default_tools.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import importlib.util
  17. import json
  18. import math
  19. from dataclasses import dataclass
  20. from math import sqrt
  21. from typing import Dict
  22. from huggingface_hub import hf_hub_download, list_spaces
  23. from ..utils import is_offline_mode
  24. from .python_interpreter import LIST_SAFE_MODULES, evaluate_python_code
  25. from .tools import TOOL_CONFIG_FILE, TOOL_MAPPING, Tool
  26. def custom_print(*args):
  27. return None
  28. BASE_PYTHON_TOOLS = {
  29. "print": custom_print,
  30. "isinstance": isinstance,
  31. "range": range,
  32. "float": float,
  33. "int": int,
  34. "bool": bool,
  35. "str": str,
  36. "set": set,
  37. "list": list,
  38. "dict": dict,
  39. "tuple": tuple,
  40. "round": round,
  41. "ceil": math.ceil,
  42. "floor": math.floor,
  43. "log": math.log,
  44. "exp": math.exp,
  45. "sin": math.sin,
  46. "cos": math.cos,
  47. "tan": math.tan,
  48. "asin": math.asin,
  49. "acos": math.acos,
  50. "atan": math.atan,
  51. "atan2": math.atan2,
  52. "degrees": math.degrees,
  53. "radians": math.radians,
  54. "pow": math.pow,
  55. "sqrt": sqrt,
  56. "len": len,
  57. "sum": sum,
  58. "max": max,
  59. "min": min,
  60. "abs": abs,
  61. "enumerate": enumerate,
  62. "zip": zip,
  63. "reversed": reversed,
  64. "sorted": sorted,
  65. "all": all,
  66. "any": any,
  67. "map": map,
  68. "filter": filter,
  69. "ord": ord,
  70. "chr": chr,
  71. "next": next,
  72. "iter": iter,
  73. "divmod": divmod,
  74. "callable": callable,
  75. "getattr": getattr,
  76. "hasattr": hasattr,
  77. "setattr": setattr,
  78. "issubclass": issubclass,
  79. "type": type,
  80. }
  81. @dataclass
  82. class PreTool:
  83. name: str
  84. inputs: Dict[str, str]
  85. output_type: type
  86. task: str
  87. description: str
  88. repo_id: str
  89. HUGGINGFACE_DEFAULT_TOOLS_FROM_HUB = [
  90. "image-transformation",
  91. "text-to-image",
  92. ]
  93. def get_remote_tools(logger, organization="huggingface-tools"):
  94. if is_offline_mode():
  95. logger.info("You are in offline mode, so remote tools are not available.")
  96. return {}
  97. spaces = list_spaces(author=organization)
  98. tools = {}
  99. for space_info in spaces:
  100. repo_id = space_info.id
  101. resolved_config_file = hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space")
  102. with open(resolved_config_file, encoding="utf-8") as reader:
  103. config = json.load(reader)
  104. task = repo_id.split("/")[-1]
  105. tools[config["name"]] = PreTool(
  106. task=task,
  107. description=config["description"],
  108. repo_id=repo_id,
  109. name=task,
  110. inputs=config["inputs"],
  111. output_type=config["output_type"],
  112. )
  113. return tools
  114. def setup_default_tools(logger):
  115. default_tools = {}
  116. main_module = importlib.import_module("transformers")
  117. tools_module = main_module.agents
  118. for task_name, tool_class_name in TOOL_MAPPING.items():
  119. tool_class = getattr(tools_module, tool_class_name)
  120. tool_instance = tool_class()
  121. default_tools[tool_class.name] = PreTool(
  122. name=tool_instance.name,
  123. inputs=tool_instance.inputs,
  124. output_type=tool_instance.output_type,
  125. task=task_name,
  126. description=tool_instance.description,
  127. repo_id=None,
  128. )
  129. return default_tools
  130. class PythonInterpreterTool(Tool):
  131. name = "python_interpreter"
  132. description = "This is a tool that evaluates python code. It can be used to perform calculations."
  133. output_type = "string"
  134. def __init__(self, *args, authorized_imports=None, **kwargs):
  135. if authorized_imports is None:
  136. self.authorized_imports = list(set(LIST_SAFE_MODULES))
  137. else:
  138. self.authorized_imports = list(set(LIST_SAFE_MODULES) | set(authorized_imports))
  139. self.inputs = {
  140. "code": {
  141. "type": "string",
  142. "description": (
  143. "The code snippet to evaluate. All variables used in this snippet must be defined in this same snippet, "
  144. f"else you will get an error. This code can only import the following python libraries: {authorized_imports}."
  145. ),
  146. }
  147. }
  148. super().__init__(*args, **kwargs)
  149. def forward(self, code):
  150. output = str(
  151. evaluate_python_code(code, static_tools=BASE_PYTHON_TOOLS, authorized_imports=self.authorized_imports)
  152. )
  153. return output
  154. class FinalAnswerTool(Tool):
  155. name = "final_answer"
  156. description = "Provides a final answer to the given problem."
  157. inputs = {"answer": {"type": "any", "description": "The final answer to the problem"}}
  158. output_type = "any"
  159. def forward(self, answer):
  160. return answer