From aba1bc40e0e728f9461684c3f4ce5481ed5ac6b6 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Mon, 11 Nov 2024 17:18:11 +0100 Subject: [PATCH 1/3] Allow handling files as args for a tool created with Tool.from_space --- src/transformers/agents/tools.py | 68 +++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 15 deletions(-) diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 994e1bdd817b..927804f57e1e 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -22,6 +22,7 @@ import os import tempfile from functools import lru_cache, wraps +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder @@ -414,7 +415,7 @@ def push_to_hub( ) @staticmethod - def from_space(space_id, name, description): + def from_space(space_id: str, name: str, description: str, api_name: Optional[str] = None): """ Creates a [`Tool`] from a Space given its id on the Hub. @@ -425,34 +426,63 @@ def from_space(space_id, name, description): The name of the tool. description (`str`): The description of the tool. + api_name (`str`, *optional*): + The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api. Returns: [`Tool`]: - The created tool. + The Space, as a tool. - Example: + Examples: ``` - tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt") + image_generator = Tool.from_space( + space_id="black-forest-labs/FLUX.1-schnell", + name="image-generator", + description="Generate an image from a prompt" + ) + image = image_generator("Generate an image of a cool surfer in Tahiti") + ``` + ``` + face_swapper = Tool.from_space( + "tuan2308/face-swap", + "face_swapper", + "Tool that puts the face shown on the first image on the second image. You can give it paths to images.", + ) + image = face_swapper('./aymeric.jpeg', './ruth.jpg') ``` """ - from gradio_client import Client + from gradio_client import Client, handle_file + from gradio_client.utils import is_http_url_like class SpaceToolWrapper(Tool): - def __init__(self, space_id, name, description): + def __init__(self, space_id: str, name: str, description: str, api_name: Optional[str] = None): self.client = Client(space_id) self.name = name self.description = description - space_description = self.client.view_api(return_format="dict")["named_endpoints"] - route = list(space_description.keys())[0] - space_description_route = space_description[route] + space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"] + + # If api_name is not defined, take the first of the available APIs for this space + if api_name is None: + api_name = list(space_description.keys())[0] + logger.warning(f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`.") + self.api_name = api_name + + try: + space_description_api = space_description[api_name] + except KeyError: + raise KeyError(f"Could not find specified {api_name=} among available api names.") + self.inputs = {} - for parameter in space_description_route["parameters"]: + for parameter in space_description_api["parameters"]: if not parameter["parameter_has_default"]: + parameter_type = parameter["type"]["type"] + if parameter_type == "object": + parameter_type = "any" self.inputs[parameter["parameter_name"]] = { - "type": parameter["type"]["type"], + "type": parameter_type, "description": parameter["python_type"]["description"], } - output_component = space_description_route["returns"][0]["component"] + output_component = space_description_api["returns"][0]["component"] if output_component == "Image": self.output_type = "image" elif output_component == "Audio": @@ -461,9 +491,17 @@ def __init__(self, space_id, name, description): self.output_type = "any" def forward(self, *args, **kwargs): - return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result - - return SpaceToolWrapper(space_id, name, description) + # Test if any arg is a file and processes it accordingly: + args = list(args) + for i, arg in enumerate(args): + if ( + isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file() + ) or is_http_url_like(arg): + args[i] = handle_file(arg) + output = self.client.predict(*args, api_name=self.api_name, **kwargs) + return output[0] # Usually the first output is the result + + return SpaceToolWrapper(space_id, name, description, api_name=api_name) @staticmethod def from_gradio(gradio_tool): From d9a09ea2419efdd93d5a11a1600db810e78a7748 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Mon, 11 Nov 2024 21:40:26 +0100 Subject: [PATCH 2/3] Format --- src/transformers/agents/tools.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 927804f57e1e..ebf556eef1f8 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -464,7 +464,9 @@ def __init__(self, space_id: str, name: str, description: str, api_name: Optiona # If api_name is not defined, take the first of the available APIs for this space if api_name is None: api_name = list(space_description.keys())[0] - logger.warning(f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`.") + logger.warning( + f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`." + ) self.api_name = api_name try: From f88d317881196a77a49ea3abd577001ee28e54a0 Mon Sep 17 00:00:00 2001 From: Aymeric Date: Mon, 18 Nov 2024 18:19:01 +0100 Subject: [PATCH 3/3] Add arg sanitizing --- src/transformers/agents/tools.py | 49 ++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 12 deletions(-) diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 98ca67b5ba20..6d3401bf30e9 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -46,7 +46,7 @@ is_vision_available, logging, ) -from .agent_types import handle_agent_inputs, handle_agent_outputs +from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs logger = logging.get_logger(__name__) @@ -419,7 +419,9 @@ def push_to_hub( ) @staticmethod - def from_space(space_id: str, name: str, description: str, api_name: Optional[str] = None): + def from_space( + space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None + ): """ Creates a [`Tool`] from a Space given its id on the Hub. @@ -432,7 +434,8 @@ def from_space(space_id: str, name: str, description: str, api_name: Optional[st The description of the tool. api_name (`str`, *optional*): The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api. - + token (`str`, *optional*): + Add your token to access private spaces or increase your GPU quotas. Returns: [`Tool`]: The Space, as a tool. @@ -459,8 +462,15 @@ def from_space(space_id: str, name: str, description: str, api_name: Optional[st from gradio_client.utils import is_http_url_like class SpaceToolWrapper(Tool): - def __init__(self, space_id: str, name: str, description: str, api_name: Optional[str] = None): - self.client = Client(space_id) + def __init__( + self, + space_id: str, + name: str, + description: str, + api_name: Optional[str] = None, + token: Optional[str] = None, + ): + self.client = Client(space_id, hf_token=token) self.name = name self.description = description space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"] @@ -496,18 +506,33 @@ def __init__(self, space_id: str, name: str, description: str, api_name: Optiona else: self.output_type = "any" + def sanitize_argument_for_prediction(self, arg): + if isinstance(arg, ImageType): + temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + arg.save(temp_file.name) + arg = temp_file.name + if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like( + arg + ): + arg = handle_file(arg) + return arg + def forward(self, *args, **kwargs): - # Test if any arg is a file and processes it accordingly: + # Preprocess args and kwargs: args = list(args) for i, arg in enumerate(args): - if ( - isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file() - ) or is_http_url_like(arg): - args[i] = handle_file(arg) + args[i] = self.sanitize_argument_for_prediction(arg) + for arg_name, arg in kwargs.items(): + kwargs[arg_name] = self.sanitize_argument_for_prediction(arg) + output = self.client.predict(*args, api_name=self.api_name, **kwargs) - return output[0] # Usually the first output is the result + if isinstance(output, tuple) or isinstance(output, list): + return output[ + 0 + ] # Sometime the space also returns the generation seed, in which case the result is at index 0 + return output - return SpaceToolWrapper(space_id, name, description, api_name=api_name) + return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token) @staticmethod def from_gradio(gradio_tool):