tools.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880
  1. #!/usr/bin/env python
  2. # coding=utf-8
  3. # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import base64
  17. import importlib
  18. import inspect
  19. import io
  20. import json
  21. import os
  22. import tempfile
  23. from functools import lru_cache, wraps
  24. from typing import Any, Callable, Dict, List, Optional, Union
  25. from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
  26. from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
  27. from packaging import version
  28. from ..dynamic_module_utils import (
  29. custom_object_save,
  30. get_class_from_dynamic_module,
  31. get_imports,
  32. )
  33. from ..models.auto import AutoProcessor
  34. from ..utils import (
  35. CONFIG_NAME,
  36. TypeHintParsingException,
  37. cached_file,
  38. get_json_schema,
  39. is_accelerate_available,
  40. is_torch_available,
  41. is_vision_available,
  42. logging,
  43. )
  44. from .agent_types import handle_agent_inputs, handle_agent_outputs
  45. logger = logging.get_logger(__name__)
  46. if is_torch_available():
  47. import torch
  48. if is_accelerate_available():
  49. from accelerate import PartialState
  50. from accelerate.utils import send_to_device
  51. TOOL_CONFIG_FILE = "tool_config.json"
  52. def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
  53. if repo_type is not None:
  54. return repo_type
  55. try:
  56. hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
  57. return "space"
  58. except RepositoryNotFoundError:
  59. try:
  60. hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
  61. return "model"
  62. except RepositoryNotFoundError:
  63. raise EnvironmentError(f"`{repo_id}` does not seem to be a valid repo identifier on the Hub.")
  64. except Exception:
  65. return "model"
  66. except Exception:
  67. return "space"
  68. # docstyle-ignore
  69. APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
  70. from {module_name} import {class_name}
  71. launch_gradio_demo({class_name})
  72. """
  73. def validate_after_init(cls):
  74. original_init = cls.__init__
  75. @wraps(original_init)
  76. def new_init(self, *args, **kwargs):
  77. original_init(self, *args, **kwargs)
  78. if not isinstance(self, PipelineTool):
  79. self.validate_arguments()
  80. cls.__init__ = new_init
  81. return cls
  82. @validate_after_init
  83. class Tool:
  84. """
  85. A base class for the functions used by the agent. Subclass this and implement the `__call__` method as well as the
  86. following class attributes:
  87. - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
  88. will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
  89. returns the text contained in the file'.
  90. - **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
  91. `"text-classifier"` or `"image_generator"`.
  92. - **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
  93. It has one `type`key and a `description`key.
  94. This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
  95. description for your tool.
  96. - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
  97. or to make a nice space from your tool, and also can be used in the generated description for your tool.
  98. You can also override the method [`~Tool.setup`] if your tool as an expensive operation to perform before being
  99. usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
  100. instantiation.
  101. """
  102. name: str
  103. description: str
  104. inputs: Dict[str, Dict[str, Union[str, type]]]
  105. output_type: type
  106. def __init__(self, *args, **kwargs):
  107. self.is_initialized = False
  108. def validate_arguments(self):
  109. required_attributes = {
  110. "description": str,
  111. "name": str,
  112. "inputs": Dict,
  113. "output_type": str,
  114. }
  115. authorized_types = ["string", "integer", "number", "image", "audio", "any", "boolean"]
  116. for attr, expected_type in required_attributes.items():
  117. attr_value = getattr(self, attr, None)
  118. if not isinstance(attr_value, expected_type):
  119. raise TypeError(f"You must set an attribute {attr} of type {expected_type.__name__}.")
  120. for input_name, input_content in self.inputs.items():
  121. assert "type" in input_content, f"Input '{input_name}' should specify a type."
  122. if input_content["type"] not in authorized_types:
  123. raise Exception(
  124. f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {authorized_types}."
  125. )
  126. assert "description" in input_content, f"Input '{input_name}' should have a description."
  127. assert getattr(self, "output_type", None) in authorized_types
  128. if not isinstance(self, PipelineTool):
  129. signature = inspect.signature(self.forward)
  130. if not set(signature.parameters.keys()) == set(self.inputs.keys()):
  131. raise Exception(
  132. "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
  133. )
  134. def forward(self, *args, **kwargs):
  135. return NotImplemented("Write this method in your subclass of `Tool`.")
  136. def __call__(self, *args, **kwargs):
  137. args, kwargs = handle_agent_inputs(*args, **kwargs)
  138. outputs = self.forward(*args, **kwargs)
  139. return handle_agent_outputs(outputs, self.output_type)
  140. def setup(self):
  141. """
  142. Overwrite this method here for any operation that is expensive and needs to be executed before you start using
  143. your tool. Such as loading a big model.
  144. """
  145. self.is_initialized = True
  146. def save(self, output_dir):
  147. """
  148. Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
  149. tool in `output_dir` as well as autogenerate:
  150. - a config file named `tool_config.json`
  151. - an `app.py` file so that your tool can be converted to a space
  152. - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
  153. code)
  154. You should only use this method to save tools that are defined in a separate module (not `__main__`).
  155. Args:
  156. output_dir (`str`): The folder in which you want to save your tool.
  157. """
  158. os.makedirs(output_dir, exist_ok=True)
  159. # Save module file
  160. if self.__module__ == "__main__":
  161. raise ValueError(
  162. f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
  163. "have to put this code in a separate module so we can include it in the saved folder."
  164. )
  165. module_files = custom_object_save(self, output_dir)
  166. module_name = self.__class__.__module__
  167. last_module = module_name.split(".")[-1]
  168. full_name = f"{last_module}.{self.__class__.__name__}"
  169. # Save config file
  170. config_file = os.path.join(output_dir, "tool_config.json")
  171. if os.path.isfile(config_file):
  172. with open(config_file, "r", encoding="utf-8") as f:
  173. tool_config = json.load(f)
  174. else:
  175. tool_config = {}
  176. tool_config = {
  177. "tool_class": full_name,
  178. "description": self.description,
  179. "name": self.name,
  180. "inputs": self.inputs,
  181. "output_type": str(self.output_type),
  182. }
  183. with open(config_file, "w", encoding="utf-8") as f:
  184. f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
  185. # Save app file
  186. app_file = os.path.join(output_dir, "app.py")
  187. with open(app_file, "w", encoding="utf-8") as f:
  188. f.write(APP_FILE_TEMPLATE.format(module_name=last_module, class_name=self.__class__.__name__))
  189. # Save requirements file
  190. requirements_file = os.path.join(output_dir, "requirements.txt")
  191. imports = []
  192. for module in module_files:
  193. imports.extend(get_imports(module))
  194. imports = list(set(imports))
  195. with open(requirements_file, "w", encoding="utf-8") as f:
  196. f.write("\n".join(imports) + "\n")
  197. @classmethod
  198. def from_hub(
  199. cls,
  200. repo_id: str,
  201. model_repo_id: Optional[str] = None,
  202. token: Optional[str] = None,
  203. **kwargs,
  204. ):
  205. """
  206. Loads a tool defined on the Hub.
  207. <Tip warning={true}>
  208. Loading a tool from the Hub means that you'll download the tool and execute it locally.
  209. ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
  210. installing a package using pip/npm/apt.
  211. </Tip>
  212. Args:
  213. repo_id (`str`):
  214. The name of the repo on the Hub where your tool is defined.
  215. model_repo_id (`str`, *optional*):
  216. If your tool uses a model and you want to use a different model than the default, you can pass a second
  217. repo ID or an endpoint url to this argument.
  218. token (`str`, *optional*):
  219. The token to identify you on hf.co. If unset, will use the token generated when running
  220. `huggingface-cli login` (stored in `~/.huggingface`).
  221. kwargs (additional keyword arguments, *optional*):
  222. Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
  223. `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
  224. others will be passed along to its init.
  225. """
  226. hub_kwargs_names = [
  227. "cache_dir",
  228. "force_download",
  229. "resume_download",
  230. "proxies",
  231. "revision",
  232. "repo_type",
  233. "subfolder",
  234. "local_files_only",
  235. ]
  236. hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
  237. # Try to get the tool config first.
  238. hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
  239. resolved_config_file = cached_file(
  240. repo_id,
  241. TOOL_CONFIG_FILE,
  242. token=token,
  243. **hub_kwargs,
  244. _raise_exceptions_for_gated_repo=False,
  245. _raise_exceptions_for_missing_entries=False,
  246. _raise_exceptions_for_connection_errors=False,
  247. )
  248. is_tool_config = resolved_config_file is not None
  249. if resolved_config_file is None:
  250. resolved_config_file = cached_file(
  251. repo_id,
  252. CONFIG_NAME,
  253. token=token,
  254. **hub_kwargs,
  255. _raise_exceptions_for_gated_repo=False,
  256. _raise_exceptions_for_missing_entries=False,
  257. _raise_exceptions_for_connection_errors=False,
  258. )
  259. if resolved_config_file is None:
  260. raise EnvironmentError(
  261. f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
  262. )
  263. with open(resolved_config_file, encoding="utf-8") as reader:
  264. config = json.load(reader)
  265. if not is_tool_config:
  266. if "custom_tool" not in config:
  267. raise EnvironmentError(
  268. f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
  269. )
  270. custom_tool = config["custom_tool"]
  271. else:
  272. custom_tool = config
  273. tool_class = custom_tool["tool_class"]
  274. tool_class = get_class_from_dynamic_module(tool_class, repo_id, token=token, **hub_kwargs)
  275. if len(tool_class.name) == 0:
  276. tool_class.name = custom_tool["name"]
  277. if tool_class.name != custom_tool["name"]:
  278. logger.warning(
  279. f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
  280. "configuration name."
  281. )
  282. tool_class.name = custom_tool["name"]
  283. if len(tool_class.description) == 0:
  284. tool_class.description = custom_tool["description"]
  285. if tool_class.description != custom_tool["description"]:
  286. logger.warning(
  287. f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
  288. "tool configuration description."
  289. )
  290. tool_class.description = custom_tool["description"]
  291. if tool_class.inputs != custom_tool["inputs"]:
  292. tool_class.inputs = custom_tool["inputs"]
  293. if tool_class.output_type != custom_tool["output_type"]:
  294. tool_class.output_type = custom_tool["output_type"]
  295. return tool_class(**kwargs)
  296. def push_to_hub(
  297. self,
  298. repo_id: str,
  299. commit_message: str = "Upload tool",
  300. private: Optional[bool] = None,
  301. token: Optional[Union[bool, str]] = None,
  302. create_pr: bool = False,
  303. ) -> str:
  304. """
  305. Upload the tool to the Hub.
  306. For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
  307. For instance:
  308. ```
  309. from my_tool_module import MyTool
  310. my_tool = MyTool()
  311. my_tool.push_to_hub("my-username/my-space")
  312. ```
  313. Parameters:
  314. repo_id (`str`):
  315. The name of the repository you want to push your tool to. It should contain your organization name when
  316. pushing to a given organization.
  317. commit_message (`str`, *optional*, defaults to `"Upload tool"`):
  318. Message to commit while pushing.
  319. private (`bool`, *optional*):
  320. Whether or not the repository created should be private.
  321. token (`bool` or `str`, *optional*):
  322. The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
  323. when running `huggingface-cli login` (stored in `~/.huggingface`).
  324. create_pr (`bool`, *optional*, defaults to `False`):
  325. Whether or not to create a PR with the uploaded files or directly commit.
  326. """
  327. repo_url = create_repo(
  328. repo_id=repo_id,
  329. token=token,
  330. private=private,
  331. exist_ok=True,
  332. repo_type="space",
  333. space_sdk="gradio",
  334. )
  335. repo_id = repo_url.repo_id
  336. metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
  337. with tempfile.TemporaryDirectory() as work_dir:
  338. # Save all files.
  339. self.save(work_dir)
  340. logger.info(f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}")
  341. return upload_folder(
  342. repo_id=repo_id,
  343. commit_message=commit_message,
  344. folder_path=work_dir,
  345. token=token,
  346. create_pr=create_pr,
  347. repo_type="space",
  348. )
  349. @staticmethod
  350. def from_gradio(gradio_tool):
  351. """
  352. Creates a [`Tool`] from a gradio tool.
  353. """
  354. import inspect
  355. class GradioToolWrapper(Tool):
  356. def __init__(self, _gradio_tool):
  357. super().__init__()
  358. self.name = _gradio_tool.name
  359. self.description = _gradio_tool.description
  360. self.output_type = "string"
  361. self._gradio_tool = _gradio_tool
  362. func_args = list(inspect.signature(_gradio_tool.run).parameters.keys())
  363. self.inputs = {key: "" for key in func_args}
  364. def forward(self, *args, **kwargs):
  365. return self._gradio_tool.run(*args, **kwargs)
  366. return GradioToolWrapper(gradio_tool)
  367. @staticmethod
  368. def from_langchain(langchain_tool):
  369. """
  370. Creates a [`Tool`] from a langchain tool.
  371. """
  372. class LangChainToolWrapper(Tool):
  373. def __init__(self, _langchain_tool):
  374. super().__init__()
  375. self.name = _langchain_tool.name.lower()
  376. self.description = _langchain_tool.description
  377. self.inputs = parse_langchain_args(_langchain_tool.args)
  378. self.output_type = "string"
  379. self.langchain_tool = _langchain_tool
  380. def forward(self, *args, **kwargs):
  381. tool_input = kwargs.copy()
  382. for index, argument in enumerate(args):
  383. if index < len(self.inputs):
  384. input_key = next(iter(self.inputs))
  385. tool_input[input_key] = argument
  386. return self.langchain_tool.run(tool_input)
  387. return LangChainToolWrapper(langchain_tool)
  388. DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
  389. - {{ tool.name }}: {{ tool.description }}
  390. Takes inputs: {{tool.inputs}}
  391. Returns an output of type: {{tool.output_type}}
  392. """
  393. def get_tool_description_with_args(tool: Tool, description_template: str = DEFAULT_TOOL_DESCRIPTION_TEMPLATE) -> str:
  394. compiled_template = compile_jinja_template(description_template)
  395. rendered = compiled_template.render(
  396. tool=tool,
  397. )
  398. return rendered
  399. @lru_cache
  400. def compile_jinja_template(template):
  401. try:
  402. import jinja2
  403. from jinja2.exceptions import TemplateError
  404. from jinja2.sandbox import ImmutableSandboxedEnvironment
  405. except ImportError:
  406. raise ImportError("template requires jinja2 to be installed.")
  407. if version.parse(jinja2.__version__) < version.parse("3.1.0"):
  408. raise ImportError("template requires jinja2>=3.1.0 to be installed. Your version is " f"{jinja2.__version__}.")
  409. def raise_exception(message):
  410. raise TemplateError(message)
  411. jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
  412. jinja_env.globals["raise_exception"] = raise_exception
  413. return jinja_env.from_string(template)
  414. class PipelineTool(Tool):
  415. """
  416. A [`Tool`] tailored towards Transformer models. On top of the class attributes of the base class [`Tool`], you will
  417. need to specify:
  418. - **model_class** (`type`) -- The class to use to load the model in this tool.
  419. - **default_checkpoint** (`str`) -- The default checkpoint that should be used when the user doesn't specify one.
  420. - **pre_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
  421. pre-processor
  422. - **post_processor_class** (`type`, *optional*, defaults to [`AutoProcessor`]) -- The class to use to load the
  423. post-processor (when different from the pre-processor).
  424. Args:
  425. model (`str` or [`PreTrainedModel`], *optional*):
  426. The name of the checkpoint to use for the model, or the instantiated model. If unset, will default to the
  427. value of the class attribute `default_checkpoint`.
  428. pre_processor (`str` or `Any`, *optional*):
  429. The name of the checkpoint to use for the pre-processor, or the instantiated pre-processor (can be a
  430. tokenizer, an image processor, a feature extractor or a processor). Will default to the value of `model` if
  431. unset.
  432. post_processor (`str` or `Any`, *optional*):
  433. The name of the checkpoint to use for the post-processor, or the instantiated pre-processor (can be a
  434. tokenizer, an image processor, a feature extractor or a processor). Will default to the `pre_processor` if
  435. unset.
  436. device (`int`, `str` or `torch.device`, *optional*):
  437. The device on which to execute the model. Will default to any accelerator available (GPU, MPS etc...), the
  438. CPU otherwise.
  439. device_map (`str` or `dict`, *optional*):
  440. If passed along, will be used to instantiate the model.
  441. model_kwargs (`dict`, *optional*):
  442. Any keyword argument to send to the model instantiation.
  443. token (`str`, *optional*):
  444. The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated when
  445. running `huggingface-cli login` (stored in `~/.huggingface`).
  446. hub_kwargs (additional keyword arguments, *optional*):
  447. Any additional keyword argument to send to the methods that will load the data from the Hub.
  448. """
  449. pre_processor_class = AutoProcessor
  450. model_class = None
  451. post_processor_class = AutoProcessor
  452. default_checkpoint = None
  453. description = "This is a pipeline tool"
  454. name = "pipeline"
  455. inputs = {"prompt": str}
  456. output_type = str
  457. def __init__(
  458. self,
  459. model=None,
  460. pre_processor=None,
  461. post_processor=None,
  462. device=None,
  463. device_map=None,
  464. model_kwargs=None,
  465. token=None,
  466. **hub_kwargs,
  467. ):
  468. if not is_torch_available():
  469. raise ImportError("Please install torch in order to use this tool.")
  470. if not is_accelerate_available():
  471. raise ImportError("Please install accelerate in order to use this tool.")
  472. if model is None:
  473. if self.default_checkpoint is None:
  474. raise ValueError("This tool does not implement a default checkpoint, you need to pass one.")
  475. model = self.default_checkpoint
  476. if pre_processor is None:
  477. pre_processor = model
  478. self.model = model
  479. self.pre_processor = pre_processor
  480. self.post_processor = post_processor
  481. self.device = device
  482. self.device_map = device_map
  483. self.model_kwargs = {} if model_kwargs is None else model_kwargs
  484. if device_map is not None:
  485. self.model_kwargs["device_map"] = device_map
  486. self.hub_kwargs = hub_kwargs
  487. self.hub_kwargs["token"] = token
  488. super().__init__()
  489. def setup(self):
  490. """
  491. Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
  492. """
  493. if isinstance(self.pre_processor, str):
  494. self.pre_processor = self.pre_processor_class.from_pretrained(self.pre_processor, **self.hub_kwargs)
  495. if isinstance(self.model, str):
  496. self.model = self.model_class.from_pretrained(self.model, **self.model_kwargs, **self.hub_kwargs)
  497. if self.post_processor is None:
  498. self.post_processor = self.pre_processor
  499. elif isinstance(self.post_processor, str):
  500. self.post_processor = self.post_processor_class.from_pretrained(self.post_processor, **self.hub_kwargs)
  501. if self.device is None:
  502. if self.device_map is not None:
  503. self.device = list(self.model.hf_device_map.values())[0]
  504. else:
  505. self.device = PartialState().default_device
  506. if self.device_map is None:
  507. self.model.to(self.device)
  508. super().setup()
  509. def encode(self, raw_inputs):
  510. """
  511. Uses the `pre_processor` to prepare the inputs for the `model`.
  512. """
  513. return self.pre_processor(raw_inputs)
  514. def forward(self, inputs):
  515. """
  516. Sends the inputs through the `model`.
  517. """
  518. with torch.no_grad():
  519. return self.model(**inputs)
  520. def decode(self, outputs):
  521. """
  522. Uses the `post_processor` to decode the model output.
  523. """
  524. return self.post_processor(outputs)
  525. def __call__(self, *args, **kwargs):
  526. args, kwargs = handle_agent_inputs(*args, **kwargs)
  527. if not self.is_initialized:
  528. self.setup()
  529. encoded_inputs = self.encode(*args, **kwargs)
  530. tensor_inputs = {k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)}
  531. non_tensor_inputs = {k: v for k, v in encoded_inputs.items() if not isinstance(v, torch.Tensor)}
  532. encoded_inputs = send_to_device(tensor_inputs, self.device)
  533. outputs = self.forward({**encoded_inputs, **non_tensor_inputs})
  534. outputs = send_to_device(outputs, "cpu")
  535. decoded_outputs = self.decode(outputs)
  536. return handle_agent_outputs(decoded_outputs, self.output_type)
  537. def launch_gradio_demo(tool_class: Tool):
  538. """
  539. Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
  540. `inputs` and `output_type`.
  541. Args:
  542. tool_class (`type`): The class of the tool for which to launch the demo.
  543. """
  544. try:
  545. import gradio as gr
  546. except ImportError:
  547. raise ImportError("Gradio should be installed in order to launch a gradio demo.")
  548. tool = tool_class()
  549. def fn(*args, **kwargs):
  550. return tool(*args, **kwargs)
  551. gradio_inputs = []
  552. for input_name, input_details in tool_class.inputs.items():
  553. input_type = input_details["type"]
  554. if input_type == "image":
  555. gradio_inputs.append(gr.Image(label=input_name))
  556. elif input_type == "audio":
  557. gradio_inputs.append(gr.Audio(label=input_name))
  558. elif input_type in ["string", "integer", "number"]:
  559. gradio_inputs.append(gr.Textbox(label=input_name))
  560. else:
  561. error_message = f"Input type '{input_type}' not supported."
  562. raise ValueError(error_message)
  563. gradio_output = tool_class.output_type
  564. assert gradio_output in ["string", "image", "audio"], f"Output type '{gradio_output}' not supported."
  565. gr.Interface(
  566. fn=fn,
  567. inputs=gradio_inputs,
  568. outputs=gradio_output,
  569. title=tool_class.__name__,
  570. article=tool.description,
  571. ).launch()
  572. TOOL_MAPPING = {
  573. "document_question_answering": "DocumentQuestionAnsweringTool",
  574. "image_question_answering": "ImageQuestionAnsweringTool",
  575. "speech_to_text": "SpeechToTextTool",
  576. "text_to_speech": "TextToSpeechTool",
  577. "translation": "TranslationTool",
  578. "python_interpreter": "PythonInterpreterTool",
  579. "web_search": "DuckDuckGoSearchTool",
  580. }
  581. def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
  582. """
  583. Main function to quickly load a tool, be it on the Hub or in the Transformers library.
  584. <Tip warning={true}>
  585. Loading a tool means that you'll download the tool and execute it locally.
  586. ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
  587. installing a package using pip/npm/apt.
  588. </Tip>
  589. Args:
  590. task_or_repo_id (`str`):
  591. The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
  592. are:
  593. - `"document_question_answering"`
  594. - `"image_question_answering"`
  595. - `"speech_to_text"`
  596. - `"text_to_speech"`
  597. - `"translation"`
  598. model_repo_id (`str`, *optional*):
  599. Use this argument to use a different model than the default one for the tool you selected.
  600. token (`str`, *optional*):
  601. The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
  602. login` (stored in `~/.huggingface`).
  603. kwargs (additional keyword arguments, *optional*):
  604. Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
  605. `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
  606. will be passed along to its init.
  607. """
  608. if task_or_repo_id in TOOL_MAPPING:
  609. tool_class_name = TOOL_MAPPING[task_or_repo_id]
  610. main_module = importlib.import_module("transformers")
  611. tools_module = main_module.agents
  612. tool_class = getattr(tools_module, tool_class_name)
  613. return tool_class(model_repo_id, token=token, **kwargs)
  614. else:
  615. logger.warning_once(
  616. f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
  617. f"trust as the code within that tool will be executed on your machine. Always verify the code of "
  618. f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
  619. f"code that you have checked."
  620. )
  621. return Tool.from_hub(task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs)
  622. def add_description(description):
  623. """
  624. A decorator that adds a description to a function.
  625. """
  626. def inner(func):
  627. func.description = description
  628. func.name = func.__name__
  629. return func
  630. return inner
  631. ## Will move to the Hub
  632. class EndpointClient:
  633. def __init__(self, endpoint_url: str, token: Optional[str] = None):
  634. self.headers = {
  635. **build_hf_headers(token=token),
  636. "Content-Type": "application/json",
  637. }
  638. self.endpoint_url = endpoint_url
  639. @staticmethod
  640. def encode_image(image):
  641. _bytes = io.BytesIO()
  642. image.save(_bytes, format="PNG")
  643. b64 = base64.b64encode(_bytes.getvalue())
  644. return b64.decode("utf-8")
  645. @staticmethod
  646. def decode_image(raw_image):
  647. if not is_vision_available():
  648. raise ImportError(
  649. "This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
  650. )
  651. from PIL import Image
  652. b64 = base64.b64decode(raw_image)
  653. _bytes = io.BytesIO(b64)
  654. return Image.open(_bytes)
  655. def __call__(
  656. self,
  657. inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
  658. params: Optional[Dict] = None,
  659. data: Optional[bytes] = None,
  660. output_image: bool = False,
  661. ) -> Any:
  662. # Build payload
  663. payload = {}
  664. if inputs:
  665. payload["inputs"] = inputs
  666. if params:
  667. payload["parameters"] = params
  668. # Make API call
  669. response = get_session().post(self.endpoint_url, headers=self.headers, json=payload, data=data)
  670. # By default, parse the response for the user.
  671. if output_image:
  672. return self.decode_image(response.content)
  673. else:
  674. return response.json()
  675. def parse_langchain_args(args: Dict[str, str]) -> Dict[str, str]:
  676. """Parse the args attribute of a LangChain tool to create a matching inputs dictionary."""
  677. inputs = args.copy()
  678. for arg_details in inputs.values():
  679. if "title" in arg_details:
  680. arg_details.pop("title")
  681. return inputs
  682. class ToolCollection:
  683. """
  684. Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
  685. > [!NOTE]
  686. > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
  687. > like for this collection to showcase them.
  688. Args:
  689. collection_slug (str):
  690. The collection slug referencing the collection.
  691. token (str, *optional*):
  692. The authentication token if the collection is private.
  693. Example:
  694. ```py
  695. >>> from transformers import ToolCollection, ReactCodeAgent
  696. >>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
  697. >>> agent = ReactCodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
  698. >>> agent.run("Please draw me a picture of rivers and lakes.")
  699. ```
  700. """
  701. def __init__(self, collection_slug: str, token: Optional[str] = None):
  702. self._collection = get_collection(collection_slug, token=token)
  703. self._hub_repo_ids = {item.item_id for item in self._collection.items if item.item_type == "space"}
  704. self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
  705. def tool(tool_function: Callable) -> Tool:
  706. """
  707. Converts a function into an instance of a Tool subclass.
  708. Args:
  709. tool_function: Your function. Should have type hints for each input and a type hint for the output.
  710. Should also have a docstring description including an 'Args:' part where each argument is described.
  711. """
  712. parameters = get_json_schema(tool_function)["function"]
  713. if "return" not in parameters:
  714. raise TypeHintParsingException("Tool return type not found: make sure your function has a return type hint!")
  715. class_name = f"{parameters['name'].capitalize()}Tool"
  716. class SpecificTool(Tool):
  717. name = parameters["name"]
  718. description = parameters["description"]
  719. inputs = parameters["parameters"]["properties"]
  720. output_type = parameters["return"]["type"]
  721. @wraps(tool_function)
  722. def forward(self, *args, **kwargs):
  723. return tool_function(*args, **kwargs)
  724. original_signature = inspect.signature(tool_function)
  725. new_parameters = [inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)] + list(
  726. original_signature.parameters.values()
  727. )
  728. new_signature = original_signature.replace(parameters=new_parameters)
  729. SpecificTool.forward.__signature__ = new_signature
  730. SpecificTool.__name__ = class_name
  731. return SpecificTool()