Skip to content
36 changes: 34 additions & 2 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:

Currently handles:
- Converting JSON dictionaries to Pydantic model instances where expected
- Converting lists of JSON dictionaries to lists of Pydantic model instances

Future extensions could include:
- Type coercion for other complex types
Expand Down Expand Up @@ -129,8 +130,39 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
if len(non_none_types) == 1:
target_type = non_none_types[0]

# Check if the target type is a list
if get_origin(target_type) is list:
list_args = get_args(target_type)
if list_args:
element_type = list_args[0]

# Check if the element type is a Pydantic model
if inspect.isclass(element_type) and issubclass(
element_type, pydantic.BaseModel
):
# Skip conversion if the value is None
if args[param_name] is None:
continue

# Convert list elements to Pydantic models
if isinstance(args[param_name], list):
converted_list = []
for item in args[param_name]:
try:
converted_list.append(element_type.model_validate(item))
except Exception as e:
logger.warning(
f"Failed to convert item in '{param_name}' to Pydantic"
f' model {element_type.__name__}: {e}'
)

# Keep the original value if conversion fails
converted_list.append(item)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When a list element fails Pydantic validation you now append the original dict back into converted_list. The downstream tool function still receives that list, so any comprehensions or attribute access (user.name) will now raise (dict has no attribute name).

I think we skipped the bad element before, so the tool could still run on the valid ones. If we want to preserve the original data for feedback, we either need to keep the entire argument unconverted so tool decides or continue skipping/raising


converted_args[param_name] = converted_list

# Check if the target type is a Pydantic model
if inspect.isclass(target_type) and issubclass(
elif inspect.isclass(target_type) and issubclass(
target_type, pydantic.BaseModel
):
# Skip conversion if the value is None and the parameter is Optional
Expand All @@ -146,7 +178,7 @@ def _preprocess_args(self, args: dict[str, Any]) -> dict[str, Any]:
except Exception as e:
logger.warning(
f"Failed to convert argument '{param_name}' to Pydantic model"
f' {target_type.__name__}: {e}'
f' model {target_type.__name__}: {e}'
)
# Keep the original value if conversion fails
pass
Expand Down
130 changes: 128 additions & 2 deletions tests/unittests/tools/test_function_tool_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,5 +280,131 @@ async def test_run_async_with_optional_pydantic_models():
assert result["theme"] == "dark"
assert result["notifications"] is True
assert result["preferences_type"] == "PreferencesModel"
assert result["preferences_type"] == "PreferencesModel"
assert result["preferences_type"] == "PreferencesModel"


def function_with_list_of_pydantic_models(users: list[UserModel]) -> dict:
"""Function that takes a list of Pydantic models."""
return {
"count": len(users),
"names": [user.name for user in users],
"ages": [user.age for user in users],
"types": [type(user).__name__ for user in users],
}


def function_with_optional_list_of_pydantic_models(
users: Optional[list[UserModel]] = None,
) -> dict:
"""Function that takes an optional list of Pydantic models."""
if users is None:
return {"count": 0, "names": []}
return {
"count": len(users),
"names": [user.name for user in users],
}


def test_preprocess_args_with_list_of_dicts_to_pydantic_models():
"""Test _preprocess_args converts list of dicts to list of Pydantic models."""
tool = FunctionTool(function_with_list_of_pydantic_models)

input_args = {
"users": [
{"name": "Alice", "age": 30, "email": "[email protected]"},
{"name": "Bob", "age": 25},
{"name": "Charlie", "age": 35, "email": "[email protected]"},
]
}

processed_args = tool._preprocess_args(input_args)

# Check that the list of dicts was converted to a list of Pydantic models
assert "users" in processed_args
users = processed_args["users"]
assert isinstance(users, list)
assert len(users) == 3

# Check each element is a Pydantic model with correct data
assert isinstance(users[0], UserModel)
assert users[0].name == "Alice"
assert users[0].age == 30
assert users[0].email == "[email protected]"

assert isinstance(users[1], UserModel)
assert users[1].name == "Bob"
assert users[1].age == 25
assert users[1].email is None

assert isinstance(users[2], UserModel)
assert users[2].name == "Charlie"
assert users[2].age == 35
assert users[2].email == "[email protected]"


def test_preprocess_args_with_optional_list_of_pydantic_models_none():
"""Test _preprocess_args handles None for optional list parameter."""
tool = FunctionTool(function_with_optional_list_of_pydantic_models)

input_args = {"users": None}

processed_args = tool._preprocess_args(input_args)

# Check that None is preserved
assert "users" in processed_args
assert processed_args["users"] is None


def test_preprocess_args_with_optional_list_of_pydantic_models_with_data():
"""Test _preprocess_args converts list for optional list parameter."""
tool = FunctionTool(function_with_optional_list_of_pydantic_models)

input_args = {
"users": [
{"name": "Alice", "age": 30},
{"name": "Bob", "age": 25},
]
}

processed_args = tool._preprocess_args(input_args)

# Check conversion
assert "users" in processed_args
users = processed_args["users"]
assert len(users) == 2
assert all(isinstance(user, UserModel) for user in users)
assert users[0].name == "Alice"
assert users[1].name == "Bob"


def test_preprocess_args_with_list_keeps_invalid_items_as_original():
"""Test _preprocess_args keeps original data for items that fail validation."""
tool = FunctionTool(function_with_list_of_pydantic_models)

input_args = {
"users": [
{"name": "Alice", "age": 30},
{"name": "Invalid"}, # Missing required 'age' field
{"name": "Bob", "age": 25},
]
}

processed_args = tool._preprocess_args(input_args)

# Check that all items are preserved
assert "users" in processed_args
users = processed_args["users"]
assert len(users) == 3 # All items preserved

# First item should be converted to UserModel
assert isinstance(users[0], UserModel)
assert users[0].name == "Alice"
assert users[0].age == 30

# Second item should remain as dict (failed validation)
assert isinstance(users[1], dict)
assert users[1] == {"name": "Invalid"}

# Third item should be converted to UserModel
assert isinstance(users[2], UserModel)
assert users[2].name == "Bob"
assert users[2].age == 25