Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 83 additions & 18 deletions src/transformers/agents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import os
import tempfile
from functools import lru_cache, wraps
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union

from huggingface_hub import create_repo, get_collection, hf_hub_download, metadata_update, upload_folder
Expand All @@ -45,7 +46,7 @@
is_vision_available,
logging,
)
from .agent_types import handle_agent_inputs, handle_agent_outputs
from .agent_types import ImageType, handle_agent_inputs, handle_agent_outputs


logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -418,7 +419,9 @@ def push_to_hub(
)

@staticmethod
def from_space(space_id, name, description):
def from_space(
space_id: str, name: str, description: str, api_name: Optional[str] = None, token: Optional[str] = None
):
"""
Creates a [`Tool`] from a Space given its id on the Hub.

Expand All @@ -429,45 +432,107 @@ def from_space(space_id, name, description):
The name of the tool.
description (`str`):
The description of the tool.

api_name (`str`, *optional*):
The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
token (`str`, *optional*):
Add your token to access private spaces or increase your GPU quotas.
Returns:
[`Tool`]:
The created tool.
The Space, as a tool.

Example:
Examples:
```
tool = Tool.from_space("black-forest-labs/FLUX.1-schnell", "image-generator", "Generate an image from a prompt")
image_generator = Tool.from_space(
space_id="black-forest-labs/FLUX.1-schnell",
name="image-generator",
description="Generate an image from a prompt"
)
image = image_generator("Generate an image of a cool surfer in Tahiti")
```
```
face_swapper = Tool.from_space(
"tuan2308/face-swap",
"face_swapper",
"Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
)
image = face_swapper('./aymeric.jpeg', './ruth.jpg')
```
"""
from gradio_client import Client
from gradio_client import Client, handle_file
from gradio_client.utils import is_http_url_like

class SpaceToolWrapper(Tool):
def __init__(self, space_id, name, description):
self.client = Client(space_id)
def __init__(
self,
space_id: str,
name: str,
description: str,
api_name: Optional[str] = None,
token: Optional[str] = None,
):
self.client = Client(space_id, hf_token=token)
self.name = name
self.description = description
space_description = self.client.view_api(return_format="dict")["named_endpoints"]
route = list(space_description.keys())[0]
space_description_route = space_description[route]
space_description = self.client.view_api(return_format="dict", print_info=False)["named_endpoints"]

# If api_name is not defined, take the first of the available APIs for this space
if api_name is None:
api_name = list(space_description.keys())[0]
logger.warning(
f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`."
)
self.api_name = api_name

try:
space_description_api = space_description[api_name]
except KeyError:
raise KeyError(f"Could not find specified {api_name=} among available api names.")

self.inputs = {}
for parameter in space_description_route["parameters"]:
for parameter in space_description_api["parameters"]:
if not parameter["parameter_has_default"]:
parameter_type = parameter["type"]["type"]
if parameter_type == "object":
parameter_type = "any"
self.inputs[parameter["parameter_name"]] = {
"type": parameter["type"]["type"],
"type": parameter_type,
"description": parameter["python_type"]["description"],
}
output_component = space_description_route["returns"][0]["component"]
output_component = space_description_api["returns"][0]["component"]
if output_component == "Image":
self.output_type = "image"
elif output_component == "Audio":
self.output_type = "audio"
else:
self.output_type = "any"

def forward(self, *args, **kwargs):
return self.client.predict(*args, **kwargs)[0] # Usually the first output is the result
def sanitize_argument_for_prediction(self, arg):
if isinstance(arg, ImageType):
temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
arg.save(temp_file.name)
arg = temp_file.name
if (isinstance(arg, (str, Path)) and Path(arg).exists() and Path(arg).is_file()) or is_http_url_like(
arg
):
arg = handle_file(arg)
return arg

return SpaceToolWrapper(space_id, name, description)
def forward(self, *args, **kwargs):
# Preprocess args and kwargs:
args = list(args)
for i, arg in enumerate(args):
args[i] = self.sanitize_argument_for_prediction(arg)
for arg_name, arg in kwargs.items():
kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)

output = self.client.predict(*args, api_name=self.api_name, **kwargs)
if isinstance(output, tuple) or isinstance(output, list):
return output[
0
] # Sometime the space also returns the generation seed, in which case the result is at index 0
return output

return SpaceToolWrapper(space_id, name, description, api_name=api_name, token=token)

@staticmethod
def from_gradio(gradio_tool):
Expand Down