Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,15 +181,30 @@ def _parse_type_hint(hint: str) -> dict:
def _convert_type_hints_to_json_schema(func: Callable) -> dict:
type_hints = get_type_hints(func)
signature = inspect.signature(func)
# For methods, we need to ignore the first "self" or "cls" parameter. However, since unbound methods are just
# functions, we need to check the signature to see if it looks like a method rather than only relying on
# inspect.ismethod(), which returns False for unbound methods.
qualname = getattr(func, "__qualname__", "")
qualname_parts = qualname.split(".")
has_unbound_method_signature = isfunction(func) and len(qualname_parts) >= 2 and qualname_parts[-2] != "<locals>"
first_param_name = next(iter(signature.parameters), None)
implicit_arg_name = None
if first_param_name in {"self", "cls"} and (inspect.ismethod(func) or has_unbound_method_signature):
implicit_arg_name = first_param_name
Comment thread
qgallouedec marked this conversation as resolved.
Outdated

required = []
for param_name, param in signature.parameters.items():
if param_name == implicit_arg_name:
continue
if param.annotation == inspect.Parameter.empty:
raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
if param.default == inspect.Parameter.empty:
required.append(param_name)

properties = {}
for param_name, param_type in type_hints.items():
if param_name == implicit_arg_name:
continue
properties[param_name] = _parse_type_hint(param_type)

schema = {"type": "object", "properties": properties}
Expand Down Expand Up @@ -485,7 +500,7 @@ def render_jinja_template(
for tool in tools:
if isinstance(tool, dict):
tool_schemas.append(tool)
elif isfunction(tool):
elif isfunction(tool) or inspect.ismethod(tool):
tool_schemas.append(get_json_schema(tool))
else:
raise ValueError(
Expand Down
71 changes: 71 additions & 0 deletions tests/utils/test_chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,77 @@ def fn(x: int) -> None:
}
self.assertEqual(schema["function"], expected_schema)

def test_instance_method(self):
class Tool:
def fn(self, x: int):
"""
Test function

Args:
x: The input
"""
return x

expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema) # unbound case
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema) # bound case

def test_static_method(self):
class Tool:
@staticmethod
def fn(x: int):
"""
Test function

Args:
x: The input
"""
return x

expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema)
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema)

def test_class_method(self):
class Tool:
@classmethod
def fn(cls, x: int):
"""
Test function

Args:
x: The input
"""
return x

expected_schema = {
"name": "fn",
"description": "Test function",
"parameters": {
"type": "object",
"properties": {"x": {"type": "integer", "description": "The input"}},
"required": ["x"],
},
}
self.assertEqual(get_json_schema(Tool.fn)["function"], expected_schema)
self.assertEqual(get_json_schema(Tool().fn)["function"], expected_schema)

def test_everything_all_at_once(self):
def fn(x: str, y: list[str | int] | None, z: tuple[str | int, str] = (42, "hello")) -> tuple[int, str]:
"""
Expand Down