diff --git a/src/transformers/agents/tools.py b/src/transformers/agents/tools.py index 84bcf0fde61f..6d3401bf30e9 100644 --- a/src/transformers/agents/tools.py +++ b/src/transformers/agents/tools.py @@ -23,6 +23,7 @@ import os import tempfile from functools import lru_cache, wraps +from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Union from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder @@ -45,7 +46,7 @@ is_vision_available, logging, ) -from .agent_types import handle_agent_inputs, handle_agent_outputs +from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs logger = logging.get_logger(__name__) @@ -418,7 +419,9 @@ def push_to_hub( ) @staticmethod - def from_space(space_id, name, description): + def from_space( + space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None + ): """ Creates a [`Tool`] from a Space given its id on the Hub. @@ -429,34 +432,73 @@ def from_space(space_id, name, description): The name of the tool. description (`str`): The description of the tool. - + api_name (`str`, *optional*): + The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api. + token (`str`, *optional*): + Add your token to access private spaces or increase your GPU quotas. Returns: [`Tool`]: - The created tool. + The Space, as a tool. - Example: + Examples: ``` - tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt") + image_generator = Tool.from_space( + space_id="black-forest-labs/FLUX.1-schnell", + name="image-generator", + description="Generate an image from a prompt" + ) + image = image_generator("Generate an image of a cool surfer in Tahiti") + ``` + ``` + face_swapper = Tool.from_space( + "tuan2308/face-swap", + "face_swapper", + "Tool that puts the face shown on the first image on the second image. You can give it paths to images.", + ) + image = face_swapper('./aymeric.jpeg', './ruth.jpg') ``` """ - from gradio_client import Client + from gradio_client import Client, handle_file + from gradio_client.utils import is_http_url_like class SpaceToolWrapper(Tool): - def __init__(self, space_id, name, description): - self.client = Client(space_id) + def __init__( + self, + space_id: str, + name: str, + description: str, + api_name: Optional[str] = None, + token: Optional[str] = None, + ): + self.client = Client(space_id, hf_token=token) self.name = name self.description = description - space_description = self.client.view_api(return_format="dict")["named_endpoints"] - route = list(space_description.keys())[0] - space_description_route = space_description[route] + space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"] + + # If api_name is not defined, take the first of the available APIs for this space + if api_name is None: + api_name = list(space_description.keys())[0] + logger.warning( + f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`." + ) + self.api_name = api_name + + try: + space_description_api = space_description[api_name] + except KeyError: + raise KeyError(f"Could not find specified {api_name=} among available api names.") + self.inputs = {} - for parameter in space_description_route["parameters"]: + for parameter in space_description_api["parameters"]: if not parameter["parameter_has_default"]: + parameter_type = parameter["type"]["type"] + if parameter_type == "object": + parameter_type = "any" self.inputs[parameter["parameter_name"]] = { - "type": parameter["type"]["type"], + "type": parameter_type, "description": parameter["python_type"]["description"], } - output_component = space_description_route["returns"][0]["component"] + output_component = space_description_api["returns"][0]["component"] if output_component == "Image": self.output_type = "image" elif output_component == "Audio": @@ -464,10 +506,33 @@ def __init__(self, space_id, name, description): else: self.output_type = "any" - def forward(self, *args, **kwargs): - return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result + def sanitize_argument_for_prediction(self, arg): + if isinstance(arg, ImageType): + temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + arg.save(temp_file.name) + arg = temp_file.name + if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like( + arg + ): + arg = handle_file(arg) + return arg - return SpaceToolWrapper(space_id, name, description) + def forward(self, *args, **kwargs): + # Preprocess args and kwargs: + args = list(args) + for i, arg in enumerate(args): + args[i] = self.sanitize_argument_for_prediction(arg) + for arg_name, arg in kwargs.items(): + kwargs[arg_name] = self.sanitize_argument_for_prediction(arg) + + output = self.client.predict(*args, api_name=self.api_name, **kwargs) + if isinstance(output, tuple) or isinstance(output, list): + return output[ + 0 + ] # Sometime the space also returns the generation seed, in which case the result is at index 0 + return output + + return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token) @staticmethod def from_gradio(gradio_tool):