Skip to content

Commit

Permalink
core[patch]: Support injected tool args that are arbitrary types (#27045
Browse files Browse the repository at this point in the history
)

This adds support for inject tool args that are arbitrary types when
used with pydantic 2.

We'll need to add similar logic on the v1 path, and potentially mirror
the config from the original model when we're doing the subset.
  • Loading branch information
eyurtsev authored Oct 2, 2024
1 parent e806e9d commit 74bf620
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
7 changes: 5 additions & 2 deletions libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def _create_subset_model_v2(
fn_description: Optional[str] = None,
) -> type[pydantic.BaseModel]:
"""Create a pydantic model with a subset of the model fields."""
from pydantic import create_model
from pydantic import ConfigDict, create_model
from pydantic.fields import FieldInfo

descriptions_ = descriptions or {}
Expand All @@ -278,7 +278,10 @@ def _create_subset_model_v2(
if field.metadata:
field_info.metadata = field.metadata
fields[field_name] = (field.annotation, field_info)
rtn = create_model(name, **fields) # type: ignore

rtn = create_model( # type: ignore
name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True)
)

# TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations.
# This is done to preserve __annotations__ when working with pydantic 2.x
Expand Down
15 changes: 15 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2090,3 +2090,18 @@ class FooSchema(BaseModel):

with pytest.raises(NotImplementedError):
assert tool.invoke("hello") == "hello"


def test_injected_arg_with_complex_type() -> None:
"""Test that an injected tool arg can be a complex type."""

class Foo:
def __init__(self) -> None:
self.value = "bar"

@tool
def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str:
"""Tool that has an injected tool arg."""
return foo.value

assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore

0 comments on commit 74bf620

Please sign in to comment.