From b3f8bd2bc9dc28cf8770ce4b21d688b735223e4c Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 20 May 2025 19:44:59 +0000 Subject: [PATCH 01/70] Output handling refactoring borrowed from output_mode PR --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 48 ++-- pydantic_ai_slim/pydantic_ai/_output.py | 235 ++++++++++++++----- pydantic_ai_slim/pydantic_ai/agent.py | 30 +-- pydantic_ai_slim/pydantic_ai/result.py | 129 +++------- pydantic_ai_slim/pydantic_ai/tools.py | 5 +- 5 files changed, 243 insertions(+), 204 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index ccc1d18f77..9a83b0911e 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -24,7 +24,7 @@ result, usage as _usage, ) -from .result import OutputDataT, ToolOutput +from .result import OutputDataT from .settings import ModelSettings, merge_model_settings from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc @@ -251,7 +251,7 @@ async def add_mcp_server_tools(server: MCPServer) -> None: output_schema = ctx.deps.output_schema return models.ModelRequestParameters( function_tools=function_tool_defs, - allow_text_output=allow_text_output(output_schema), + allow_text_output=_output.allow_text_output(output_schema), output_tools=output_schema.tool_defs() if output_schema is not None else [], ) @@ -437,7 +437,7 @@ async def _run_stream() -> AsyncIterator[_messages.HandleResponseEvent]: # when the model has already returned text along side tool calls # in this scenario, if text responses are allowed, we return text from the most recent model # response, if any - if allow_text_output(ctx.deps.output_schema): + if _output.allow_text_output(ctx.deps.output_schema): for message in reversed(ctx.state.message_history): if isinstance(message, _messages.ModelResponse): last_texts = [p.content for p in message.parts if isinstance(p, _messages.TextPart)] @@ -520,27 +520,22 @@ async def _handle_text_response( output_schema = ctx.deps.output_schema text = '\n\n'.join(texts) - if allow_text_output(output_schema): - # The following cast is safe because we know `str` is an allowed result type - result_data_input = cast(NodeRunEndT, text) - try: - result_data = await _validate_output(result_data_input, ctx, None) - except _output.ToolRetryError as e: - ctx.state.increment_retries(ctx.deps.max_result_retries) - return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) + try: + if _output.allow_text_output(output_schema): + # The following cast is safe because we know `str` is an allowed result type + result_data = cast(NodeRunEndT, text) else: - return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) - else: - ctx.state.increment_retries(ctx.deps.max_result_retries) - return ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest( - parts=[ - _messages.RetryPromptPart( - content='Plain text responses are not permitted, please include your response in a tool call', - ) - ] + m = _messages.RetryPromptPart( + content='Plain text responses are not permitted, please include your response in a tool call', ) - ) + raise _output.ToolRetryError(m) + + result_data = await _validate_output(result_data, ctx, None) + except _output.ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries) + return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) + else: + return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> RunContext[DepsT]: @@ -782,11 +777,6 @@ async def _validate_output( return result_data -def allow_text_output(output_schema: _output.OutputSchema[Any] | None) -> bool: - """Check if the result schema allows text results.""" - return output_schema is None or output_schema.allow_text_output - - @dataclasses.dataclass class _RunMessages: messages: list[_messages.ModelMessage] @@ -836,7 +826,9 @@ def get_captured_run_messages() -> _RunMessages: def build_agent_graph( - name: str | None, deps_type: type[DepsT], output_type: type[OutputT] | ToolOutput[OutputT] + name: str | None, + deps_type: type[DepsT], + output_type: _output.OutputType[OutputT], ) -> Graph[GraphAgentState, GraphAgentDeps[DepsT, result.FinalResult[OutputT]], result.FinalResult[OutputT]]: """Build the execution [Graph][pydantic_graph.Graph] for a given agent.""" nodes = ( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index f2246ffb92..5e5565de08 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -12,11 +12,43 @@ from . import _utils, messages as _messages from .exceptions import ModelRetry -from .result import DEFAULT_OUTPUT_TOOL_NAME, OutputDataT, OutputDataT_inv, OutputValidatorFunc, ToolOutput -from .tools import AgentDepsT, GenerateToolJsonSchema, RunContext, ToolDefinition +from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition T = TypeVar('T') """An invariant TypeVar.""" +OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) +""" +An invariant type variable for the result data of a model. + +We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used +in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types +possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and +changing it would have negative consequences for the ergonomics of the library. + +At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would +resolve these potential variance issues. +""" +OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) +"""Covariant type variable for the result data type of a run.""" + +OutputValidatorFunc = Union[ + Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv], + Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]], + Callable[[OutputDataT_inv], OutputDataT_inv], + Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]], +] +""" +A function that always takes and returns the same type of data (which is the result type of an agent run), and: + +* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument +* may or may not be async + +Usage `OutputValidatorFunc[AgentDepsT, T]`. +""" + + +DEFAULT_OUTPUT_TOOL_NAME = 'final_result' +DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' @dataclass @@ -76,69 +108,99 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): super().__init__() +@dataclass(init=False) +class ToolOutput(Generic[OutputDataT]): + """Marker class to use tools for outputs, and customize the tool.""" + + output_type: type[OutputDataT] + name: str | None + description: str | None + max_retries: int | None + strict: bool | None + + def __init__( + self, + *, + type_: type[OutputDataT], + name: str | None = None, + description: str | None = None, + max_retries: int | None = None, + strict: bool | None = None, + ): + self.output_type = type_ + self.name = name + self.description = description + self.max_retries = max_retries + self.strict = strict + + +# TODO: Use TypeAliasType +type OutputType[OutputDataT] = type[OutputDataT] | ToolOutput[OutputDataT] + + @dataclass class OutputSchema(Generic[OutputDataT]): - """Model the final response from an agent run. + """Model the final output from an agent run. Similar to `Tool` but for the final output of running an agent. """ - tools: dict[str, OutputSchemaTool[OutputDataT]] + tools: dict[str, OutputTool[OutputDataT]] allow_text_output: bool @classmethod def build( cls: type[OutputSchema[T]], - output_type: type[T] | ToolOutput[T], + output_type: OutputType[T], name: str | None = None, description: str | None = None, strict: bool | None = None, ) -> OutputSchema[T] | None: - """Build an OutputSchema dataclass from a response type.""" + """Build an OutputSchema dataclass from an output type.""" if output_type is str: return None + allow_text_output = False if isinstance(output_type, ToolOutput): # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads name = output_type.name description = output_type.description output_type_ = output_type.output_type strict = output_type.strict - else: - output_type_ = output_type - - if output_type_option := extract_str_from_union(output_type): - output_type_ = output_type_option.value + elif output_type_other_than_str := extract_str_from_union(output_type): + output_type_ = output_type_other_than_str.value allow_text_output = True else: - allow_text_output = False + output_type_ = output_type - tools: dict[str, OutputSchemaTool[T]] = {} + tools: dict[str, OutputTool[T]] = {} if args := get_union_args(output_type_): for i, arg in enumerate(args, start=1): tool_name = raw_tool_name = union_tool_name(name, arg) while tool_name in tools: tool_name = f'{raw_tool_name}_{i}' + + parameters_schema = OutputObjectSchema(output_type=arg, description=description, strict=strict) tools[tool_name] = cast( - OutputSchemaTool[T], - OutputSchemaTool( - output_type=arg, name=tool_name, description=description, multiple=True, strict=strict - ), + OutputTool[T], + OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=True), ) else: - name = name or DEFAULT_OUTPUT_TOOL_NAME - tools[name] = cast( - OutputSchemaTool[T], - OutputSchemaTool( - output_type=output_type_, name=name, description=description, multiple=False, strict=strict - ), + tool_name = name or DEFAULT_OUTPUT_TOOL_NAME + parameters_schema = OutputObjectSchema(output_type=output_type_, description=description, strict=strict) + tools[tool_name] = cast( + OutputTool[T], + OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=False), ) - return cls(tools=tools, allow_text_output=allow_text_output) + return cls( + tools=tools, + allow_text_output=allow_text_output, + ) def find_named_tool( self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]] | None: + ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: """Find a tool that matches one of the calls, with a specific name.""" for part in parts: # pragma: no branch if isinstance(part, _messages.ToolCallPart): # pragma: no branch @@ -148,7 +210,7 @@ def find_named_tool( def find_tool( self, parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputSchemaTool[OutputDataT]]]: + ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: """Find a tool that matches one of the calls.""" for part in parts: if isinstance(part, _messages.ToolCallPart): # pragma: no branch @@ -164,55 +226,116 @@ def tool_defs(self) -> list[ToolDefinition]: return [t.tool_def for t in self.tools.values()] -DEFAULT_DESCRIPTION = 'The final response which ends this conversation' +def allow_text_output(output_schema: OutputSchema[Any] | None) -> bool: + return output_schema is None or output_schema.allow_text_output + + +@dataclass +class OutputObjectDefinition: + name: str + json_schema: ObjectJsonSchema + description: str | None = None + strict: bool | None = None @dataclass(init=False) -class OutputSchemaTool(Generic[OutputDataT]): - tool_def: ToolDefinition +class OutputObjectSchema(Generic[OutputDataT]): + definition: OutputObjectDefinition type_adapter: TypeAdapter[Any] + outer_typed_dict_key: str | None = None def __init__( - self, *, output_type: type[OutputDataT], name: str, description: str | None, multiple: bool, strict: bool | None + self, + *, + output_type: type[OutputDataT], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, ): - """Build a OutputSchemaTool from a response type.""" if _utils.is_model_like(output_type): self.type_adapter = TypeAdapter(output_type) - outer_typed_dict_key: str | None = None - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) else: + self.outer_typed_dict_key = 'response' response_data_typed_dict = TypedDict( # noqa: UP013 'response_data_typed_dict', {'response': output_type}, # pyright: ignore[reportInvalidTypeForm] ) self.type_adapter = TypeAdapter(response_data_typed_dict) - outer_typed_dict_key = 'response' - # noinspection PyArgumentList - parameters_json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) + + json_schema = _utils.check_object_json_schema( + self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) + ) + if self.outer_typed_dict_key: # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM - parameters_json_schema.pop('title') + json_schema.pop('title') - if json_schema_description := parameters_json_schema.pop('description', None): + if json_schema_description := json_schema.pop('description', None): if description is None: - tool_description = json_schema_description + description = json_schema_description else: - tool_description = f'{description}. {json_schema_description}' # pragma: no cover + description = f'{description}. {json_schema_description}' + + self.definition = OutputObjectDefinition( + name=name or getattr(output_type, '__name__', DEFAULT_OUTPUT_TOOL_NAME), + description=description, + json_schema=json_schema, + strict=strict, + ) + + def validate( + self, data: str | dict[str, Any], allow_partial: bool = False, wrap_validation_errors: bool = True + ) -> OutputDataT: + """Validate an output message. + + Args: + data: The output data to validate. + allow_partial: If true, allow partial validation. + wrap_validation_errors: If true, wrap the validation errors in a retry message. + + Returns: + Either the validated output data (left) or a retry message (right). + """ + try: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + output = self.type_adapter.validate_json(data, experimental_allow_partial=pyd_allow_partial) + else: + output = self.type_adapter.validate_python(data, experimental_allow_partial=pyd_allow_partial) + except ValidationError as e: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=e.errors(include_url=False), + ) + raise ToolRetryError(m) from e + else: + raise else: - tool_description = description or DEFAULT_DESCRIPTION + if k := self.outer_typed_dict_key: + output = output[k] + return output + + +@dataclass(init=False) +class OutputTool(Generic[OutputDataT]): + parameters_schema: OutputObjectSchema[OutputDataT] + tool_def: ToolDefinition + + def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDataT], multiple: bool): + self.parameters_schema = parameters_schema + definition = parameters_schema.definition + + description = definition.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION if multiple: - tool_description = f'{union_arg_name(output_type)}: {tool_description}' + description = f'{definition.name}: {description}' self.tool_def = ToolDefinition( name=name, - description=tool_description, - parameters_json_schema=parameters_json_schema, - outer_typed_dict_key=outer_typed_dict_key, - strict=strict, + description=description, + parameters_json_schema=definition.json_schema, + strict=definition.strict, + outer_typed_dict_key=parameters_schema.outer_typed_dict_key, ) def validate( @@ -229,11 +352,9 @@ def validate( Either the validated output data (left) or a retry message (right). """ try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(tool_call.args, str): - output = self.type_adapter.validate_json(tool_call.args, experimental_allow_partial=pyd_allow_partial) - else: - output = self.type_adapter.validate_python(tool_call.args, experimental_allow_partial=pyd_allow_partial) + output = self.parameters_schema.validate( + tool_call.args, allow_partial=allow_partial, wrap_validation_errors=False + ) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -245,8 +366,6 @@ def validate( else: raise # pragma: lax no cover else: - if k := self.tool_def.outer_typed_dict_key: - output = output[k] return output diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3242478865..636684bdff 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -29,7 +29,7 @@ usage as _usage, ) from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model -from .result import FinalResult, OutputDataT, StreamedRunResult, ToolOutput +from .result import FinalResult, OutputDataT, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( AgentDepsT, @@ -127,7 +127,7 @@ class Agent(Generic[AgentDepsT, OutputDataT]): be merged with this value, with the runtime argument taking priority. """ - output_type: type[OutputDataT] | ToolOutput[OutputDataT] + output_type: _output.OutputType[OutputDataT] """ The type of data output by agent runs, used to validate the data returned by the model, defaults to `str`. """ @@ -166,7 +166,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str, + output_type: _output.OutputType[OutputDataT] = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] @@ -203,7 +203,7 @@ def __init__( name: str | None = None, model_settings: ModelSettings | None = None, retries: int = 1, - result_tool_name: str = 'final_result', + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, result_tool_description: str | None = None, result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), @@ -218,7 +218,7 @@ def __init__( self, model: models.Model | models.KnownModelName | str | None = None, *, - # TODO change this back to `output_type: type[OutputDataT] | ToolOutput[OutputDataT] = str,` when we remove the overloads + # TODO change this back to `output_type: _output.OutputType[OutputDataT] = str,` when we remove the overloads output_type: Any = str, instructions: str | _system_prompt.SystemPromptFunc[AgentDepsT] @@ -378,7 +378,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -408,7 +408,7 @@ async def run( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -496,7 +496,7 @@ def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -528,7 +528,7 @@ async def iter( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -783,7 +783,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -813,7 +813,7 @@ def run_sync( self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -896,7 +896,7 @@ def run_stream( self, user_prompt: str | Sequence[_messages.UserContent], *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT], + output_type: _output.OutputType[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -927,7 +927,7 @@ async def run_stream( # noqa C901 self, user_prompt: str | Sequence[_messages.UserContent] | None = None, *, - output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None = None, + output_type: _output.OutputType[RunOutputDataT] | None = None, message_history: list[_messages.ModelMessage] | None = None, model: models.Model | models.KnownModelName | str | None = None, deps: AgentDepsT = None, @@ -1007,7 +1007,7 @@ async def stream_to_final( if isinstance(maybe_part_event, _messages.PartStartEvent): new_part = maybe_part_event.part if isinstance(new_part, _messages.TextPart): - if _agent_graph.allow_text_output(output_schema): + if _output.allow_text_output(output_schema): return FinalResult(s, None, None) elif isinstance(new_part, _messages.ToolCallPart) and output_schema: for call, _ in output_schema.find_tool([new_part]): @@ -1641,7 +1641,7 @@ def last_run_messages(self) -> list[_messages.ModelMessage]: raise AttributeError('The `last_run_messages` attribute has been removed, use `capture_run_messages` instead.') def _prepare_output_schema( - self, output_type: type[RunOutputDataT] | ToolOutput[RunOutputDataT] | None + self, output_type: _output.OutputType[RunOutputDataT] | None ) -> _output.OutputSchema[RunOutputDataT] | None: if output_type is not None: if self._output_validators: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 6d9d397392..9b742062c8 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,100 +5,35 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import TYPE_CHECKING, Generic, Union, cast +from typing import Generic, cast from typing_extensions import TypeVar, assert_type, deprecated, overload -from . import _utils, exceptions, messages as _messages, models +from . import _output, _utils, exceptions, messages as _messages, models +from ._output import ( + OutputDataT, + OutputDataT_inv, + OutputSchema, + OutputValidator, + OutputValidatorFunc, + ToolOutput, +) from .messages import AgentStreamEvent, FinalResultEvent from .tools import AgentDepsT, RunContext from .usage import Usage, UsageLimits -if TYPE_CHECKING: - from . import _output - __all__ = 'OutputDataT', 'OutputDataT_inv', 'ToolOutput', 'OutputValidatorFunc' T = TypeVar('T') """An invariant TypeVar.""" -OutputDataT_inv = TypeVar('OutputDataT_inv', default=str) -""" -An invariant type variable for the result data of a model. - -We need to use an invariant typevar for `OutputValidator` and `OutputValidatorFunc` because the output data type is used -in both the input and output of a `OutputValidatorFunc`. This can theoretically lead to some issues assuming that types -possessing OutputValidator's are covariant in the result data type, but in practice this is rarely an issue, and -changing it would have negative consequences for the ergonomics of the library. - -At some point, it may make sense to change the input to OutputValidatorFunc to be `Any` or `object` as doing that would -resolve these potential variance issues. -""" -OutputDataT = TypeVar('OutputDataT', default=str, covariant=True) -"""Covariant type variable for the result data type of a run.""" - -OutputValidatorFunc = Union[ - Callable[[RunContext[AgentDepsT], OutputDataT_inv], OutputDataT_inv], - Callable[[RunContext[AgentDepsT], OutputDataT_inv], Awaitable[OutputDataT_inv]], - Callable[[OutputDataT_inv], OutputDataT_inv], - Callable[[OutputDataT_inv], Awaitable[OutputDataT_inv]], -] -""" -A function that always takes and returns the same type of data (which is the result type of an agent run), and: - -* may or may not take [`RunContext`][pydantic_ai.tools.RunContext] as a first argument -* may or may not be async - -Usage `OutputValidatorFunc[AgentDepsT, T]`. -""" - -DEFAULT_OUTPUT_TOOL_NAME = 'final_result' - - -@dataclass(init=False) -class ToolOutput(Generic[OutputDataT]): - """Marker class to use tools for structured outputs, and customize the tool.""" - - output_type: type[OutputDataT] - # TODO: Add `output_call` support, for calling a function to get the output - # output_call: Callable[..., OutputDataT] | None - name: str - description: str | None - max_retries: int | None - strict: bool | None - - def __init__( - self, - *, - type_: type[OutputDataT], - # call: Callable[..., OutputDataT] | None = None, - name: str = 'final_result', - description: str | None = None, - max_retries: int | None = None, - strict: bool | None = None, - ): - self.output_type = type_ - self.name = name - self.description = description - self.max_retries = max_retries - self.strict = strict - - # TODO: add support for call and make type_ optional, with the following logic: - # if type_ is None and call is None: - # raise ValueError('Either type_ or call must be provided') - # if call is not None: - # if type_ is None: - # type_ = get_type_hints(call).get('return') - # if type_ is None: - # raise ValueError('Unable to determine type_ from call signature; please provide it explicitly') - # self.output_call = call @dataclass class AgentStream(Generic[AgentDepsT, OutputDataT]): _raw_stream_response: models.StreamedResponse - _output_schema: _output.OutputSchema[OutputDataT] | None - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] + _output_schema: OutputSchema[OutputDataT] | None + _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None @@ -144,6 +79,7 @@ async def _validate_response( self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" + call = None if self._output_schema is not None and output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, output_tool_name) if match is None: @@ -153,20 +89,14 @@ async def _validate_response( call, output_tool = match result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) - - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - for validator in self._output_validators: - text = await validator.validate( - text, - None, - self._run_ctx, - ) - # Since there is no output tool, we can assume that str is compatible with OutputDataT - return cast(OutputDataT, text) + # The following cast is safe because we know `str` is an allowed output type + result_data = cast(OutputDataT, text) + + for validator in self._output_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) + return result_data def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -180,7 +110,6 @@ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: async def aiter(): output_schema = self._output_schema - allow_text_output = output_schema is None or output_schema.allow_text_output def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None: """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" @@ -192,7 +121,7 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. return _messages.FinalResultEvent( tool_name=call.tool_name, tool_call_id=call.tool_call_id ) - elif allow_text_output: # pragma: no branch + elif _output.allow_text_output(output_schema): # pragma: no branch assert_type(e, _messages.PartStartEvent) return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) @@ -224,9 +153,9 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _usage_limits: UsageLimits | None _stream_response: models.StreamedResponse - _output_schema: _output.OutputSchema[OutputDataT] | None + _output_schema: OutputSchema[OutputDataT] | None _run_ctx: RunContext[AgentDepsT] - _output_validators: list[_output.OutputValidator[AgentDepsT, OutputDataT]] + _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] @@ -458,6 +387,7 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" + call = None if self._output_schema is not None and self._output_tool_name is not None: match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) if match is None: @@ -467,16 +397,13 @@ async def validate_structured_output( call, output_tool = match result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) - - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover - return result_data else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) - for validator in self._output_validators: - text = await validator.validate(text, None, self._run_ctx) # pragma: no cover - # Since there is no output tool, we can assume that str is compatible with OutputDataT - return cast(OutputDataT, text) + result_data = cast(OutputDataT, text) + + for validator in self._output_validators: + result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover + return result_data async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index cee99dfd2b..584d77e02a 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -13,10 +13,11 @@ from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar -from . import _pydantic, _utils, messages as _messages, models +from . import _pydantic, _utils, messages as _messages from .exceptions import ModelRetry, UnexpectedModelBehavior if TYPE_CHECKING: + from .models import Model from .result import Usage __all__ = ( @@ -45,7 +46,7 @@ class RunContext(Generic[AgentDepsT]): deps: AgentDepsT """Dependencies for the agent.""" - model: models.Model + model: Model """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" From 2ebe6a96c7392ec692dc221311b12ec3e6ef3fe6 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 20 May 2025 22:22:19 +0000 Subject: [PATCH 02/70] Support functions as output_type, as well as lists of functions and other types --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/_output.py | 138 ++++--- pydantic_ai_slim/pydantic_ai/result.py | 4 +- tests/test_agent.py | 407 ++++++++++++++++++- 4 files changed, 498 insertions(+), 53 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 9a83b0911e..a65b177e24 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -465,7 +465,7 @@ async def _handle_tool_calls( if output_schema is not None: for call, output_tool in output_schema.find_tool(tool_calls): try: - result_data = output_tool.validate(call) + result_data = await output_tool.process(call) result_data = await _validate_output(result_data, ctx, call) except _output.ToolRetryError as e: # TODO: Should only increment retry stuff once per node execution, not for each tool call diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 5e5565de08..83a17459dc 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import inspect -from collections.abc import Awaitable, Iterable, Iterator +from collections.abc import Awaitable, Iterable, Iterator, Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, Union, cast @@ -112,7 +112,7 @@ def __init__(self, tool_retry: _messages.RetryPromptPart): class ToolOutput(Generic[OutputDataT]): """Marker class to use tools for outputs, and customize the tool.""" - output_type: type[OutputDataT] + output_type: SimpleOutputType[OutputDataT] name: str | None description: str | None max_retries: int | None @@ -121,7 +121,7 @@ class ToolOutput(Generic[OutputDataT]): def __init__( self, *, - type_: type[OutputDataT], + type_: SimpleOutputType[OutputDataT], name: str | None = None, description: str | None = None, max_retries: int | None = None, @@ -135,7 +135,10 @@ def __init__( # TODO: Use TypeAliasType -type OutputType[OutputDataT] = type[OutputDataT] | ToolOutput[OutputDataT] +type OutputCallable[OutputDataT] = Callable[..., OutputDataT | Awaitable[OutputDataT]] +type SimpleOutputType[OutputDataT] = type[OutputDataT] | OutputCallable[OutputDataT] +type SimpleOutputTypeOrMarker[OutputDataT] = SimpleOutputType[OutputDataT] | ToolOutput[OutputDataT] +type OutputType[OutputDataT] = SimpleOutputTypeOrMarker[OutputDataT] | Sequence[SimpleOutputTypeOrMarker[OutputDataT]] @dataclass @@ -150,48 +153,72 @@ class OutputSchema(Generic[OutputDataT]): @classmethod def build( - cls: type[OutputSchema[T]], - output_type: OutputType[T], + cls: type[OutputSchema[OutputDataT]], + output_type: OutputType[OutputDataT], name: str | None = None, description: str | None = None, strict: bool | None = None, - ) -> OutputSchema[T] | None: + ) -> OutputSchema[OutputDataT] | None: """Build an OutputSchema dataclass from an output type.""" if output_type is str: return None - allow_text_output = False - if isinstance(output_type, ToolOutput): - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - name = output_type.name - description = output_type.description - output_type_ = output_type.output_type - strict = output_type.strict - elif output_type_other_than_str := extract_str_from_union(output_type): - output_type_ = output_type_other_than_str.value - allow_text_output = True + multiple = False + + output_types_or_markers: Sequence[SimpleOutputTypeOrMarker[OutputDataT]] + if isinstance(output_type, Sequence): + output_types_or_markers = output_type + multiple = True else: - output_type_ = output_type + output_types_or_markers = [output_type] + + allow_text_output = False + tools: dict[str, OutputTool[OutputDataT]] = {} + for output_type_or_marker in output_types_or_markers: + tool_name = name + tool_description = description + tool_strict = strict + custom_tool_name = False + if isinstance(output_type_or_marker, ToolOutput): + output_type_ = output_type_or_marker.output_type + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + tool_name = output_type_or_marker.name + tool_description = output_type_or_marker.description + tool_strict = output_type_or_marker.strict + custom_tool_name = True + elif output_type_other_than_str := extract_str_from_union(output_type_or_marker): + output_type_ = output_type_other_than_str.value + allow_text_output = True + else: + output_type_ = output_type_or_marker + + disambiguate_tool_name = multiple and not custom_tool_name + if args := get_union_args(output_type_): + multiple = True + disambiguate_tool_name = True + else: + args = (output_type_,) + + base_tool_name = tool_name + for arg in args: + tool_name = base_tool_name or DEFAULT_OUTPUT_TOOL_NAME - tools: dict[str, OutputTool[T]] = {} - if args := get_union_args(output_type_): - for i, arg in enumerate(args, start=1): - tool_name = raw_tool_name = union_tool_name(name, arg) + if disambiguate_tool_name: + tool_name += f'_{arg.__name__}' + + i = 1 + original_tool_name = tool_name while tool_name in tools: - tool_name = f'{raw_tool_name}_{i}' + tool_name = f'{original_tool_name}_{i}' + i += 1 - parameters_schema = OutputObjectSchema(output_type=arg, description=description, strict=strict) + parameters_schema = OutputObjectSchema( + output_type=arg, description=tool_description, strict=tool_strict + ) tools[tool_name] = cast( - OutputTool[T], - OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=True), + OutputTool[OutputDataT], + OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=multiple), ) - else: - tool_name = name or DEFAULT_OUTPUT_TOOL_NAME - parameters_schema = OutputObjectSchema(output_type=output_type_, description=description, strict=strict) - tools[tool_name] = cast( - OutputTool[T], - OutputTool(name=tool_name, parameters_schema=parameters_schema, multiple=False), - ) return cls( tools=tools, @@ -247,18 +274,18 @@ class OutputObjectSchema(Generic[OutputDataT]): def __init__( self, *, - output_type: type[OutputDataT], + output_type: SimpleOutputType[OutputDataT], name: str | None = None, description: str | None = None, strict: bool | None = None, ): - if _utils.is_model_like(output_type): + if inspect.isfunction(output_type) or _utils.is_model_like(output_type): self.type_adapter = TypeAdapter(output_type) else: self.outer_typed_dict_key = 'response' response_data_typed_dict = TypedDict( # noqa: UP013 'response_data_typed_dict', - {'response': output_type}, # pyright: ignore[reportInvalidTypeForm] + {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm] ) self.type_adapter = TypeAdapter(response_data_typed_dict) @@ -282,10 +309,10 @@ def __init__( strict=strict, ) - def validate( + async def process( self, data: str | dict[str, Any], allow_partial: bool = False, wrap_validation_errors: bool = True ) -> OutputDataT: - """Validate an output message. + """Process an output message, performing validation and (if necessary) calling the output function. Args: data: The output data to validate. @@ -296,6 +323,7 @@ def validate( Either the validated output data (left) or a retry message (right). """ try: + # TODO: Inject RunContext? pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' if isinstance(data, str): output = self.type_adapter.validate_json(data, experimental_allow_partial=pyd_allow_partial) @@ -309,10 +337,20 @@ def validate( raise ToolRetryError(m) from e else: raise + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart(content=r.message) + raise ToolRetryError(m) from r + else: + raise else: if k := self.outer_typed_dict_key: output = output[k] - return output + + # A non-awaitable callable output_type will already have been executed by the TypeAdapter's validation method + if inspect.isawaitable(output): + output = await output + return output @dataclass(init=False) @@ -338,7 +376,7 @@ def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDat outer_typed_dict_key=parameters_schema.outer_typed_dict_key, ) - def validate( + async def process( self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True ) -> OutputDataT: """Validate an output message. @@ -352,7 +390,7 @@ def validate( Either the validated output data (left) or a retry message (right). """ try: - output = self.parameters_schema.validate( + output = await self.parameters_schema.process( tool_call.args, allow_partial=allow_partial, wrap_validation_errors=False ) except ValidationError as e: @@ -365,18 +403,20 @@ def validate( raise ToolRetryError(m) from e else: raise # pragma: lax no cover + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + tool_name=tool_call.tool_name, + content=r.message, + tool_call_id=tool_call.tool_call_id, + ) + raise ToolRetryError(m) from r + else: + raise else: return output -def union_tool_name(base_name: str | None, union_arg: Any) -> str: - return f'{base_name or DEFAULT_OUTPUT_TOOL_NAME}_{union_arg_name(union_arg)}' - - -def union_arg_name(union_arg: Any) -> str: - return union_arg.__name__ - - def extract_str_from_union(output_type: Any) -> _utils.Option[Any]: """Extract the string type from a Union, return the remaining union or remaining type.""" union_args = get_union_args(output_type) diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 9b742062c8..7af756ee53 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -88,7 +88,7 @@ async def _validate_response( ) call, output_tool = match - result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) + result_data = await output_tool.process(call, allow_partial=allow_partial, wrap_validation_errors=False) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) # The following cast is safe because we know `str` is an allowed output type @@ -396,7 +396,7 @@ async def validate_structured_output( ) call, output_tool = match - result_data = output_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False) + result_data = await output_tool.process(call, allow_partial=allow_partial, wrap_validation_errors=False) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = cast(OutputDataT, text) diff --git a/tests/test_agent.py b/tests/test_agent.py index 5e93df4e57..7c044574b0 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -29,7 +29,7 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.result import Usage +from pydantic_ai.result import ToolOutput, Usage from pydantic_ai.tools import ToolDefinition from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -446,6 +446,8 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'union_code', [ pytest.param('OutputType = Union[Foo, Bar]'), + pytest.param('OutputType = [Foo, Bar]'), + pytest.param('OutputType = ToolOutput(type_=Union[Foo, Bar])'), pytest.param('OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')), pytest.param( 'OutputType: TypeAlias = Foo | Bar', @@ -461,6 +463,7 @@ def test_response_multiple_return_tools(create_module: Callable[[str], Any], uni from pydantic import BaseModel from typing import Union from typing_extensions import TypeAlias +from pydantic_ai import ToolOutput class Foo(BaseModel): a: int @@ -531,6 +534,408 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert got_tool_call_name == snapshot('final_result_Bar') +def test_output_type_callable(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_callable_with_retry(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + if city != 'Mexico City': + raise ModelRetry('City not found, I only know Mexico City') + return Weather(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + args_json = '{"city": "New York City"}' + else: + args_json = '{"city": "Mexico City"}' + + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('New York City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='New York City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "New York City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=53, response_tokens=7, total_tokens=60), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='City not found, I only know Mexico City', + tool_name='final_result', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=68, response_tokens=13, total_tokens=81), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +def test_output_type_async_callable(): + class Weather(BaseModel): + temperature: float + description: str + + async def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_callable_with_custom_tool_name(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=ToolOutput(type_=get_weather, name='get_weather')) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='get_weather', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_callable_or_model(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=[get_weather, Weather]) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result_get_weather', + description='get_weather: The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ), + ToolDefinition( + name='final_result_Weather', + description='Weather: The final response which ends this conversation', + parameters_json_schema={ + 'properties': {'temperature': {'type': 'number'}, 'description': {'type': 'string'}}, + 'required': ['temperature', 'description'], + 'title': 'Weather', + 'type': 'object', + }, + ), + ] + ) + + +def test_output_type_handoff_to_agent(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + + handoff_result = None + + async def handoff(city: str) -> Weather: + result = await agent.run(f'Get me the weather in {city}') + nonlocal handoff_result + handoff_result = result + return result.output + + def call_handoff_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + supervisor_agent = Agent(FunctionModel(call_handoff_tool), output_type=handoff) + + result = supervisor_agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Mexico City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=52, response_tokens=6, total_tokens=58), + model_name='function:call_handoff_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + assert handoff_result is not None + assert handoff_result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Get me the weather in Mexico City', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"city": "Mexico City"}', + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=57, response_tokens=6, total_tokens=63), + model_name='function:call_tool:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='final_result', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +def test_output_type_multiple_custom_tools(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(city: str) -> Weather: + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent( + FunctionModel(call_tool), + output_type=[ + ToolOutput(type_=get_weather, name='get_weather'), + ToolOutput(type_=Weather, name='return_weather'), + ], + ) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='get_weather', + description='get_weather: The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ), + ToolDefinition( + name='return_weather', + description='Weather: The final response which ends this conversation', + parameters_json_schema={ + 'properties': {'temperature': {'type': 'number'}, 'description': {'type': 'string'}}, + 'required': ['temperature', 'description'], + 'title': 'Weather', + 'type': 'object', + }, + ), + ] + ) + + def test_run_with_history_new(): m = TestModel() From ab576d7e099ce3dab7d4ee21e5a82eaec42708ab Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 20 May 2025 23:35:22 +0000 Subject: [PATCH 03/70] Fix tests --- pydantic_ai_slim/pydantic_ai/_output.py | 18 ++++++++++++------ tests/test_examples.py | 2 +- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 83a17459dc..84dd575abf 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Generic, Literal, Union, cast from pydantic import TypeAdapter, ValidationError -from typing_extensions import TypedDict, TypeVar, get_args, get_origin +from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin @@ -134,11 +134,16 @@ def __init__( self.strict = strict -# TODO: Use TypeAliasType -type OutputCallable[OutputDataT] = Callable[..., OutputDataT | Awaitable[OutputDataT]] -type SimpleOutputType[OutputDataT] = type[OutputDataT] | OutputCallable[OutputDataT] -type SimpleOutputTypeOrMarker[OutputDataT] = SimpleOutputType[OutputDataT] | ToolOutput[OutputDataT] -type OutputType[OutputDataT] = SimpleOutputTypeOrMarker[OutputDataT] | Sequence[SimpleOutputTypeOrMarker[OutputDataT]] +# TODO: Comments to explain what's supported +T_co = TypeVar('T_co', covariant=True) +OutputCallable = TypeAliasType('OutputCallable', Callable[..., T_co | Awaitable[T_co]], type_params=(T_co,)) +SimpleOutputType = TypeAliasType('SimpleOutputType', type[T_co] | OutputCallable[T_co], type_params=(T_co,)) +SimpleOutputTypeOrMarker = TypeAliasType( + 'SimpleOutputTypeOrMarker', SimpleOutputType[T_co] | ToolOutput[T_co], type_params=(T_co,) +) +OutputType = TypeAliasType( + 'OutputType', SimpleOutputTypeOrMarker[T_co] | Sequence[SimpleOutputTypeOrMarker[T_co]], type_params=(T_co,) +) @dataclass @@ -279,6 +284,7 @@ def __init__( description: str | None = None, strict: bool | None = None, ): + # TODO: Support bound instance methods if inspect.isfunction(output_type) or _utils.is_model_like(output_type): self.type_adapter = TypeAdapter(output_type) else: diff --git a/tests/test_examples.py b/tests/test_examples.py index ff74ff80cb..b1f4986805 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -322,7 +322,7 @@ async def list_tools() -> list[None]: args={'response': ['red', 'blue', 'green']}, ), 'square size 10, circle size 20, triangle size 30': ToolCallPart( - tool_name='final_result_list_2', + tool_name='final_result_list_1', args={'response': [10, 20, 30]}, ), 'get me users who were last active yesterday.': ToolCallPart( From 6c4fcec4f3273c8e0a883f309054187f1605b71e Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 20 May 2025 23:40:03 +0000 Subject: [PATCH 04/70] Make Python 3.9 happy --- pydantic_ai_slim/pydantic_ai/_output.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 84dd575abf..5705dd40cd 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -136,13 +136,13 @@ def __init__( # TODO: Comments to explain what's supported T_co = TypeVar('T_co', covariant=True) -OutputCallable = TypeAliasType('OutputCallable', Callable[..., T_co | Awaitable[T_co]], type_params=(T_co,)) -SimpleOutputType = TypeAliasType('SimpleOutputType', type[T_co] | OutputCallable[T_co], type_params=(T_co,)) +OutputCallable = TypeAliasType('OutputCallable', Callable[..., Union[T_co, Awaitable[T_co]]], type_params=(T_co,)) +SimpleOutputType = TypeAliasType('SimpleOutputType', Union[type[T_co], OutputCallable[T_co]], type_params=(T_co,)) SimpleOutputTypeOrMarker = TypeAliasType( - 'SimpleOutputTypeOrMarker', SimpleOutputType[T_co] | ToolOutput[T_co], type_params=(T_co,) + 'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,) ) OutputType = TypeAliasType( - 'OutputType', SimpleOutputTypeOrMarker[T_co] | Sequence[SimpleOutputTypeOrMarker[T_co]], type_params=(T_co,) + 'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,) ) From 1bd16dcbda1f5cfc41e8845002199ab75ffa77d0 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 21 May 2025 18:09:40 +0000 Subject: [PATCH 05/70] Support output_type = bound instance method --- pydantic_ai_slim/pydantic_ai/_output.py | 3 +- tests/test_agent.py | 49 +++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 5705dd40cd..5aa86d01ca 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -284,8 +284,7 @@ def __init__( description: str | None = None, strict: bool | None = None, ): - # TODO: Support bound instance methods - if inspect.isfunction(output_type) or _utils.is_model_like(output_type): + if inspect.isfunction(output_type) or inspect.ismethod(output_type) or _utils.is_model_like(output_type): self.type_adapter = TypeAdapter(output_type) else: self.outer_typed_dict_key = 'response' diff --git a/tests/test_agent.py b/tests/test_agent.py index 7c044574b0..2028ad3b7e 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -10,6 +10,7 @@ from inline_snapshot import snapshot from pydantic import BaseModel, TypeAdapter, field_validator from pydantic_core import to_json +from typing_extensions import Self from pydantic_ai import Agent, ModelRetry, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai.agent import AgentRunResult @@ -534,7 +535,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert got_tool_call_name == snapshot('final_result_Bar') -def test_output_type_callable(): +def test_output_type_function(): class Weather(BaseModel): temperature: float description: str @@ -572,7 +573,47 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) -def test_output_type_callable_with_retry(): +def test_output_type_bound_instance_method(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, city: str) -> Self: + return self + + weather = Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=weather.get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_function_with_retry(): class Weather(BaseModel): temperature: float description: str @@ -691,7 +732,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) -def test_output_type_callable_with_custom_tool_name(): +def test_output_type_function_with_custom_tool_name(): class Weather(BaseModel): temperature: float description: str @@ -729,7 +770,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) -def test_output_type_callable_or_model(): +def test_output_type_function_or_model(): class Weather(BaseModel): temperature: float description: str From 98e64d4e148767b374dae8bce1c4daa293959a07 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 21 May 2025 21:44:27 +0000 Subject: [PATCH 06/70] Support RunContext arg on output_type function using same logic as tools --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 4 +- .../{_pydantic.py => _function_schema.py} | 56 +++++++++-- pydantic_ai_slim/pydantic_ai/_output.py | 97 ++++++++++++------- pydantic_ai_slim/pydantic_ai/result.py | 8 +- pydantic_ai_slim/pydantic_ai/tools.py | 83 +++++----------- tests/test_agent.py | 94 ++++++++++++++++-- 6 files changed, 229 insertions(+), 113 deletions(-) rename pydantic_ai_slim/pydantic_ai/{_pydantic.py => _function_schema.py} (82%) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index a65b177e24..c95bfb77ac 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -458,6 +458,7 @@ async def _handle_tool_calls( tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: output_schema = ctx.deps.output_schema + run_context = build_run_context(ctx) # first, look for the output tool call final_result: result.FinalResult[NodeRunEndT] | None = None @@ -465,7 +466,7 @@ async def _handle_tool_calls( if output_schema is not None: for call, output_tool in output_schema.find_tool(tool_calls): try: - result_data = await output_tool.process(call) + result_data = await output_tool.process(call, run_context) result_data = await _validate_output(result_data, ctx, call) except _output.ToolRetryError as e: # TODO: Should only increment retry stuff once per node execution, not for each tool call @@ -492,7 +493,6 @@ async def _handle_tool_calls( else: if tool_responses: parts.extend(tool_responses) - run_context = build_run_context(ctx) instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( _messages.ModelRequest(parts=parts, instructions=instructions) diff --git a/pydantic_ai_slim/pydantic_ai/_pydantic.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py similarity index 82% rename from pydantic_ai_slim/pydantic_ai/_pydantic.py rename to pydantic_ai_slim/pydantic_ai/_function_schema.py index 8f5f0ee73a..b76510e092 100644 --- a/pydantic_ai_slim/pydantic_ai/_pydantic.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -5,6 +5,9 @@ from __future__ import annotations as _annotations +import inspect +from collections.abc import Awaitable +from dataclasses import dataclass from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Callable, cast @@ -15,10 +18,12 @@ from pydantic.json_schema import GenerateJsonSchema from pydantic.plugin._schema_validator import create_schema_validator from pydantic_core import SchemaValidator, core_schema -from typing_extensions import TypedDict, get_origin +from typing_extensions import get_origin + +from pydantic_ai.tools import RunContext from ._griffe import doc_descriptions -from ._utils import check_object_json_schema, is_model_like +from ._utils import check_object_json_schema, is_model_like, run_in_executor if TYPE_CHECKING: from .tools import DocstringFormat, ObjectJsonSchema @@ -27,24 +32,53 @@ __all__ = ('function_schema',) -class FunctionSchema(TypedDict): +@dataclass +class FunctionSchema: """Internal information about a function schema.""" + function: Callable[..., Any] description: str validator: SchemaValidator json_schema: ObjectJsonSchema # if not None, the function takes a single by that name (besides potentially `info`) + takes_ctx: bool + is_async: bool single_arg_name: str | None positional_fields: list[str] var_positional_field: str | None + async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any: + args, kwargs = self._call_args(args_dict, ctx) + if self.is_async: + function = cast(Callable[[Any], Awaitable[str]], self.function) + return await function(*args, **kwargs) + else: + function = cast(Callable[[Any], str], self.function) + return await run_in_executor(function, *args, **kwargs) + + def _call_args( + self, + args_dict: dict[str, Any], + ctx: RunContext[Any], + ) -> tuple[list[Any], dict[str, Any]]: + if self.single_arg_name: + args_dict = {self.single_arg_name: args_dict} + + args = [ctx] if self.takes_ctx else [] + for positional_field in self.positional_fields: + args.append(args_dict.pop(positional_field)) # pragma: no cover + if self.var_positional_field: + args.extend(args_dict.pop(self.var_positional_field)) + + return args, args_dict + def function_schema( # noqa: C901 function: Callable[..., Any], - takes_ctx: bool, - docstring_format: DocstringFormat, - require_parameter_descriptions: bool, schema_generator: type[GenerateJsonSchema], + takes_ctx: bool | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, ) -> FunctionSchema: """Build a Pydantic validator and JSON schema from a tool function. @@ -58,6 +92,9 @@ def function_schema( # noqa: C901 Returns: A `FunctionSchema` instance. """ + if takes_ctx is None: + takes_ctx = _takes_ctx(function) + config = ConfigDict(title=function.__name__, use_attribute_docstrings=True) config_wrapper = ConfigWrapper(config) gen_schema = _generate_schema.GenerateSchema(config_wrapper) @@ -169,10 +206,13 @@ def function_schema( # noqa: C901 single_arg_name=single_arg_name, positional_fields=positional_fields, var_positional_field=var_positional_field, + takes_ctx=takes_ctx, + is_async=inspect.iscoroutinefunction(function), + function=function, ) -def takes_ctx(function: Callable[..., Any]) -> bool: +def _takes_ctx(function: Callable[..., Any]) -> bool: """Check if a function takes a `RunContext` first argument. Args: @@ -189,7 +229,7 @@ def takes_ctx(function: Callable[..., Any]) -> bool: else: type_hints = _typing_extra.get_function_type_hints(function) annotation = type_hints[first_param_name] - return annotation is not sig.empty and _is_call_ctx(annotation) + return True is not sig.empty and _is_call_ctx(annotation) def _build_schema( diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 5aa86d01ca..113a5aa2a1 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -6,11 +6,12 @@ from typing import Any, Callable, Generic, Literal, Union, cast from pydantic import TypeAdapter, ValidationError +from pydantic_core import SchemaValidator from typing_extensions import TypeAliasType, TypedDict, TypeVar, get_args, get_origin from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin -from . import _utils, messages as _messages +from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry from .tools import AgentDepsT, GenerateToolJsonSchema, ObjectJsonSchema, RunContext, ToolDefinition @@ -134,13 +135,16 @@ def __init__( self.strict = strict -# TODO: Comments to explain what's supported T_co = TypeVar('T_co', covariant=True) -OutputCallable = TypeAliasType('OutputCallable', Callable[..., Union[T_co, Awaitable[T_co]]], type_params=(T_co,)) -SimpleOutputType = TypeAliasType('SimpleOutputType', Union[type[T_co], OutputCallable[T_co]], type_params=(T_co,)) +# output_type=Type or output_type=function or output_type=object.method +SimpleOutputType = TypeAliasType( + 'SimpleOutputType', Union[type[T_co], Callable[..., Union[T_co, Awaitable[T_co]]]], type_params=(T_co,) +) +# output_type=ToolOutput() or SimpleOutputTypeOrMarker = TypeAliasType( 'SimpleOutputTypeOrMarker', Union[SimpleOutputType[T_co], ToolOutput[T_co]], type_params=(T_co,) ) +# output_type= or [, ...] OutputType = TypeAliasType( 'OutputType', Union[SimpleOutputTypeOrMarker[T_co], Sequence[SimpleOutputTypeOrMarker[T_co]]], type_params=(T_co,) ) @@ -273,7 +277,8 @@ class OutputObjectDefinition: @dataclass(init=False) class OutputObjectSchema(Generic[OutputDataT]): definition: OutputObjectDefinition - type_adapter: TypeAdapter[Any] + validator: SchemaValidator + function_schema: _function_schema.FunctionSchema | None = None outer_typed_dict_key: str | None = None def __init__( @@ -284,22 +289,31 @@ def __init__( description: str | None = None, strict: bool | None = None, ): - if inspect.isfunction(output_type) or inspect.ismethod(output_type) or _utils.is_model_like(output_type): - self.type_adapter = TypeAdapter(output_type) + if inspect.isfunction(output_type) or inspect.ismethod(output_type): + self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) + self.validator = self.function_schema.validator + json_schema = self.function_schema.json_schema else: - self.outer_typed_dict_key = 'response' - response_data_typed_dict = TypedDict( # noqa: UP013 - 'response_data_typed_dict', - {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm] + type_adapter: TypeAdapter[Any] + if _utils.is_model_like(output_type): + type_adapter = TypeAdapter(output_type) + else: + self.outer_typed_dict_key = 'response' + response_data_typed_dict = TypedDict( # noqa: UP013 + 'response_data_typed_dict', + {'response': cast(type[OutputDataT], output_type)}, # pyright: ignore[reportInvalidTypeForm] + ) + type_adapter = TypeAdapter(response_data_typed_dict) + + # Really a PluggableSchemaValidator, but it's API-compatible + self.validator = cast(SchemaValidator, type_adapter.validator) + json_schema = _utils.check_object_json_schema( + type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) - self.type_adapter = TypeAdapter(response_data_typed_dict) - json_schema = _utils.check_object_json_schema( - self.type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) - ) - if self.outer_typed_dict_key: - # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM - json_schema.pop('title') + if self.outer_typed_dict_key: + # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM + json_schema.pop('title') if json_schema_description := json_schema.pop('description', None): if description is None: @@ -315,12 +329,17 @@ def __init__( ) async def process( - self, data: str | dict[str, Any], allow_partial: bool = False, wrap_validation_errors: bool = True + self, + data: str | dict[str, Any], + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, ) -> OutputDataT: """Process an output message, performing validation and (if necessary) calling the output function. Args: data: The output data to validate. + run_context: The current run context. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -328,12 +347,11 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - # TODO: Inject RunContext? pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' if isinstance(data, str): - output = self.type_adapter.validate_json(data, experimental_allow_partial=pyd_allow_partial) + output = self.validator.validate_json(data, allow_partial=pyd_allow_partial) else: - output = self.type_adapter.validate_python(data, experimental_allow_partial=pyd_allow_partial) + output = self.validator.validate_python(data, allow_partial=pyd_allow_partial) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -342,19 +360,19 @@ async def process( raise ToolRetryError(m) from e else: raise - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart(content=r.message) - raise ToolRetryError(m) from r - else: - raise - else: - if k := self.outer_typed_dict_key: - output = output[k] - # A non-awaitable callable output_type will already have been executed by the TypeAdapter's validation method - if inspect.isawaitable(output): - output = await output + if self.function_schema: + try: + output = await self.function_schema.call(output, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart(content=r.message) + raise ToolRetryError(m) from r + else: + raise + + if k := self.outer_typed_dict_key: + output = output[k] return output @@ -382,12 +400,17 @@ def __init__(self, *, name: str, parameters_schema: OutputObjectSchema[OutputDat ) async def process( - self, tool_call: _messages.ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True + self, + tool_call: _messages.ToolCallPart, + run_context: RunContext[AgentDepsT], + allow_partial: bool = False, + wrap_validation_errors: bool = True, ) -> OutputDataT: - """Validate an output message. + """Process an output message. Args: tool_call: The tool call from the LLM to validate. + run_context: The current run context. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -396,7 +419,7 @@ async def process( """ try: output = await self.parameters_schema.process( - tool_call.args, allow_partial=allow_partial, wrap_validation_errors=False + tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False ) except ValidationError as e: if wrap_validation_errors: diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 7af756ee53..5b833fc672 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -88,7 +88,9 @@ async def _validate_response( ) call, output_tool = match - result_data = await output_tool.process(call, allow_partial=allow_partial, wrap_validation_errors=False) + result_data = await output_tool.process( + call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) # The following cast is safe because we know `str` is an allowed output type @@ -396,7 +398,9 @@ async def validate_structured_output( ) call, output_tool = match - result_data = await output_tool.process(call, allow_partial=allow_partial, wrap_validation_errors=False) + result_data = await output_tool.process( + call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + ) else: text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = cast(OutputDataT, text) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 584d77e02a..df07de7678 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,19 +1,18 @@ from __future__ import annotations as _annotations import dataclasses -import inspect import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union from opentelemetry.trace import Tracer from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue -from pydantic_core import SchemaValidator, core_schema +from pydantic_core import core_schema from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar -from . import _pydantic, _utils, messages as _messages +from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry, UnexpectedModelBehavior if TYPE_CHECKING: @@ -207,12 +206,7 @@ class Tool(Generic[AgentDepsT]): docstring_format: DocstringFormat require_parameter_descriptions: bool strict: bool | None - _is_async: bool = field(init=False) - _single_arg_name: str | None = field(init=False) - _positional_fields: list[str] = field(init=False) - _var_positional_field: str | None = field(init=False) - _validator: SchemaValidator = field(init=False, repr=False) - _base_parameters_json_schema: ObjectJsonSchema = field(init=False) + function_schema: _function_schema.FunctionSchema """ The base JSON schema for the tool's parameters. @@ -236,6 +230,7 @@ def __init__( require_parameter_descriptions: bool = False, schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, strict: bool | None = None, + function_schema: _function_schema.FunctionSchema | None = None, ): """Create a new tool instance. @@ -288,28 +283,24 @@ async def prep_my_tool( schema_generator: The JSON schema generator class to use. Defaults to `GenerateToolJsonSchema`. strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + function_schema: The function schema to use for the tool. If not provided, it will be generated. """ - if takes_ctx is None: - takes_ctx = _pydantic.takes_ctx(function) - - f = _pydantic.function_schema( - function, takes_ctx, docstring_format, require_parameter_descriptions, schema_generator - ) self.function = function - self.takes_ctx = takes_ctx + self.function_schema = function_schema or _function_schema.function_schema( + function, + schema_generator, + takes_ctx=takes_ctx, + docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions, + ) + self.takes_ctx = self.function_schema.takes_ctx self.max_retries = max_retries self.name = name or function.__name__ - self.description = description or f['description'] + self.description = description or self.function_schema.description self.prepare = prepare self.docstring_format = docstring_format self.require_parameter_descriptions = require_parameter_descriptions self.strict = strict - self._is_async = inspect.iscoroutinefunction(self.function) - self._single_arg_name = f['single_arg_name'] - self._positional_fields = f['positional_fields'] - self._var_positional_field = f['var_positional_field'] - self._validator = f['validator'] - self._base_parameters_json_schema = f['json_schema'] async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. @@ -323,7 +314,7 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition tool_def = ToolDefinition( name=self.name, description=self.description, - parameters_json_schema=self._base_parameters_json_schema, + parameters_json_schema=self.function_schema.json_schema, strict=self.strict, ) if self.prepare is not None: @@ -365,21 +356,22 @@ async def _run( self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: try: + validator = self.function_schema.validator if isinstance(message.args, str): - args_dict = self._validator.validate_json(message.args or '{}') + args_dict = validator.validate_json(message.args or '{}') else: - args_dict = self._validator.validate_python(message.args) + args_dict = validator.validate_python(message.args) except ValidationError as e: return self._on_error(e, message) - args, kwargs = self._call_args(args_dict, message, run_context) + ctx = dataclasses.replace( + run_context, + retry=self.current_retry, + tool_name=message.tool_name, + tool_call_id=message.tool_call_id, + ) try: - if self._is_async: - function = cast(Callable[[Any], Awaitable[str]], self.function) - response_content = await function(*args, **kwargs) - else: - function = cast(Callable[[Any], str], self.function) - response_content = await _utils.run_in_executor(function, *args, **kwargs) + response_content = await self.function_schema.call(args_dict, ctx) except ModelRetry as e: return self._on_error(e, message) @@ -390,29 +382,6 @@ async def _run( tool_call_id=message.tool_call_id, ) - def _call_args( - self, - args_dict: dict[str, Any], - message: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - ) -> tuple[list[Any], dict[str, Any]]: - if self._single_arg_name: - args_dict = {self._single_arg_name: args_dict} - - ctx = dataclasses.replace( - run_context, - retry=self.current_retry, - tool_name=message.tool_name, - tool_call_id=message.tool_call_id, - ) - args = [ctx] if self.takes_ctx else [] - for positional_field in self._positional_fields: - args.append(args_dict.pop(positional_field)) # pragma: no cover - if self._var_positional_field: - args.extend(args_dict.pop(self._var_positional_field)) - - return args, args_dict - def _on_error( self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart ) -> _messages.RetryPromptPart: diff --git a/tests/test_agent.py b/tests/test_agent.py index 2028ad3b7e..5492325126 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -564,7 +564,46 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, - 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_function_with_run_context(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(ctx: RunContext[None], city: str) -> Weather: + assert ctx is not None + return Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, @@ -604,7 +643,48 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, - 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + ) + ] + ) + + +def test_output_type_bound_instance_method_with_run_context(): + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, ctx: RunContext[None], city: str) -> Self: + assert ctx is not None + return self + + weather = Weather(temperature=28.7, description='sunny') + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent(FunctionModel(call_tool), output_type=weather.get_weather) + result = agent.run_sync('Mexico City') + assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='The final response which ends this conversation', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, @@ -694,7 +774,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) -def test_output_type_async_callable(): +def test_output_type_async_function(): class Weather(BaseModel): temperature: float description: str @@ -723,7 +803,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, - 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, @@ -761,7 +841,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: description='The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, - 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, @@ -799,7 +879,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: description='get_weather: The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, - 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, @@ -958,7 +1038,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: description='get_weather: The final response which ends this conversation', parameters_json_schema={ 'additionalProperties': False, - 'properties': {'city': {'title': 'City', 'type': 'string'}}, + 'properties': {'city': {'type': 'string'}}, 'required': ['city'], 'type': 'object', }, From 60d789e98591da1efb03ca38528e2eab44154930 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 21 May 2025 22:13:34 +0000 Subject: [PATCH 07/70] Improve test coverage --- pydantic_ai_slim/pydantic_ai/_output.py | 34 ++++++------------------- tests/test_agent.py | 27 ++++++++++++++++++++ 2 files changed, 35 insertions(+), 26 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 113a5aa2a1..592b8dda34 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -346,30 +346,14 @@ async def process( Returns: Either the validated output data (left) or a retry message (right). """ - try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(data, str): - output = self.validator.validate_json(data, allow_partial=pyd_allow_partial) - else: - output = self.validator.validate_python(data, allow_partial=pyd_allow_partial) - except ValidationError as e: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - content=e.errors(include_url=False), - ) - raise ToolRetryError(m) from e - else: - raise + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + output = self.validator.validate_json(data, allow_partial=pyd_allow_partial) + else: + output = self.validator.validate_python(data, allow_partial=pyd_allow_partial) if self.function_schema: - try: - output = await self.function_schema.call(output, run_context) - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart(content=r.message) - raise ToolRetryError(m) from r - else: - raise + output = await self.function_schema.call(output, run_context) if k := self.outer_typed_dict_key: output = output[k] @@ -418,9 +402,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - output = await self.parameters_schema.process( - tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False - ) + output = await self.parameters_schema.process(tool_call.args, run_context, allow_partial=allow_partial) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -440,7 +422,7 @@ async def process( ) raise ToolRetryError(m) from r else: - raise + raise # pragma: lax no cover else: return output diff --git a/tests/test_agent.py b/tests/test_agent.py index 5492325126..c272225df0 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -535,6 +535,33 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: assert got_tool_call_name == snapshot('final_result_Bar') +def test_output_type_with_two_descriptions(): + class MyOutput(BaseModel): + """Description from docstring""" + + valid: bool + + m = TestModel() + agent = Agent(m, output_type=ToolOutput(type_=MyOutput, description='Description from ToolOutput')) + result = agent.run_sync('Hello') + assert result.output == snapshot(MyOutput(valid=False)) + assert m.last_model_request_parameters is not None + assert m.last_model_request_parameters.output_tools == snapshot( + [ + ToolDefinition( + name='final_result', + description='Description from ToolOutput. Description from docstring', + parameters_json_schema={ + 'properties': {'valid': {'type': 'boolean'}}, + 'required': ['valid'], + 'title': 'MyOutput', + 'type': 'object', + }, + ) + ] + ) + + def test_output_type_function(): class Weather(BaseModel): temperature: float From 14e69d0554e0c93ab06a985a6e2bcfdc40bf63bc Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 22 May 2025 21:34:19 +0000 Subject: [PATCH 08/70] Start output tool name disambiguation counter at 2 --- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- tests/test_examples.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 592b8dda34..a447c01f63 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -218,8 +218,8 @@ def build( i = 1 original_tool_name = tool_name while tool_name in tools: - tool_name = f'{original_tool_name}_{i}' i += 1 + tool_name = f'{original_tool_name}_{i}' parameters_schema = OutputObjectSchema( output_type=arg, description=tool_description, strict=tool_strict diff --git a/tests/test_examples.py b/tests/test_examples.py index b1f4986805..ff74ff80cb 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -322,7 +322,7 @@ async def list_tools() -> list[None]: args={'response': ['red', 'blue', 'green']}, ), 'square size 10, circle size 20, triangle size 30': ToolCallPart( - tool_name='final_result_list_1', + tool_name='final_result_list_2', args={'response': [10, 20, 30]}, ), 'get me users who were last active yesterday.': ToolCallPart( From ee95a80d24b8596e6b76db0efe08fa86ee617a7d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 22 May 2025 21:43:56 +0000 Subject: [PATCH 09/70] Stop requiring explicitly specifying type_ kwarg name on ToolOutput --- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- tests/test_agent.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index a447c01f63..cb737ee930 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -121,8 +121,8 @@ class ToolOutput(Generic[OutputDataT]): def __init__( self, - *, type_: SimpleOutputType[OutputDataT], + *, name: str | None = None, description: str | None = None, max_retries: int | None = None, diff --git a/tests/test_agent.py b/tests/test_agent.py index c272225df0..e58ed67170 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -448,7 +448,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: [ pytest.param('OutputType = Union[Foo, Bar]'), pytest.param('OutputType = [Foo, Bar]'), - pytest.param('OutputType = ToolOutput(type_=Union[Foo, Bar])'), + pytest.param('OutputType = ToolOutput(Union[Foo, Bar])'), pytest.param('OutputType = Foo | Bar', marks=pytest.mark.skipif(sys.version_info < (3, 10), reason='3.10+')), pytest.param( 'OutputType: TypeAlias = Foo | Bar', @@ -542,7 +542,7 @@ class MyOutput(BaseModel): valid: bool m = TestModel() - agent = Agent(m, output_type=ToolOutput(type_=MyOutput, description='Description from ToolOutput')) + agent = Agent(m, output_type=ToolOutput(MyOutput, description='Description from ToolOutput')) result = agent.run_sync('Hello') assert result.output == snapshot(MyOutput(valid=False)) assert m.last_model_request_parameters is not None @@ -858,7 +858,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: args_json = '{"city": "Mexico City"}' return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) - agent = Agent(FunctionModel(call_tool), output_type=ToolOutput(type_=get_weather, name='get_weather')) + agent = Agent(FunctionModel(call_tool), output_type=ToolOutput(get_weather, name='get_weather')) result = agent.run_sync('Mexico City') assert result.output == snapshot(Weather(temperature=28.7, description='sunny')) assert output_tools == snapshot( @@ -1052,8 +1052,8 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: agent = Agent( FunctionModel(call_tool), output_type=[ - ToolOutput(type_=get_weather, name='get_weather'), - ToolOutput(type_=Weather, name='return_weather'), + ToolOutput(get_weather, name='get_weather'), + ToolOutput(Weather, name='return_weather'), ], ) result = agent.run_sync('Mexico City') From f1f093e2f94497ab37e0cc6fc002134e1acb7cb2 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 22 May 2025 22:02:26 +0000 Subject: [PATCH 10/70] Remove runtime assertion from typed_agent.py as the file is only typechecked, not executed --- tests/typed_agent.py | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 676cb34229..9668c49b72 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -1,7 +1,6 @@ """This file is used to test static typing, it's analyzed with pyright and mypy.""" -from collections.abc import Awaitable, Iterator -from contextlib import contextmanager +from collections.abc import Awaitable from dataclasses import dataclass from typing import Callable, TypeAlias, Union @@ -37,16 +36,6 @@ def system_prompt_ok2() -> str: assert_type(system_prompt_ok2, Callable[[], str]) -@contextmanager -def expect_error(error_type: type[Exception]) -> Iterator[None]: - try: - yield None - except Exception as e: - assert isinstance(e, error_type), f'Expected {error_type}, got {type(e)}' - else: - raise AssertionError('Expected an error') - - @typed_agent.tool async def ok_tool(ctx: RunContext[MyDeps], x: str) -> str: assert_type(ctx.deps, MyDeps) @@ -108,13 +97,6 @@ async def bad_tool2(ctx: RunContext[int], x: str) -> str: return f'{x} {ctx.deps}' -with expect_error(ValueError): - - @typed_agent.tool # type: ignore[arg-type] - async def bad_tool3(x: str) -> str: - return x - - @typed_agent.output_validator def ok_validator_simple(data: str) -> str: return data From 66e54055c2cccd9b5e2a10f7b1b88eff44d59829 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 22 May 2025 22:15:42 +0000 Subject: [PATCH 11/70] Add typing tests for Agent(output_type=) --- pydantic_ai_slim/pydantic_ai/_output.py | 2 +- tests/typed_agent.py | 24 ++++++++++++++++++++++-- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index cb737ee930..80a45f9e6d 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -138,7 +138,7 @@ def __init__( T_co = TypeVar('T_co', covariant=True) # output_type=Type or output_type=function or output_type=object.method SimpleOutputType = TypeAliasType( - 'SimpleOutputType', Union[type[T_co], Callable[..., Union[T_co, Awaitable[T_co]]]], type_params=(T_co,) + 'SimpleOutputType', Union[type[T_co], Callable[..., T_co], Callable[..., Awaitable[T_co]]], type_params=(T_co,) ) # output_type=ToolOutput() or SimpleOutputTypeOrMarker = TypeAliasType( diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 9668c49b72..3315b5c972 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -7,6 +7,7 @@ from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool +from pydantic_ai._output import ToolOutput from pydantic_ai.agent import AgentRunResult from pydantic_ai.tools import ToolDefinition @@ -169,8 +170,27 @@ def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> str: return f'{x} {y}' -def foobar_plain(x: str, y: int) -> str: - return f'{x} {y}' +async def foobar_plain(x: int, y: int) -> int: + return x * y + + +class MyClass: + def my_method(self) -> bool: + return True + + +agent = Agent(output_type=foobar_ctx) +assert_type(agent, Agent[None, str]) + +agent = Agent(output_type=foobar_plain) +assert_type(agent, Agent[None, int]) + +agent = Agent(output_type=MyClass().my_method) +assert_type(agent, Agent[None, bool]) + +marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore[call-overload] +agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker]) +assert_type(agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) Tool(foobar_ctx, takes_ctx=True) From 3aad6fca187fdb1d33a1551b52e191b863170ca0 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 14:11:05 +0000 Subject: [PATCH 12/70] Update typing test for mypy --- tests/typed_agent.py | 40 ++++++++++++++++++++++++++++++---------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 3315b5c972..9885643282 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -11,6 +11,9 @@ from pydantic_ai.agent import AgentRunResult from pydantic_ai.tools import ToolDefinition +# Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True +MYPY = False + @dataclass class MyDeps: @@ -179,18 +182,36 @@ def my_method(self) -> bool: return True -agent = Agent(output_type=foobar_ctx) -assert_type(agent, Agent[None, str]) +if MYPY: + # mypy requires the generic parameters to be specified explicitly to figure out what's going on here + str_function_agent = Agent[None, str](output_type=foobar_ctx) + assert_type(str_function_agent, Agent[None, str]) + + int_function_agent = Agent[None, int](output_type=foobar_plain) + assert_type(int_function_agent, Agent[None, int]) -agent = Agent(output_type=foobar_plain) -assert_type(agent, Agent[None, int]) + bool_method_agent = Agent[None, bool](output_type=MyClass().my_method) + assert_type(bool_method_agent, Agent[None, bool]) -agent = Agent(output_type=MyClass().my_method) -assert_type(agent, Agent[None, bool]) + marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore + complex_output_agent = Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]( + output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker] + ) + assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) +else: + # pyright is able to correctly infer the output type here + str_function_agent = Agent(output_type=foobar_ctx) + assert_type(str_function_agent, Agent[None, str]) -marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore[call-overload] -agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker]) -assert_type(agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) + int_function_agent = Agent(output_type=foobar_plain) + assert_type(int_function_agent, Agent[None, int]) + + bool_method_agent = Agent(output_type=MyClass().my_method) + assert_type(bool_method_agent, Agent[None, bool]) + + marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore + complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker]) + assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) Tool(foobar_ctx, takes_ctx=True) @@ -237,7 +258,6 @@ async def prepare_greet(ctx: RunContext[str], tool_def: ToolDefinition) -> ToolD result = greet_agent.run_sync('testing...', deps='human') assert result.output == '{"greet":"hello a"}' -MYPY = False if not MYPY: default_agent = Agent() assert_type(default_agent, Agent[None, str]) From 3ff6e7478015ca57c40d4a21035c384fe13ab4e3 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 14:13:03 +0000 Subject: [PATCH 13/70] Treat str in an output_type list the same as in a union --- pydantic_ai_slim/pydantic_ai/_output.py | 9 +++++++-- tests/test_agent.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index ad76d49b72..23b0ad8092 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -173,15 +173,19 @@ def build( return None multiple = False + allow_text_output = False output_types_or_markers: Sequence[SimpleOutputTypeOrMarker[OutputDataT]] if isinstance(output_type, Sequence): output_types_or_markers = output_type - multiple = True + if str in output_types_or_markers: + allow_text_output = True + output_types_or_markers = [t for t in output_types_or_markers if t is not str] + if len(output_types_or_markers) > 1: + multiple = True else: output_types_or_markers = [output_type] - allow_text_output = False tools: dict[str, OutputTool[OutputDataT]] = {} for output_type_or_marker in output_types_or_markers: tool_name = name @@ -293,6 +297,7 @@ def __init__( self.function_schema = _function_schema.function_schema(output_type, GenerateToolJsonSchema) self.validator = self.function_schema.validator json_schema = self.function_schema.json_schema + json_schema['description'] = self.function_schema.description else: type_adapter: TypeAdapter[Any] if _utils.is_model_like(output_type): diff --git a/tests/test_agent.py b/tests/test_agent.py index 1b1ba1388a..aea538c590 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -390,8 +390,8 @@ def test_response_tuple(): @pytest.mark.parametrize( 'input_union_callable', - [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str], - ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str'], + [lambda: Union[str, Foo], lambda: Union[Foo, str], lambda: str | Foo, lambda: Foo | str, lambda: [Foo, str]], + ids=['Union[str, Foo]', 'Union[Foo, str]', 'str | Foo', 'Foo | str', '[Foo, str]'], ) def test_response_union_allow_str(input_union_callable: Callable[[], Any]): try: From a7fd8acd4fb73f2a30bd236b6480fdf863ca75f7 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 17:47:34 +0000 Subject: [PATCH 14/70] Set ToolRetryError as cause on UnexpectedModelBehavior when available --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index ae2a57aee6..bf83971742 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -64,12 +64,14 @@ class GraphAgentState: retries: int run_step: int - def increment_retries(self, max_result_retries: int) -> None: + def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None: self.retries += 1 if self.retries > max_result_retries: - raise exceptions.UnexpectedModelBehavior( - f'Exceeded maximum retries ({max_result_retries}) for result validation' - ) + message = f'Exceeded maximum retries ({max_result_retries}) for result validation' + if error: + raise exceptions.UnexpectedModelBehavior(message) from error + else: + raise exceptions.UnexpectedModelBehavior(message) @dataclasses.dataclass @@ -484,7 +486,7 @@ async def _handle_tool_calls( except _output.ToolRetryError as e: # TODO: Should only increment retry stuff once per node execution, not for each tool call # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries) + ctx.state.increment_retries(ctx.deps.max_result_retries, e) parts.append(e.tool_retry) else: final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) @@ -545,7 +547,7 @@ async def _handle_text_response( result_data = await _validate_output(result_data, ctx, None) except _output.ToolRetryError as e: - ctx.state.increment_retries(ctx.deps.max_result_retries) + ctx.state.increment_retries(ctx.deps.max_result_retries, e) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: return self._handle_final_result(ctx, result.FinalResult(result_data, None, None), []) From 5399df297eb3a34e103135b5c1bd6c8ebff1b572 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 17:51:37 +0000 Subject: [PATCH 15/70] Drop end line from example test parameterized test ID to make it easier to rerun when the length of the example changes --- tests/test_examples.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_examples.py b/tests/test_examples.py index ff74ff80cb..468d42d169 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -72,8 +72,12 @@ def find_filter_examples() -> Iterable[ParameterSet]: for ex in find_examples('docs', 'pydantic_ai_slim', 'pydantic_graph', 'pydantic_evals'): if ex.path.name != '_utils.py': + try: + path = ex.path.relative_to(Path.cwd()) + except ValueError: + path = ex.path + test_id = f'{path}:{ex.start_line}' prefix_settings = ex.prefix_settings() - test_id = str(ex) if opt_title := prefix_settings.get('title'): test_id += f':{opt_title}' yield pytest.param(ex, id=test_id) From e503edb0784e2dbdd3d1c35f98df9f15ad7e4044 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 18:03:00 +0000 Subject: [PATCH 16/70] Document output functions --- docs/multi-agent-applications.md | 1 + docs/output.md | 148 ++++++++++++++++++++++++++++--- docs/tools.md | 4 +- tests/test_examples.py | 89 +++++++++++++++++++ 4 files changed, 230 insertions(+), 12 deletions(-) diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 33299a7766..7a2ce36f37 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -12,6 +12,7 @@ Of course, you can combine multiple strategies in a single application. ## Agent delegation "Agent delegation" refers to the scenario where an agent delegates work to another agent, then takes back control when the delegate agent (the agent called from within a tool) finishes. +If you want to hand off control to another agent completely, without coming back to the first agent, you can use an [output function](output.md#output-functions). Since agents are stateless and designed to be global, you do not need to include the agent itself in agent [dependencies](dependencies.md). diff --git a/docs/output.md b/docs/output.md index 09eeb27580..19e7aae5c2 100644 --- a/docs/output.md +++ b/docs/output.md @@ -1,4 +1,4 @@ -"Output" refers to the final value returned from [running an agent](agents.md#running-agents) these can be either plain text or structured data. +"Output" refers to the final value returned from [running an agent](agents.md#running-agents). This can be either plain text, [structured data](#structured-output), or the result of a [function](#output-functions) called with arguments provided by the model. The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) @@ -25,27 +25,29 @@ print(result.usage()) _(This example is complete, it can be run "as is")_ -Runs end when either a plain text response is received or the model calls a tool associated with one of the structured output types (run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits)). +A run ends when a plain text response is received (assuming no output type is specified or `str` is one of the allowed options), or when the model responds with one of the structured output types by calling a special output tool. A run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits). ## Output data {#structured-output} -When the output type is `str`, or a union including `str`, plain text responses are enabled on the model, and the raw text response from the model is used as the response data. +When no output type is specified, or when the output type is `str` or a union or list of types including `str`, the model is allowed to respond with plain text, and this text is used as the output data. +If `str` is not among the allowed output types, the model is not allowed to respond with plain text and is forced to use an output tool to return structured data. -If the output type is a union with multiple members (after removing `str` from the members), each member is registered as a separate tool with the model in order to reduce the complexity of the tool schemas and maximise the chances a model will respond correctly. +If the output type is a union or list with multiple members, each member (except for `str`, if it is a member) is registered with the model as a separate output tool in order to reduce the complexity of the tool schemas and maximise the chances a model will respond correctly. If the output type schema is not of type `"object"` (e.g. it's `int` or `list[int]`), the output type is wrapped in a single element object, so the schema of all tools registered with the model are object schemas. Structured outputs (like tools) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model. !!! note "Bring on PEP-747" - Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands, unions are not valid as `type`s in Python. + Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands, type checkers will not consider unions a valid value for `output_type`, even though PydanticAI supports them. - When creating the agent we need to `# type: ignore` the `output_type` argument, and add a type hint to tell type checkers about the type of the agent. + To work around this, we can use a list of types instead of a union, which is supported by type checkers. -Here's an example of returning either text or a structured value + Alternatively, we can add `# type: ignore` to the `output_type` argument when creating the agent, and add an explicit type hint to the agent's variable to inform the type checker. This is shown in the second example. + +Here's an example of returning either text or a structured value: ```python {title="box_or_error.py"} -from typing import Union from pydantic import BaseModel @@ -59,9 +61,9 @@ class Box(BaseModel): units: str -agent: Agent[None, Union[Box, str]] = Agent( +agent = Agent( 'openai:gpt-4o-mini', - output_type=Union[Box, str], # type: ignore + output_type=[Box, str], system_prompt=( "Extract me the dimensions of a box, " "if you can't extract all data, ask the user to try again." @@ -103,10 +105,134 @@ print(result.output) _(This example is complete, it can be run "as is")_ -### Output validator functions +### Output functions + +Instead of plain text or structured data, you may want the output of your agent run to be the result of a function called with arguments provided by the model, for example to further process or validate the data provided through the arguments (with the option to tell the model to try again), or to hand off to another agent. + +Output functions are similar to [function tools](tools.md), but the model is forced to call one of them, the call ends the agent run, and the result is not passed back to the model. + +As with tool functions, output function arguments provided by the model are validated using Pydantic, they can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and they can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with updated arguments. + +To specify output functions, you set the agent's `output_type` to either a single function (or bound instance method), or a list of functions. The list can also contain other output types like simple scalars or entire Pydantic models. +You typically do not want to also register your output function as a tool (using the `@agent.tool` decorator or `tools` argument), as this could confuse the model about which it should be calling. + +Here's an example of all of these features in action: + +```python {title="output_functions.py"} +import re + +from pydantic import BaseModel + +from pydantic_ai import Agent, ModelRetry, RunContext +from pydantic_ai._output import ToolRetryError +from pydantic_ai.exceptions import UnexpectedModelBehavior + + +class Row(BaseModel): + name: str + country: str + + +tables = { + 'capital_cities': [ + Row(name='Amsterdam', country='Netherlands'), + Row(name='Mexico City', country='Mexico'), + ] +} + + +class SQLFailure(BaseModel): + """An unrecoverable failure. Only use this when you can't change the query to make it work.""" + + explanation: str + + +def run_sql_query(query: str) -> list[Row]: + """Run a SQL query on the database.""" + + select_table = re.match(r'SELECT (.+) FROM (\w+)', query) + if select_table: + column_names = select_table.group(1) + if column_names != '*': + raise ModelRetry("Only 'SELECT *' is supported, you'll have to do column filtering manually.") + + table_name = select_table.group(2) + if table_name not in tables: + raise ModelRetry( + f"Unknown table '{table_name}' in query '{query}'. Available tables: {', '.join(tables.keys())}." + ) + + return tables[table_name] + + raise ModelRetry(f"Unsupported query: '{query}'.") + + +sql_agent: Agent[None, list[Row] | SQLFailure] = Agent( + 'openai:gpt-4o', + output_type=[run_sql_query, SQLFailure], + instructions='You are a SQL agent that can run SQL queries on a database.', +) + + +async def hand_off_to_sql_agent(ctx: RunContext, query: str) -> list[Row]: + """I take natural language queries, turn them into SQL, and run them on a database.""" + + # Drop the final message with the output tool call, as it shouldn't be passed on to the SQL agent + messages = ctx.messages[:-1] + try: + result = await sql_agent.run(query, message_history=messages) + output = result.output + if isinstance(output, SQLFailure): + raise ModelRetry(f'SQL agent failed: {output.explanation}') + return output + except UnexpectedModelBehavior as e: + # Bubble up potentially retryable errors to the router agent + if (cause := e.__cause__) and isinstance(cause, ToolRetryError): + raise ModelRetry(f'SQL agent failed: {cause.tool_retry.content}') from e + else: + raise + + +class RouterFailure(BaseModel): + """Use me when no appropriate agent is found or the used agent failed.""" + + explanation: str + + +router_agent: Agent[None, float | list[Row] | RouterFailure] = Agent( + 'openai:gpt-4o', + output_type=[hand_off_to_sql_agent, RouterFailure], + instructions='You are a router to other agents. Never try to solve a problem yourself, just pass it on.', +) + +result = router_agent.run_sync('Select the names and countries of all capitals') +print(result.output) +""" +[ + Row(name='Amsterdam', country='Netherlands'), + Row(name='Mexico City', country='Mexico'), +] +""" + +result = router_agent.run_sync('Select all pets') +print(result.output) +""" +explanation = "The requested table 'pets' does not exist in the database. The only available table is 'capital_cities', which does not contain data about pets." +""" + +result = router_agent.run_sync('How do I fly from Amsterdam to Mexico City?') +print(result.output) +""" +explanation = 'I am not equipped to provide travel information, such as flights from Amsterdam to Mexico City.' +""" +``` + +### Output validators Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. +If you want to implement separate validation logic for different output types, it's recommended to use [output functions](#output-functions) instead, to save you from having to do `isinstance` checks inside the output validator. + Here's a simplified variant of the [SQL Generation example](examples/sql-gen.md): ```python {title="sql_gen.py"} diff --git a/docs/tools.md b/docs/tools.md index 33bfdeb3a0..921edd95e1 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -2,7 +2,9 @@ Function tools provide a mechanism for models to retrieve extra information to help them generate a response. -They're useful when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. +They're useful when you want to enable the model to take some action and use the result, when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. + +If you want a model to be able to call a function as its final action, without the result being sent back to the model, you can use an [output function](output.md#output-functions) instead. !!! info "Function tools vs. RAG" Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. diff --git a/tests/test_examples.py b/tests/test_examples.py index 468d42d169..ad377bedbf 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -405,6 +405,32 @@ async def list_tools() -> list[None]: args={'numerator': '123', 'denominator': '456'}, tool_call_id='pyd_ai_2e0e396768a14fe482df90a29a78dc7b', ), + 'Select the names and countries of all capitals': ToolCallPart( + tool_name='final_result_hand_off_to_sql_agent', + args={'query': 'SELECT name, country FROM capitals;'}, + ), + 'SELECT name, country FROM capitals;': ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT name, country FROM capitals;'}, + ), + 'SELECT * FROM capital_cities;': ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT * FROM capital_cities;'}, + ), + 'Select all pets': ToolCallPart( + tool_name='final_result_hand_off_to_sql_agent', + args={'query': 'SELECT * FROM pets;'}, + ), + 'SELECT * FROM pets;': ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT * FROM pets;'}, + ), + 'How do I fly from Amsterdam to Mexico City?': ToolCallPart( + tool_name='final_result_RouterFailure', + args={ + 'explanation': 'I am not equipped to provide travel information, such as flights from Amsterdam to Mexico City.' + }, + ), } tool_responses: dict[tuple[str, str], str] = { @@ -586,6 +612,69 @@ async def model_logic( # noqa: C901 return ModelResponse( parts=[ToolCallPart(tool_name='get_document', args={}, tool_call_id='pyd_ai_tool_call_id')] ) + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_run_sql_query' + and m.content == "Only 'SELECT *' is supported, you'll have to do column filtering manually." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_run_sql_query', + args={'query': 'SELECT * FROM capitals;'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_hand_off_to_sql_agent' + and m.content + == "SQL agent failed: Unknown table 'capitals' in query 'SELECT * FROM capitals;'. Available tables: capital_cities." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_hand_off_to_sql_agent', + args={'query': 'SELECT * FROM capital_cities;'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_run_sql_query' + and m.content == "Unknown table 'pets' in query 'SELECT * FROM pets;'. Available tables: capital_cities." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_SQLFailure', + args={ + 'explanation': "The table 'pets' does not exist in the database. Only the table 'capital_cities' is available." + }, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) + # SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available. + elif ( + isinstance(m, RetryPromptPart) + and m.tool_name == 'final_result_hand_off_to_sql_agent' + and m.content + == "SQL agent failed: The table 'pets' does not exist in the database. Only the table 'capital_cities' is available." + ): + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result_RouterFailure', + args={ + 'explanation': "The requested table 'pets' does not exist in the database. The only available table is 'capital_cities', which does not contain data about pets." + }, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') From f43679729a7b53eb13916469cd8e2e0bb339438d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 18:24:44 +0000 Subject: [PATCH 17/70] Fix docs --- docs/output.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/output.md b/docs/output.md index 19e7aae5c2..fd903a7e3b 100644 --- a/docs/output.md +++ b/docs/output.md @@ -120,6 +120,7 @@ Here's an example of all of these features in action: ```python {title="output_functions.py"} import re +from typing import Union from pydantic import BaseModel @@ -167,7 +168,7 @@ def run_sql_query(query: str) -> list[Row]: raise ModelRetry(f"Unsupported query: '{query}'.") -sql_agent: Agent[None, list[Row] | SQLFailure] = Agent( +sql_agent: Agent[None, Union[list[Row], SQLFailure]] = Agent( 'openai:gpt-4o', output_type=[run_sql_query, SQLFailure], instructions='You are a SQL agent that can run SQL queries on a database.', @@ -199,7 +200,7 @@ class RouterFailure(BaseModel): explanation: str -router_agent: Agent[None, float | list[Row] | RouterFailure] = Agent( +router_agent: Agent[None, Union[list[Row], RouterFailure]] = Agent( 'openai:gpt-4o', output_type=[hand_off_to_sql_agent, RouterFailure], instructions='You are a router to other agents. Never try to solve a problem yourself, just pass it on.', @@ -227,7 +228,7 @@ explanation = 'I am not equipped to provide travel information, such as flights """ ``` -### Output validators +### Output validators {#output-validator-functions} Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. From 046b81356db852750ea5d03715be61ab199b62e4 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 19:15:06 +0000 Subject: [PATCH 18/70] Update output_type typing tests --- tests/typed_agent.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 9885643282..eeaf8fbd3a 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -182,16 +182,22 @@ def my_method(self) -> bool: return True +str_function_agent = Agent(output_type=foobar_ctx) +assert_type(str_function_agent, Agent[None, str]) + +bool_method_agent = Agent(output_type=MyClass().my_method) +assert_type(bool_method_agent, Agent[None, bool]) + if MYPY: - # mypy requires the generic parameters to be specified explicitly to figure out what's going on here - str_function_agent = Agent[None, str](output_type=foobar_ctx) - assert_type(str_function_agent, Agent[None, str]) + # mypy requires the generic parameters to be specified explicitly to be happy here + async_int_function_agent = Agent[None, int](output_type=foobar_plain) + assert_type(async_int_function_agent, Agent[None, int]) - int_function_agent = Agent[None, int](output_type=foobar_plain) - assert_type(int_function_agent, Agent[None, int]) + two_models_output_agent = Agent[None, Foo | Bar](output_type=[Foo, Bar]) + assert_type(two_models_output_agent, Agent[None, Foo | Bar]) - bool_method_agent = Agent[None, bool](output_type=MyClass().my_method) - assert_type(bool_method_agent, Agent[None, bool]) + two_scalars_output_agent = Agent[None, int | bool](output_type=[int, bool]) + assert_type(two_scalars_output_agent, Agent[None, int | bool]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore complex_output_agent = Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]( @@ -199,15 +205,16 @@ def my_method(self) -> bool: ) assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) else: - # pyright is able to correctly infer the output type here - str_function_agent = Agent(output_type=foobar_ctx) - assert_type(str_function_agent, Agent[None, str]) + # pyright is able to correctly infer the type here + async_int_function_agent = Agent(output_type=foobar_plain) + assert_type(async_int_function_agent, Agent[None, int]) - int_function_agent = Agent(output_type=foobar_plain) - assert_type(int_function_agent, Agent[None, int]) + two_models_output_agent = Agent(output_type=[Foo, Bar]) + assert_type(two_models_output_agent, Agent[None, Foo | Bar]) - bool_method_agent = Agent(output_type=MyClass().my_method) - assert_type(bool_method_agent, Agent[None, bool]) + # this doesn't work in pyright without the generic parameters specified explicitly + two_scalars_output_agent = Agent[None, int | bool](output_type=[int, bool]) + assert_type(two_scalars_output_agent, Agent[None, int | bool]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker]) From dc39e68133cc3cfa9e0ab9782f519822f9061b33 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 20:16:44 +0000 Subject: [PATCH 19/70] Update output_type typing tests --- tests/typed_agent.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index eeaf8fbd3a..34eb3d1095 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -196,8 +196,8 @@ def my_method(self) -> bool: two_models_output_agent = Agent[None, Foo | Bar](output_type=[Foo, Bar]) assert_type(two_models_output_agent, Agent[None, Foo | Bar]) - two_scalars_output_agent = Agent[None, int | bool](output_type=[int, bool]) - assert_type(two_scalars_output_agent, Agent[None, int | bool]) + two_scalars_output_agent = Agent[None, int | str](output_type=[int, str]) + assert_type(two_scalars_output_agent, Agent[None, int | str]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore complex_output_agent = Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]( @@ -212,9 +212,8 @@ def my_method(self) -> bool: two_models_output_agent = Agent(output_type=[Foo, Bar]) assert_type(two_models_output_agent, Agent[None, Foo | Bar]) - # this doesn't work in pyright without the generic parameters specified explicitly - two_scalars_output_agent = Agent[None, int | bool](output_type=[int, bool]) - assert_type(two_scalars_output_agent, Agent[None, int | bool]) + two_scalars_output_agent = Agent(output_type=[int, str]) + assert_type(two_scalars_output_agent, Agent[None, int | str]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker]) From a57874598941d2b838712c5767400a88b72106b8 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 23 May 2025 23:34:17 +0000 Subject: [PATCH 20/70] Document when automatic output_type type inference may fail --- docs/output.md | 37 ++++++++++++++---------- examples/pydantic_ai_examples/sql_gen.py | 2 +- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/docs/output.md b/docs/output.md index fd903a7e3b..7eac789c4d 100644 --- a/docs/output.md +++ b/docs/output.md @@ -1,9 +1,13 @@ "Output" refers to the final value returned from [running an agent](agents.md#running-agents). This can be either plain text, [structured data](#structured-output), or the result of a [function](#output-functions) called with arguments provided by the model. -The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so you can access other data like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results) +The output is wrapped in [`AgentRunResult`][pydantic_ai.agent.AgentRunResult] or [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] so that you can access other data, like [usage][pydantic_ai.usage.Usage] of the run and [message history](message-history.md#accessing-messages-from-results). Both `AgentRunResult` and `StreamedRunResult` are generic in the data they wrap, so typing information about the data returned by the agent is preserved. +A run ends when a plain text response is received (assuming no output type is specified or `str` is one of the allowed options), or when the model responds with one of the structured output types by calling a special output tool. A run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits). + +Here's an example using a Pydantic model as the `output_type`, forcing the model to respond with data matching our specification: + ```python {title="olympics.py" line_length="90"} from pydantic import BaseModel @@ -25,12 +29,12 @@ print(result.usage()) _(This example is complete, it can be run "as is")_ -A run ends when a plain text response is received (assuming no output type is specified or `str` is one of the allowed options), or when the model responds with one of the structured output types by calling a special output tool. A run can also be cancelled if usage limits are exceeded, see [Usage Limits](agents.md#usage-limits). - ## Output data {#structured-output} +The [`Agent`][pydantic_ai.Agent] class constructor takes an `output_type` argument that takes one or more types or [output functions](#output-functions). It supports both type unions and lists of types and functions. + When no output type is specified, or when the output type is `str` or a union or list of types including `str`, the model is allowed to respond with plain text, and this text is used as the output data. -If `str` is not among the allowed output types, the model is not allowed to respond with plain text and is forced to use an output tool to return structured data. +If `str` is not among the allowed output types, the model is not allowed to respond with plain text and is forced to return structured data (or arguments to an output function). If the output type is a union or list with multiple members, each member (except for `str`, if it is a member) is registered with the model as a separate output tool in order to reduce the complexity of the tool schemas and maximise the chances a model will respond correctly. @@ -38,14 +42,17 @@ If the output type schema is not of type `"object"` (e.g. it's `int` or `list[in Structured outputs (like tools) use Pydantic to build the JSON schema used for the tool, and to validate the data returned by the model. -!!! note "Bring on PEP-747" - Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands, type checkers will not consider unions a valid value for `output_type`, even though PydanticAI supports them. +!!! note "Type checking considerations" + The Agent class is generic in its output type, and this type is carried through to `AgentRunResult.output` and `StreamedRunResult.output` so that your IDE or static type checker can warn you when your code doesn't properly take into account all the possible values those outputs could have. - To work around this, we can use a list of types instead of a union, which is supported by type checkers. + Static type checkers like pyright and mypy will do their best the infer the agent's output type from the `output_type` you've specified, but they're not always able to do so correctly when you provide functions or multiple types in a union or list, even though PydanticAI will behave correctly. When this happens, your type checker will complain even when you're confident you've passed a valid `output_type`, and you'll need to help the type checker by explicitly specifying the generic parameters on the `Agent` constructor. This is shown in the second example below and the output functions example further down. - Alternatively, we can add `# type: ignore` to the `output_type` argument when creating the agent, and add an explicit type hint to the agent's variable to inform the type checker. This is shown in the second example. + Specifically, there are three valid uses of `output_type` where you'll need to do this: + 1. When using a union of types, e.g. `output_type=Foo | Bar` or in older Python, `output_type=Union[Foo, Bar]`. Until [PEP-747](https://peps.python.org/pep-0747/) "Annotating Type Forms" lands in Python 3.15, type checkers do not consider these a valid value for `output_type`. In addition to the generic parameters on the `Agent` constructor, you'll need to add `# type: ignore` to the line that passes the union to `output_type`. + 2. With mypy: When using a list, as a functionally equivalent alternative to a union, or because you're passing in [output functions](#output-functions). Pyright does handle this correctly, and we've filed [an issue](https://github.com/python/mypy/issues/19142) with mypy to try and get this fixed. + 3. With mypy: when using an async output function. Pyright does handle this correctly, and we've filed [an issue](https://github.com/python/mypy/issues/19143) with mypy to try and get this fixed. -Here's an example of returning either text or a structured value: +Here's an example of returning either text or structured data: ```python {title="box_or_error.py"} @@ -81,14 +88,14 @@ print(result.output) _(This example is complete, it can be run "as is")_ -Here's an example of using a union return type which registers multiple tools, and wraps non-object schemas in an object: +Here's an example of using a union return type, for which PydanticAI will register multiple tools and wraps non-object schemas in an object: ```python {title="colors_or_sizes.py"} from typing import Union from pydantic_ai import Agent -agent: Agent[None, Union[list[str], list[int]]] = Agent( +agent = Agent[None, Union[list[str], list[int]]]( 'openai:gpt-4o-mini', output_type=Union[list[str], list[int]], # type: ignore system_prompt='Extract either colors or sizes from the shapes provided.', @@ -111,7 +118,7 @@ Instead of plain text or structured data, you may want the output of your agent Output functions are similar to [function tools](tools.md), but the model is forced to call one of them, the call ends the agent run, and the result is not passed back to the model. -As with tool functions, output function arguments provided by the model are validated using Pydantic, they can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and they can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with updated arguments. +As with tool functions, output function arguments provided by the model are validated using Pydantic, they can optionally take [`RunContext`][pydantic_ai.tools.RunContext] as the first argument, and they can raise [`ModelRetry`][pydantic_ai.exceptions.ModelRetry] to ask the model to try again with modified arguments (or with a different output type). To specify output functions, you set the agent's `output_type` to either a single function (or bound instance method), or a list of functions. The list can also contain other output types like simple scalars or entire Pydantic models. You typically do not want to also register your output function as a tool (using the `@agent.tool` decorator or `tools` argument), as this could confuse the model about which it should be calling. @@ -168,7 +175,7 @@ def run_sql_query(query: str) -> list[Row]: raise ModelRetry(f"Unsupported query: '{query}'.") -sql_agent: Agent[None, Union[list[Row], SQLFailure]] = Agent( +sql_agent = Agent[None, Union[list[Row], SQLFailure]]( 'openai:gpt-4o', output_type=[run_sql_query, SQLFailure], instructions='You are a SQL agent that can run SQL queries on a database.', @@ -200,7 +207,7 @@ class RouterFailure(BaseModel): explanation: str -router_agent: Agent[None, Union[list[Row], RouterFailure]] = Agent( +router_agent = Agent[None, Union[list[Row], RouterFailure]]( 'openai:gpt-4o', output_type=[hand_off_to_sql_agent, RouterFailure], instructions='You are a router to other agents. Never try to solve a problem yourself, just pass it on.', @@ -254,7 +261,7 @@ class InvalidRequest(BaseModel): Output = Union[Success, InvalidRequest] -agent: Agent[DatabaseConn, Output] = Agent( +agent = Agent[DatabaseConn, Output]( 'google-gla:gemini-1.5-flash', output_type=Output, # type: ignore deps_type=DatabaseConn, diff --git a/examples/pydantic_ai_examples/sql_gen.py b/examples/pydantic_ai_examples/sql_gen.py index 28b5459fb7..fdf8c5ff3d 100644 --- a/examples/pydantic_ai_examples/sql_gen.py +++ b/examples/pydantic_ai_examples/sql_gen.py @@ -92,7 +92,7 @@ class InvalidRequest(BaseModel): Response: TypeAlias = Union[Success, InvalidRequest] -agent: Agent[Deps, Response] = Agent( +agent = Agent[Deps, Response]( 'google-gla:gemini-1.5-flash', # Type ignore while we wait for PEP-0747, nonetheless unions will work fine everywhere else output_type=Response, # type: ignore From fc42d69c9b284a25250b7313f5f4563fe03ddc3d Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Mon, 26 May 2025 11:30:16 +0200 Subject: [PATCH 21/70] Suggested code for from_langchain This does work however the style and approach are very much open for discussion. In particular: * this creates a function just so it can be inspected, as part of that it has a very bad mapping from json schema types back to python. * the Tool class current has no base class that defines the behaviour. If it did then creating a separate class that doesn't need the round trip might be appropriate --- pydantic_ai_slim/pydantic_ai/tools.py | 89 +++++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 7d174ba7d6..b14bbeaacb 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -312,6 +312,95 @@ async def prep_my_tool( self._validator = f['validator'] self._base_parameters_json_schema = f['json_schema'] + @staticmethod + def from_langchain(langchain_tool: "langchain_core.tools.base.BaseTool") -> Tool[None]: + """ + Creates a Pydantic tool proxy from a LangChain tool. + + Args: + langchain_tool: The LangChain tool to wrap. + + Returns: + A Pydantic tool that corresponds to the LangChain tool. + """ + import inspect + + _JSON_SCHEMA_TO_PYTHON = { + "array": list, + "boolean": bool, + "null": type(None), + "number": float, + "object": dict, + "string": str, + } + + function_name = langchain_tool.name + function_description = langchain_tool.description + # inputs are like: + # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, + # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} + inputs = langchain_tool.args.copy() + defaults = { + name: detail["default"] + for name, detail in inputs.items() + if "default" in detail + } + # need to reorder the inputs so that the default ones are last + inputs = dict( + [(name, detail) for name, detail in inputs.items() if "default" not in detail] + + [(name, detail) for name, detail in inputs.items() if "default" in detail] + ) + output_type = str + + # restructures the arguments to match langchain tool run + def proxy(*args, **kwargs): + tool_input = kwargs.copy() + for argument, key in zip(args, inputs.keys()): + tool_input[key] = argument + for name, default_value in defaults.items(): + if name in kwargs: + continue + kwargs[name] = default_value + return langchain_tool.run(tool_input) + + proxy.__name__ = function_name + + # Generate the docstring for the proxy + input_descriptions = [ + f"{name} ({detail['type']}): {detail['description']}" + for name, detail in inputs.items() + ] + input_description_str = "\n ".join(input_descriptions) + args_section = f"Args:\n {input_description_str}" + docstring = f"{function_description}\n\n{args_section}" + proxy.__doc__ = docstring + + # Replace the proxy signature and annotations with the arguments from the tool + parameters = [ + inspect.Parameter( + name=parameter_name, + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=parameter_details.get("default", inspect.Parameter.empty), + annotation=_JSON_SCHEMA_TO_PYTHON[parameter_details["type"]], + ) + for parameter_name, parameter_details in inputs.items() + ] + signature = inspect.Signature(parameters=parameters, return_annotation=output_type) + proxy.__signature__ = signature + + annotations = { + parameter_name: _JSON_SCHEMA_TO_PYTHON[parameter_details["type"]] + for parameter_name, parameter_details in inputs.items() + } + proxy.__annotations__ = annotations + + return Tool( + proxy, + takes_ctx=False, + name=function_name, + description=docstring, + ) + async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. From 7f897b165e2eee1a0213617a26304bd40d509885 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Mon, 26 May 2025 13:44:17 +0200 Subject: [PATCH 22/70] Add optional group of langchain to allow type checking If the optional group is not installed then a protocol is used to define the tool type used. --- docs/install.md | 1 + pydantic_ai_slim/pydantic_ai/tools.py | 23 +++++++++++++++++++---- pydantic_ai_slim/pyproject.toml | 2 ++ 3 files changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/install.md b/docs/install.md index 6d621ada5f..f115334855 100644 --- a/docs/install.md +++ b/docs/install.md @@ -56,6 +56,7 @@ pip/uv-add "pydantic-ai-slim[openai]" * `cohere` - installs `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"} * `duckduckgo` - installs `duckduckgo-search` [PyPI ↗](https://pypi.org/project/duckduckgo-search){:target="_blank"} * `tavily` - installs `tavily-python` [PyPI ↗](https://pypi.org/project/tavily-python){:target="_blank"} +* `langchain` - installs `langchain-core` [PyPI ↗](https://pypi.org/project/langchain-core){:target="_blank"} See the [models](models/index.md) documentation for information on which optional dependencies are required for each model. diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index b14bbeaacb..7afd398e77 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -19,6 +19,23 @@ if TYPE_CHECKING: from .result import Usage + try: + from langchain_core.tools import BaseTool as LangChainTool # type: ignore + except ImportError: + from typing import Protocol + + class LangChainTool(Protocol): + @property + def args(self) -> dict[str, Any]: ... + + @property + def name(self) -> str: ... + + @property + def description(self) -> str: ... + + def run(self, *args, **kwargs) -> str: ... + __all__ = ( 'AgentDepsT', 'DocstringFormat', @@ -313,7 +330,7 @@ async def prep_my_tool( self._base_parameters_json_schema = f['json_schema'] @staticmethod - def from_langchain(langchain_tool: "langchain_core.tools.base.BaseTool") -> Tool[None]: + def from_langchain(langchain_tool: LangChainTool) -> Tool[None]: """ Creates a Pydantic tool proxy from a LangChain tool. @@ -323,8 +340,6 @@ def from_langchain(langchain_tool: "langchain_core.tools.base.BaseTool") -> Tool Returns: A Pydantic tool that corresponds to the LangChain tool. """ - import inspect - _JSON_SCHEMA_TO_PYTHON = { "array": list, "boolean": bool, @@ -386,7 +401,7 @@ def proxy(*args, **kwargs): for parameter_name, parameter_details in inputs.items() ] signature = inspect.Signature(parameters=parameters, return_annotation=output_type) - proxy.__signature__ = signature + proxy.__signature__ = signature # type: ignore annotations = { parameter_name: _JSON_SCHEMA_TO_PYTHON[parameter_details["type"]] diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 32c30eaa65..408e8d86ef 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -80,6 +80,8 @@ mcp = ["mcp>=1.6.0; python_version >= '3.10'"] evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a=={{ version }}"] +# LangChain Tools +langchain = ["langchain-core>=0.3.61"] [dependency-groups] dev = [ From e3b3c7fea03bafa58a91596df7f34b2375bc2464 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Mon, 26 May 2025 14:15:53 +0200 Subject: [PATCH 23/70] Fix linting problems There are a bunch of unstable pre-commit hooks that I can't satisfy. The format one in particular wants to change unrelated code. --- pydantic_ai_slim/pydantic_ai/tools.py | 59 ++++++++++++--------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 7afd398e77..76f9f20a65 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -20,13 +20,16 @@ from .result import Usage try: - from langchain_core.tools import BaseTool as LangChainTool # type: ignore + from langchain_core.tools import BaseTool as LangChainTool except ImportError: from typing import Protocol class LangChainTool(Protocol): + # args are like + # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, + # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} @property - def args(self) -> dict[str, Any]: ... + def args(self) -> dict[str, dict[str, str]]: ... @property def name(self) -> str: ... @@ -34,7 +37,8 @@ def name(self) -> str: ... @property def description(self) -> str: ... - def run(self, *args, **kwargs) -> str: ... + def run(self, *args: Any, **kwargs: Any) -> str: ... + __all__ = ( 'AgentDepsT', @@ -331,8 +335,7 @@ async def prep_my_tool( @staticmethod def from_langchain(langchain_tool: LangChainTool) -> Tool[None]: - """ - Creates a Pydantic tool proxy from a LangChain tool. + """Creates a Pydantic tool proxy from a LangChain tool. Args: langchain_tool: The LangChain tool to wrap. @@ -341,34 +344,27 @@ def from_langchain(langchain_tool: LangChainTool) -> Tool[None]: A Pydantic tool that corresponds to the LangChain tool. """ _JSON_SCHEMA_TO_PYTHON = { - "array": list, - "boolean": bool, - "null": type(None), - "number": float, - "object": dict, - "string": str, + 'array': list, + 'boolean': bool, + 'null': type(None), + 'number': float, + 'object': dict, + 'string': str, } function_name = langchain_tool.name function_description = langchain_tool.description - # inputs are like: - # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, - # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} inputs = langchain_tool.args.copy() - defaults = { - name: detail["default"] - for name, detail in inputs.items() - if "default" in detail - } + defaults = {name: detail['default'] for name, detail in inputs.items() if 'default' in detail} # need to reorder the inputs so that the default ones are last inputs = dict( - [(name, detail) for name, detail in inputs.items() if "default" not in detail] - + [(name, detail) for name, detail in inputs.items() if "default" in detail] + [(name, detail) for name, detail in inputs.items() if 'default' not in detail] + + [(name, detail) for name, detail in inputs.items() if 'default' in detail] ) output_type = str # restructures the arguments to match langchain tool run - def proxy(*args, **kwargs): + def proxy(*args: Any, **kwargs: Any) -> str: tool_input = kwargs.copy() for argument, key in zip(args, inputs.keys()): tool_input[key] = argument @@ -381,13 +377,10 @@ def proxy(*args, **kwargs): proxy.__name__ = function_name # Generate the docstring for the proxy - input_descriptions = [ - f"{name} ({detail['type']}): {detail['description']}" - for name, detail in inputs.items() - ] - input_description_str = "\n ".join(input_descriptions) - args_section = f"Args:\n {input_description_str}" - docstring = f"{function_description}\n\n{args_section}" + input_descriptions = [f'{name} ({detail["type"]}): {detail["description"]}' for name, detail in inputs.items()] + input_description_str = '\n '.join(input_descriptions) + args_section = f'Args:\n {input_description_str}' + docstring = f'{function_description}\n\n{args_section}' proxy.__doc__ = docstring # Replace the proxy signature and annotations with the arguments from the tool @@ -395,16 +388,16 @@ def proxy(*args, **kwargs): inspect.Parameter( name=parameter_name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=parameter_details.get("default", inspect.Parameter.empty), - annotation=_JSON_SCHEMA_TO_PYTHON[parameter_details["type"]], + default=parameter_details.get('default', inspect.Parameter.empty), + annotation=_JSON_SCHEMA_TO_PYTHON[parameter_details['type']], ) for parameter_name, parameter_details in inputs.items() ] signature = inspect.Signature(parameters=parameters, return_annotation=output_type) - proxy.__signature__ = signature # type: ignore + proxy.__signature__ = signature # type: ignore annotations = { - parameter_name: _JSON_SCHEMA_TO_PYTHON[parameter_details["type"]] + parameter_name: _JSON_SCHEMA_TO_PYTHON[parameter_details['type']] for parameter_name, parameter_details in inputs.items() } proxy.__annotations__ = annotations From 8597ae4cbfebac14f44a344a0717af6e5fbfe6c5 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Mon, 26 May 2025 14:36:56 +0200 Subject: [PATCH 24/70] Add a test for tool conversion --- tests/test_tools.py | 54 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_tools.py b/tests/test_tools.py index ec2d9cfaf4..9b8f57d203 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1048,3 +1048,57 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: result = agent.run_sync('', deps=1) assert result.output == snapshot('{"foobar":"1 0 a"}') + + +def test_langchain_tool_conversion(): + call_args: list[int] = [] + + @dataclass + class SimulatedLangChainTool: + name: str + description: str + args: dict[str, dict[str, str]] + + def run( + self, + tool_input: Union[str, dict[str, Any]], + verbose: bool | None = None, + start_color: str | None = 'green', + color: str | None = 'green', + callbacks: Any = None, + *, + tags: list[str] | None = None, + metadata: dict[str, Any] | None = None, + run_name: str | None = None, + run_id: Any | None = None, + config: Any | None = None, + tool_call_id: str | None = None, + **kwargs: Any, + ) -> Any: + return 'I was called' + + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = Tool.from_langchain(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot('{"file_search":"I was called"}') + assert call_args == snapshot([]) + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 From 2a92ba0e54913023daaa92a21913246191db7950 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Mon, 26 May 2025 14:37:47 +0200 Subject: [PATCH 25/70] Drop attempted import of langchain tool Just makes the checks unhappy --- pydantic_ai_slim/pydantic_ai/tools.py | 29 ++++++++++++--------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 76f9f20a65..a6fe603e15 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -17,27 +17,24 @@ from .exceptions import ModelRetry, UnexpectedModelBehavior if TYPE_CHECKING: - from .result import Usage + from typing import Protocol - try: - from langchain_core.tools import BaseTool as LangChainTool - except ImportError: - from typing import Protocol + from .result import Usage - class LangChainTool(Protocol): - # args are like - # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, - # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} - @property - def args(self) -> dict[str, dict[str, str]]: ... + class LangChainTool(Protocol): + # args are like + # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, + # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} + @property + def args(self) -> dict[str, dict[str, str]]: ... - @property - def name(self) -> str: ... + @property + def name(self) -> str: ... - @property - def description(self) -> str: ... + @property + def description(self) -> str: ... - def run(self, *args: Any, **kwargs: Any) -> str: ... + def run(self, *args: Any, **kwargs: Any) -> str: ... __all__ = ( From 2d22c3242410564c4db52e992e5f01dd9e19d38e Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Mon, 26 May 2025 14:46:06 +0200 Subject: [PATCH 26/70] Older pythons don't enjoy type unions with | --- tests/test_tools.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 9b8f57d203..570e806706 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1062,17 +1062,17 @@ class SimulatedLangChainTool: def run( self, tool_input: Union[str, dict[str, Any]], - verbose: bool | None = None, - start_color: str | None = 'green', - color: str | None = 'green', + verbose: Union[bool, None] = None, + start_color: Union[str, None] = 'green', + color: Union[str, None] = 'green', callbacks: Any = None, *, - tags: list[str] | None = None, - metadata: dict[str, Any] | None = None, - run_name: str | None = None, - run_id: Any | None = None, - config: Any | None = None, - tool_call_id: str | None = None, + tags: Union[list[str], None] = None, + metadata: Union[dict[str, Any], None] = None, + run_name: Union[str, None] = None, + run_id: Union[Any, None] = None, + config: Union[Any, None] = None, + tool_call_id: Union[str, None] = None, **kwargs: Any, ) -> Any: return 'I was called' From 2d6cc1adad1ad6accbf24f1fbc7954b050982231 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 26 May 2025 13:09:46 +0000 Subject: [PATCH 27/70] Drop unnecessary explicit generic parameter on constructor in typing test --- tests/typed_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 34eb3d1095..180ce2b0dc 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -216,7 +216,7 @@ def my_method(self) -> bool: assert_type(two_scalars_output_agent, Agent[None, int | str]) marker: ToolOutput[bool | tuple[str, int]] = ToolOutput(bool | tuple[str, int]) # type: ignore - complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput[int](foobar_plain), marker]) + complex_output_agent = Agent(output_type=[Foo, Bar, foobar_ctx, ToolOutput(foobar_plain), marker]) assert_type(complex_output_agent, Agent[None, Foo | Bar | str | int | bool | tuple[str, int]]) From 6ef816ab6b089ee09a1b9cc360354524420e013e Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 21:16:37 +0200 Subject: [PATCH 28/70] hack in the new FunctionSchema approach --- pydantic_ai_slim/pydantic_ai/tools.py | 79 ++++++++++++--------------- 1 file changed, 35 insertions(+), 44 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 1a29350dce..7d9b27505f 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,7 +1,6 @@ from __future__ import annotations as _annotations import dataclasses -import inspect import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field @@ -335,25 +334,19 @@ def from_langchain(langchain_tool: LangChainTool) -> Tool[None]: Returns: A Pydantic tool that corresponds to the LangChain tool. """ - _JSON_SCHEMA_TO_PYTHON = { - 'array': list, - 'boolean': bool, - 'null': type(None), - 'number': float, - 'object': dict, - 'string': str, - } - function_name = langchain_tool.name function_description = langchain_tool.description inputs = langchain_tool.args.copy() + required = sorted({name for name, detail in inputs.items() if 'default' not in detail}) + schema = { + 'type': 'object', + 'properties': inputs, + 'additionalProperties': False, + } + if required: + schema['required'] = required + defaults = {name: detail['default'] for name, detail in inputs.items() if 'default' in detail} - # need to reorder the inputs so that the default ones are last - inputs = dict( - [(name, detail) for name, detail in inputs.items() if 'default' not in detail] - + [(name, detail) for name, detail in inputs.items() if 'default' in detail] - ) - output_type = str # restructures the arguments to match langchain tool run def proxy(*args: Any, **kwargs: Any) -> str: @@ -366,39 +359,37 @@ def proxy(*args: Any, **kwargs: Any) -> str: kwargs[name] = default_value return langchain_tool.run(tool_input) - proxy.__name__ = function_name - - # Generate the docstring for the proxy - input_descriptions = [f'{name} ({detail["type"]}): {detail["description"]}' for name, detail in inputs.items()] - input_description_str = '\n '.join(input_descriptions) - args_section = f'Args:\n {input_description_str}' - docstring = f'{function_description}\n\n{args_section}' - proxy.__doc__ = docstring - - # Replace the proxy signature and annotations with the arguments from the tool - parameters = [ - inspect.Parameter( - name=parameter_name, - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=parameter_details.get('default', inspect.Parameter.empty), - annotation=_JSON_SCHEMA_TO_PYTHON[parameter_details['type']], - ) - for parameter_name, parameter_details in inputs.items() - ] - signature = inspect.Signature(parameters=parameters, return_annotation=output_type) - proxy.__signature__ = signature # type: ignore - - annotations = { - parameter_name: _JSON_SCHEMA_TO_PYTHON[parameter_details['type']] - for parameter_name, parameter_details in inputs.items() - } - proxy.__annotations__ = annotations + class AnySchemaValidator: + def validate_python( + self, + input: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: Any | None = None, + self_instance: Any | None = None, + allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, + ) -> Any: + return input + + function_schema = _function_schema.FunctionSchema( + function=proxy, + description=function_description, + validator=AnySchemaValidator(), + json_schema=schema, + takes_ctx=False, + is_async=False, + single_arg_name=None, + positional_fields=[], + var_positional_field=None, + ) return Tool( proxy, takes_ctx=False, name=function_name, - description=docstring, + description=function_description, + function_schema=function_schema, ) async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: From aefad98beaeb62ce060a6463ea30b42ab23106ab Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 21:28:44 +0200 Subject: [PATCH 29/70] split into from_function and from_langchain --- pydantic_ai_slim/pydantic_ai/tools.py | 81 +++++++++++++++++---------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 7d9b27505f..c75f45bd45 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -325,39 +325,18 @@ async def prep_my_tool( self.strict = strict @staticmethod - def from_langchain(langchain_tool: LangChainTool) -> Tool[None]: - """Creates a Pydantic tool proxy from a LangChain tool. + def from_function(function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[None]: + """Creates a Pydantic tool from a function and a JSON schema. Args: - langchain_tool: The LangChain tool to wrap. + function: The function to call + json_schema: The schema for the function arguments Returns: - A Pydantic tool that corresponds to the LangChain tool. + A Pydantic tool that calls the function """ - function_name = langchain_tool.name - function_description = langchain_tool.description - inputs = langchain_tool.args.copy() - required = sorted({name for name, detail in inputs.items() if 'default' not in detail}) - schema = { - 'type': 'object', - 'properties': inputs, - 'additionalProperties': False, - } - if required: - schema['required'] = required - - defaults = {name: detail['default'] for name, detail in inputs.items() if 'default' in detail} - - # restructures the arguments to match langchain tool run - def proxy(*args: Any, **kwargs: Any) -> str: - tool_input = kwargs.copy() - for argument, key in zip(args, inputs.keys()): - tool_input[key] = argument - for name, default_value in defaults.items(): - if name in kwargs: - continue - kwargs[name] = default_value - return langchain_tool.run(tool_input) + function_name = function.__name__ + function_description = function.__doc__ or '' class AnySchemaValidator: def validate_python( @@ -373,10 +352,10 @@ def validate_python( return input function_schema = _function_schema.FunctionSchema( - function=proxy, + function=function, description=function_description, validator=AnySchemaValidator(), - json_schema=schema, + json_schema=json_schema, takes_ctx=False, is_async=False, single_arg_name=None, @@ -385,13 +364,53 @@ def validate_python( ) return Tool( - proxy, + function, takes_ctx=False, name=function_name, description=function_description, function_schema=function_schema, ) + @classmethod + def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[None]: + """Creates a Pydantic tool proxy from a LangChain tool. + + Args: + langchain_tool: The LangChain tool to wrap. + + Returns: + A Pydantic tool that corresponds to the LangChain tool. + """ + function_name = langchain_tool.name + function_description = langchain_tool.description + inputs = langchain_tool.args.copy() + required = sorted({name for name, detail in inputs.items() if 'default' not in detail}) + schema = { + 'type': 'object', + 'properties': inputs, + 'additionalProperties': False, + } + if required: + schema['required'] = required + + defaults = {name: detail['default'] for name, detail in inputs.items() if 'default' in detail} + + # restructures the arguments to match langchain tool run + def proxy(*args: Any, **kwargs: Any) -> str: + tool_input = kwargs.copy() + for argument, key in zip(args, inputs.keys()): + tool_input[key] = argument + for name, default_value in defaults.items(): + if name in kwargs: + continue + kwargs[name] = default_value + return langchain_tool.run(tool_input) + + proxy.__name__ = function_name + proxy.__doc__ = function_description + + return cls.from_function(function=proxy, json_schema=schema) + async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. From 1614418e1e1a4cfa516b27dee1690fb11eec13fa Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 21:36:53 +0200 Subject: [PATCH 30/70] Use the proper schema validator --- pydantic_ai_slim/pydantic_ai/tools.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index c75f45bd45..7c160a93e4 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -9,7 +9,7 @@ from opentelemetry.trace import Tracer from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue -from pydantic_core import core_schema +from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar from . import _function_schema, _utils, messages as _messages @@ -338,23 +338,10 @@ def from_function(function: Callable[..., Any], json_schema: JsonSchemaValue) -> function_name = function.__name__ function_description = function.__doc__ or '' - class AnySchemaValidator: - def validate_python( - self, - input: Any, - *, - strict: bool | None = None, - from_attributes: bool | None = None, - context: Any | None = None, - self_instance: Any | None = None, - allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, - ) -> Any: - return input - function_schema = _function_schema.FunctionSchema( function=function, description=function_description, - validator=AnySchemaValidator(), + validator=SchemaValidator(schema=core_schema.any_schema()), json_schema=json_schema, takes_ctx=False, is_async=False, From 5269059f3059781187c62ceb4ace9539b79496dc Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 21:54:58 +0200 Subject: [PATCH 31/70] test the function conversion directly --- tests/test_tools.py | 47 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 570e806706..d77d1cd148 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1050,9 +1050,51 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: assert result.output == snapshot('{"foobar":"1 0 a"}') -def test_langchain_tool_conversion(): - call_args: list[int] = [] +def test_function_tool_consistent_with_schema(): + def function(*args: Any, **kwargs: Any) -> str: + assert len(args) == 0 + assert set(kwargs) == {'one', 'two'} + return 'I like being called like this' + + json_schema = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'one': {'description': 'first argument', 'type': 'string'}, + 'two': {'description': 'second argument', 'type': 'object'}, + }, + 'required': ['one', 'two'], + } + pydantic_tool = Tool.from_function(function, json_schema=json_schema) + + agent = Agent('test', tools=[pydantic_tool], retries=0) + result = agent.run_sync('foobar') + assert result.output == snapshot('{"function":"I like being called like this"}') + assert agent._function_tools['function'].takes_ctx is False + assert agent._function_tools['function'].max_retries == 0 + +def test_function_tool_inconsistent_with_schema(): + def function(three: str, four: int) -> str: + return 'How did you even manage this?' + + json_schema = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'one': {'description': 'first argument', 'type': 'string'}, + 'two': {'description': 'second argument', 'type': 'object'}, + }, + 'required': ['one', 'two'], + } + pydantic_tool = Tool.from_function(function, json_schema=json_schema) + + agent = Agent('test', tools=[pydantic_tool], retries=0) + with pytest.raises(TypeError, match=".* got an unexpected keyword argument 'one'"): + agent.run_sync('foobar') + + +def test_langchain_tool_conversion(): @dataclass class SimulatedLangChainTool: name: str @@ -1099,6 +1141,5 @@ def run( agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') assert result.output == snapshot('{"file_search":"I was called"}') - assert call_args == snapshot([]) assert agent._function_tools['file_search'].takes_ctx is False assert agent._function_tools['file_search'].max_retries == 7 From b7478a3d0715866a2f54db783af27a325f99c021 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 22:02:34 +0200 Subject: [PATCH 32/70] Support async functions for from_function Should think about the langchain version - problem is there is a default proxy method for arun --- pydantic_ai_slim/pydantic_ai/tools.py | 3 ++- tests/test_tools.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 7c160a93e4..c08f710bf2 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import asyncio import dataclasses import json from collections.abc import Awaitable, Sequence @@ -344,7 +345,7 @@ def from_function(function: Callable[..., Any], json_schema: JsonSchemaValue) -> validator=SchemaValidator(schema=core_schema.any_schema()), json_schema=json_schema, takes_ctx=False, - is_async=False, + is_async=asyncio.iscoroutinefunction(function), single_arg_name=None, positional_fields=[], var_positional_field=None, diff --git a/tests/test_tools.py b/tests/test_tools.py index d77d1cd148..c8250d38e5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1094,6 +1094,30 @@ def function(three: str, four: int) -> str: agent.run_sync('foobar') +def test_async_function_tool_consistent_with_schema(): + async def function(*args: Any, **kwargs: Any) -> str: + assert len(args) == 0 + assert set(kwargs) == {'one', 'two'} + return 'I like being called like this' + + json_schema = { + 'type': 'object', + 'additionalProperties': False, + 'properties': { + 'one': {'description': 'first argument', 'type': 'string'}, + 'two': {'description': 'second argument', 'type': 'object'}, + }, + 'required': ['one', 'two'], + } + pydantic_tool = Tool.from_function(function, json_schema=json_schema) + + agent = Agent('test', tools=[pydantic_tool], retries=0) + result = agent.run_sync('foobar') + assert result.output == snapshot('{"function":"I like being called like this"}') + assert agent._function_tools['function'].takes_ctx is False + assert agent._function_tools['function'].max_retries == 0 + + def test_langchain_tool_conversion(): @dataclass class SimulatedLangChainTool: From 880c629ca8adce7590b338a0f526bf00d3c18617 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 22:07:21 +0200 Subject: [PATCH 33/70] Fix the type of the schema --- pydantic_ai_slim/pydantic_ai/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index c08f710bf2..df1187165c 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -373,7 +373,7 @@ def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[None]: function_description = langchain_tool.description inputs = langchain_tool.args.copy() required = sorted({name for name, detail in inputs.items() if 'default' not in detail}) - schema = { + schema: JsonSchemaValue = { 'type': 'object', 'properties': inputs, 'additionalProperties': False, From 8407f3b1d3be3ca8f17a332167412cd7d14c6def Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 22:24:39 +0200 Subject: [PATCH 34/70] test the langchain function directly and find a bug! oh no! --- pydantic_ai_slim/pydantic_ai/tools.py | 5 +- tests/test_tools.py | 102 +++++++++++++++++++------- 2 files changed, 77 insertions(+), 30 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index df1187165c..5ccacf0290 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -385,14 +385,13 @@ def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[None]: # restructures the arguments to match langchain tool run def proxy(*args: Any, **kwargs: Any) -> str: - tool_input = kwargs.copy() for argument, key in zip(args, inputs.keys()): - tool_input[key] = argument + kwargs[key] = argument for name, default_value in defaults.items(): if name in kwargs: continue kwargs[name] = default_value - return langchain_tool.run(tool_input) + return langchain_tool.run(kwargs) proxy.__name__ = function_name proxy.__doc__ = function_description diff --git a/tests/test_tools.py b/tests/test_tools.py index c8250d38e5..7e61113f4a 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1075,8 +1075,7 @@ def function(*args: Any, **kwargs: Any) -> str: def test_function_tool_inconsistent_with_schema(): - def function(three: str, four: int) -> str: - return 'How did you even manage this?' + def function(three: str, four: int) -> None: ... json_schema = { 'type': 'object', @@ -1118,31 +1117,32 @@ async def function(*args: Any, **kwargs: Any) -> str: assert agent._function_tools['function'].max_retries == 0 -def test_langchain_tool_conversion(): - @dataclass - class SimulatedLangChainTool: - name: str - description: str - args: dict[str, dict[str, str]] - - def run( - self, - tool_input: Union[str, dict[str, Any]], - verbose: Union[bool, None] = None, - start_color: Union[str, None] = 'green', - color: Union[str, None] = 'green', - callbacks: Any = None, - *, - tags: Union[list[str], None] = None, - metadata: Union[dict[str, Any], None] = None, - run_name: Union[str, None] = None, - run_id: Union[Any, None] = None, - config: Union[Any, None] = None, - tool_call_id: Union[str, None] = None, - **kwargs: Any, - ) -> Any: - return 'I was called' +@dataclass +class SimulatedLangChainTool: + name: str + description: str + args: dict[str, dict[str, str]] + + def run( + self, + tool_input: Union[str, dict[str, Any]], + verbose: Union[bool, None] = None, + start_color: Union[str, None] = 'green', + color: Union[str, None] = 'green', + callbacks: Any = None, + *, + tags: Union[list[str], None] = None, + metadata: Union[dict[str, Any], None] = None, + run_name: Union[str, None] = None, + run_id: Union[Any, None] = None, + config: Union[Any, None] = None, + tool_call_id: Union[str, None] = None, + **kwargs: Any, + ) -> Any: + return f'I was called with {tool_input}' + +def test_langchain_tool_conversion(): langchain_tool = SimulatedLangChainTool( name='file_search', description='Recursively search for files in a subdirectory that match the regex pattern', @@ -1164,6 +1164,54 @@ def run( agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') - assert result.output == snapshot('{"file_search":"I was called"}') + assert result.output == snapshot("{\"file_search\":\"I was called with {'pattern': 'a', 'dir_path': '.'}\"}") assert agent._function_tools['file_search'].takes_ctx is False assert agent._function_tools['file_search'].max_retries == 7 + + +def test_langchain_tool_defaults(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = Tool.from_langchain(langchain_tool) + + result = pydantic_tool.function(pattern='something') + assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': '.'}") + + +def test_langchain_tool_positional(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + }, + ) + pydantic_tool = Tool.from_langchain(langchain_tool) + + result = pydantic_tool.function('something') + assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': '.'}") From ba9023775745ea23dd12b4674a51205b54a2c70a Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 22:28:00 +0200 Subject: [PATCH 35/70] type: ignore the funny calls the signature generated for the function is quite odd --- tests/test_tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 7e61113f4a..b3b18a0cc9 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1189,7 +1189,7 @@ def test_langchain_tool_defaults(): ) pydantic_tool = Tool.from_langchain(langchain_tool) - result = pydantic_tool.function(pattern='something') + result = pydantic_tool.function(pattern='something') # type: ignore assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': '.'}") @@ -1213,5 +1213,5 @@ def test_langchain_tool_positional(): ) pydantic_tool = Tool.from_langchain(langchain_tool) - result = pydantic_tool.function('something') + result = pydantic_tool.function('something') # type: ignore assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': '.'}") From 6351ffaa01ee48e6acbd894780576de8379281df Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 22:46:29 +0200 Subject: [PATCH 36/70] call the broken function to get coverage --- tests/test_tools.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index b3b18a0cc9..5f1b943c4c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1075,7 +1075,8 @@ def function(*args: Any, **kwargs: Any) -> str: def test_function_tool_inconsistent_with_schema(): - def function(three: str, four: int) -> None: ... + def function(three: str, four: int) -> str: + return 'Coverage made me call this' json_schema = { 'type': 'object', @@ -1092,6 +1093,9 @@ def function(three: str, four: int) -> None: ... with pytest.raises(TypeError, match=".* got an unexpected keyword argument 'one'"): agent.run_sync('foobar') + result = function('three', 4) + assert result == 'Coverage made me call this' + def test_async_function_tool_consistent_with_schema(): async def function(*args: Any, **kwargs: Any) -> str: From 40b5a5ccccc74aa2d805f1425f998916a1141de7 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 22:56:06 +0200 Subject: [PATCH 37/70] branch coverage for required/default --- tests/test_tools.py | 54 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/tests/test_tools.py b/tests/test_tools.py index 5f1b943c4c..ce4e42f115 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1173,6 +1173,60 @@ def test_langchain_tool_conversion(): assert agent._function_tools['file_search'].max_retries == 7 +def test_langchain_tool_conversion_no_defaults(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = Tool.from_langchain(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': 'a', 'pattern': 'a'}\"}") + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 + + +def test_langchain_tool_conversion_no_required(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'default': '*', + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = Tool.from_langchain(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': '*'}\"}") + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 + + def test_langchain_tool_defaults(): langchain_tool = SimulatedLangChainTool( name='file_search', From 25df8774b79e5a397a23eb1e32b723aecddbf7a3 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 27 May 2025 23:06:19 +0200 Subject: [PATCH 38/70] test that the default is overridden --- tests/test_tools.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_tools.py b/tests/test_tools.py index ce4e42f115..354a6235b8 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1273,3 +1273,27 @@ def test_langchain_tool_positional(): result = pydantic_tool.function('something') # type: ignore assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': '.'}") + + +def test_langchain_tool_default_override(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = Tool.from_langchain(langchain_tool) + + result = pydantic_tool.function(pattern='something', dir_path='somewhere') # type: ignore + assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': 'somewhere'}") From 8121577e54e6402ac67b5044be0810ddda2c94b3 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Wed, 28 May 2025 16:21:32 +0200 Subject: [PATCH 39/70] Change name of function --- pydantic_ai_slim/pydantic_ai/tools.py | 4 +- tests/test_tools.py | 62 +++++++++++++++++++++------ 2 files changed, 51 insertions(+), 15 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 5ccacf0290..0980f74b8d 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -326,7 +326,7 @@ async def prep_my_tool( self.strict = strict @staticmethod - def from_function(function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[None]: + def from_schema(function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[None]: """Creates a Pydantic tool from a function and a JSON schema. Args: @@ -396,7 +396,7 @@ def proxy(*args: Any, **kwargs: Any) -> str: proxy.__name__ = function_name proxy.__doc__ = function_description - return cls.from_function(function=proxy, json_schema=schema) + return cls.from_schema(function=proxy, json_schema=schema) async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. diff --git a/tests/test_tools.py b/tests/test_tools.py index 354a6235b8..638270d95f 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,7 +12,14 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext, Tool, ToolOutput, UserError -from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturnPart, +) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition @@ -326,7 +333,11 @@ def test_only_returns_type(): The result as a string. \ """, - 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, + 'parameters_json_schema': { + 'additionalProperties': False, + 'properties': {}, + 'type': 'object', + }, 'outer_typed_dict_key': None, 'strict': None, } @@ -502,7 +513,12 @@ def ctx_tool(ctx: RunContext[int], x: int) -> int: # pyright: reportPrivateUsage=false def test_init_tool_ctx(): - agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7) + agent = Agent( + 'test', + tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], + deps_type=int, + retries=7, + ) result = agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') assert agent._function_tools['ctx_tool'].takes_ctx is True @@ -572,7 +588,12 @@ def test_tool_return_conflict(): Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) # this raises an error with pytest.raises(UserError, match="Tool name conflicts with result schema name: 'ctx_tool'"): - Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(type_=int, name='ctx_tool')) + Agent( + 'test', + tools=[ctx_tool], + deps_type=int, + output_type=ToolOutput(type_=int, name='ctx_tool'), + ) def test_init_ctx_tool_invalid(): @@ -585,7 +606,10 @@ def plain_tool(x: int) -> int: # pragma: no cover def test_init_plain_tool_invalid(): - with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'): + with pytest.raises( + UserError, + match='RunContext annotations can only be used with tools that take context', + ): Tool(ctx_tool, takes_ctx=False) @@ -632,7 +656,10 @@ def test_return_bytes_invalid(): def return_pydantic_model() -> bytes: return b'\00 \x81' - with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'): + with pytest.raises( + PydanticSerializationError, + match='invalid utf-8 sequence of 1 bytes from index 2', + ): agent.run_sync('') @@ -767,7 +794,8 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: def test_future_run_context(create_module: Callable[[str], Any]): - mod = create_module(""" + mod = create_module( + """ from __future__ import annotations from pydantic_ai import Agent, RunContext @@ -776,7 +804,8 @@ def ctx_tool(ctx: RunContext[int], x: int) -> int: return x + ctx.deps agent = Agent('test', tools=[ctx_tool], deps_type=int) - """) + """ + ) result = mod.agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') @@ -800,7 +829,11 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture): 'description': "A tool that documents what it returns but doesn't have a return annotation in the docstring.", 'name': 'tool_without_return_annotation_in_docstring', 'outer_typed_dict_key': None, - 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, + 'parameters_json_schema': { + 'additionalProperties': False, + 'properties': {}, + 'type': 'object', + }, 'strict': None, } ) @@ -948,7 +981,10 @@ def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaVa agent = Agent(FunctionModel(get_json_schema)) - def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = None, **kwargs: Any): + def my_tool( + x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = None, + **kwargs: Any, + ): return x # pragma: no cover agent.tool_plain(name='my_tool_1')(my_tool) @@ -1065,7 +1101,7 @@ def function(*args: Any, **kwargs: Any) -> str: }, 'required': ['one', 'two'], } - pydantic_tool = Tool.from_function(function, json_schema=json_schema) + pydantic_tool = Tool.from_schema(function, json_schema=json_schema) agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') @@ -1087,7 +1123,7 @@ def function(three: str, four: int) -> str: }, 'required': ['one', 'two'], } - pydantic_tool = Tool.from_function(function, json_schema=json_schema) + pydantic_tool = Tool.from_schema(function, json_schema=json_schema) agent = Agent('test', tools=[pydantic_tool], retries=0) with pytest.raises(TypeError, match=".* got an unexpected keyword argument 'one'"): @@ -1112,7 +1148,7 @@ async def function(*args: Any, **kwargs: Any) -> str: }, 'required': ['one', 'two'], } - pydantic_tool = Tool.from_function(function, json_schema=json_schema) + pydantic_tool = Tool.from_schema(function, json_schema=json_schema) agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') From ffb78eb2b6029a5a82c1b5a184f26b99b6b6cb3e Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Wed, 28 May 2025 17:33:36 +0200 Subject: [PATCH 40/70] fix import, address PR feedback --- pydantic_ai_slim/pydantic_ai/tools.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 296c47d904..aa0e6f3918 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -5,12 +5,12 @@ import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, Self from opentelemetry.trace import Tracer from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue -from pydantic_core import core_schema +from pydantic_core import core_schema, SchemaValidator from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar from . import _function_schema, _utils, messages as _messages @@ -325,8 +325,8 @@ async def prep_my_tool( self.require_parameter_descriptions = require_parameter_descriptions self.strict = strict - @staticmethod - def from_schema(function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[None]: + @classmethod + def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Self: """Creates a Pydantic tool from a function and a JSON schema. Args: @@ -351,7 +351,7 @@ def from_schema(function: Callable[..., Any], json_schema: JsonSchemaValue) -> T var_positional_field=None, ) - return Tool( + return cls( function, takes_ctx=False, name=function_name, @@ -360,7 +360,7 @@ def from_schema(function: Callable[..., Any], json_schema: JsonSchemaValue) -> T ) @classmethod - def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[None]: + def from_langchain(cls, langchain_tool: LangChainTool) -> Self: """Creates a Pydantic tool proxy from a LangChain tool. Args: From 8024a8061938f8a75cb51fa0e362c77d157d890a Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Wed, 28 May 2025 17:39:51 +0200 Subject: [PATCH 41/70] Self was introduced in python 3.11 --- pydantic_ai_slim/pydantic_ai/tools.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index aa0e6f3918..0fca46d628 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -5,12 +5,12 @@ import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, Self +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union from opentelemetry.trace import Tracer from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue -from pydantic_core import core_schema, SchemaValidator +from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar from . import _function_schema, _utils, messages as _messages @@ -326,7 +326,7 @@ async def prep_my_tool( self.strict = strict @classmethod - def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Self: + def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[None]: """Creates a Pydantic tool from a function and a JSON schema. Args: @@ -360,7 +360,7 @@ def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) ) @classmethod - def from_langchain(cls, langchain_tool: LangChainTool) -> Self: + def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[None]: """Creates a Pydantic tool proxy from a LangChain tool. Args: From 98ffd4e70d89f3f996b9374d260d555a2b84f4b7 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Wed, 28 May 2025 17:45:39 +0200 Subject: [PATCH 42/70] Use AgentDepsT --- pydantic_ai_slim/pydantic_ai/tools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 0fca46d628..c8e48d7caa 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -326,7 +326,7 @@ async def prep_my_tool( self.strict = strict @classmethod - def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[None]: + def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[AgentDepsT]: """Creates a Pydantic tool from a function and a JSON schema. Args: @@ -360,7 +360,7 @@ def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) ) @classmethod - def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[None]: + def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[AgentDepsT]: """Creates a Pydantic tool proxy from a LangChain tool. Args: From 44f285e72d7ef7501a78c19de4025190cbdee13c Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Wed, 28 May 2025 17:50:22 +0200 Subject: [PATCH 43/70] Import Self from typing_extensions --- pydantic_ai_slim/pydantic_ai/tools.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index c8e48d7caa..6caa40ee98 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -11,7 +11,7 @@ from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import SchemaValidator, core_schema -from typing_extensions import Concatenate, ParamSpec, TypeAlias, TypeVar +from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar from . import _function_schema, _utils, messages as _messages from .exceptions import ModelRetry, UnexpectedModelBehavior @@ -326,7 +326,7 @@ async def prep_my_tool( self.strict = strict @classmethod - def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Tool[AgentDepsT]: + def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Self: """Creates a Pydantic tool from a function and a JSON schema. Args: @@ -360,7 +360,7 @@ def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) ) @classmethod - def from_langchain(cls, langchain_tool: LangChainTool) -> Tool[AgentDepsT]: + def from_langchain(cls, langchain_tool: LangChainTool) -> Self: """Creates a Pydantic tool proxy from a LangChain tool. Args: From 4af3e4e8a10e56efabf9426e6936eea903ec95ee Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 11:11:44 +0200 Subject: [PATCH 44/70] remove langchain optional group --- docs/install.md | 1 - pydantic_ai_slim/pyproject.toml | 2 -- 2 files changed, 3 deletions(-) diff --git a/docs/install.md b/docs/install.md index f115334855..6d621ada5f 100644 --- a/docs/install.md +++ b/docs/install.md @@ -56,7 +56,6 @@ pip/uv-add "pydantic-ai-slim[openai]" * `cohere` - installs `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"} * `duckduckgo` - installs `duckduckgo-search` [PyPI ↗](https://pypi.org/project/duckduckgo-search){:target="_blank"} * `tavily` - installs `tavily-python` [PyPI ↗](https://pypi.org/project/tavily-python){:target="_blank"} -* `langchain` - installs `langchain-core` [PyPI ↗](https://pypi.org/project/langchain-core){:target="_blank"} See the [models](models/index.md) documentation for information on which optional dependencies are required for each model. diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index bdf1c10c39..631cc196d0 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -80,8 +80,6 @@ mcp = ["mcp>=1.9.0; python_version >= '3.10'"] evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a=={{ version }}"] -# LangChain Tools -langchain = ["langchain-core>=0.3.61"] [dependency-groups] dev = [ From 8bf590bae30dbc439b68f7fa09e4635b9358451f Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 11:21:24 +0200 Subject: [PATCH 45/70] Make the FunctionSchema arguments defaults --- pydantic_ai_slim/pydantic_ai/_function_schema.py | 8 ++++---- pydantic_ai_slim/pydantic_ai/tools.py | 3 --- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_function_schema.py b/pydantic_ai_slim/pydantic_ai/_function_schema.py index facca89aa6..64201a3790 100644 --- a/pydantic_ai_slim/pydantic_ai/_function_schema.py +++ b/pydantic_ai_slim/pydantic_ai/_function_schema.py @@ -7,7 +7,7 @@ import inspect from collections.abc import Awaitable -from dataclasses import dataclass +from dataclasses import dataclass, field from inspect import Parameter, signature from typing import TYPE_CHECKING, Any, Callable, cast @@ -43,9 +43,9 @@ class FunctionSchema: # if not None, the function takes a single by that name (besides potentially `info`) takes_ctx: bool is_async: bool - single_arg_name: str | None - positional_fields: list[str] - var_positional_field: str | None + single_arg_name: str | None = None + positional_fields: list[str] = field(default_factory=list) + var_positional_field: str | None = None async def call(self, args_dict: dict[str, Any], ctx: RunContext[Any]) -> Any: args, kwargs = self._call_args(args_dict, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 6caa40ee98..68294f937c 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -346,9 +346,6 @@ def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) json_schema=json_schema, takes_ctx=False, is_async=asyncio.iscoroutinefunction(function), - single_arg_name=None, - positional_fields=[], - var_positional_field=None, ) return cls( From cd5e6c4baec0053741d8ccf49d25506e3dc70bcc Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 12:02:37 +0200 Subject: [PATCH 46/70] sort the items before showing them makes the test assertion more deterministic --- tests/test_tools.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_tools.py b/tests/test_tools.py index 84330f0c89..fab7695102 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1174,6 +1174,8 @@ def run( tool_call_id: Union[str, None] = None, **kwargs: Any, ) -> Any: + if isinstance(tool_input, dict): + tool_input = dict(sorted(tool_input.items())) return f'I was called with {tool_input}' @@ -1199,7 +1201,7 @@ def test_langchain_tool_conversion(): agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') - assert result.output == snapshot("{\"file_search\":\"I was called with {'pattern': 'a', 'dir_path': '.'}\"}") + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") assert agent._function_tools['file_search'].takes_ctx is False assert agent._function_tools['file_search'].max_retries == 7 @@ -1279,7 +1281,7 @@ def test_langchain_tool_defaults(): pydantic_tool = Tool.from_langchain(langchain_tool) result = pydantic_tool.function(pattern='something') # type: ignore - assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': '.'}") + assert result == snapshot("I was called with {'dir_path': '.', 'pattern': 'something'}") def test_langchain_tool_positional(): @@ -1303,7 +1305,7 @@ def test_langchain_tool_positional(): pydantic_tool = Tool.from_langchain(langchain_tool) result = pydantic_tool.function('something') # type: ignore - assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': '.'}") + assert result == snapshot("I was called with {'dir_path': '.', 'pattern': 'something'}") def test_langchain_tool_default_override(): @@ -1327,4 +1329,4 @@ def test_langchain_tool_default_override(): pydantic_tool = Tool.from_langchain(langchain_tool) result = pydantic_tool.function(pattern='something', dir_path='somewhere') # type: ignore - assert result == snapshot("I was called with {'pattern': 'something', 'dir_path': 'somewhere'}") + assert result == snapshot("I was called with {'dir_path': 'somewhere', 'pattern': 'something'}") From e1a556278a483a892aca18970eedbbbe9327ed1a Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 12:03:24 +0200 Subject: [PATCH 47/70] use dictionary merging --- pydantic_ai_slim/pydantic_ai/tools.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 68294f937c..8d2ef15f29 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -384,10 +384,7 @@ def from_langchain(cls, langchain_tool: LangChainTool) -> Self: def proxy(*args: Any, **kwargs: Any) -> str: for argument, key in zip(args, inputs.keys()): kwargs[key] = argument - for name, default_value in defaults.items(): - if name in kwargs: - continue - kwargs[name] = default_value + kwargs = defaults | kwargs return langchain_tool.run(kwargs) proxy.__name__ = function_name From 91949306aab3da05331237d9dcd14d5fc62322f9 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 12:06:20 +0200 Subject: [PATCH 48/70] Use JsonSchemaValue to describe arguments --- pydantic_ai_slim/pydantic_ai/tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 8d2ef15f29..8436244529 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -27,7 +27,7 @@ class LangChainTool(Protocol): # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} @property - def args(self) -> dict[str, dict[str, str]]: ... + def args(self) -> dict[str, JsonSchemaValue]: ... @property def name(self) -> str: ... From 9f89e02435d3e7b46656bd183c2adfc91b0da0ca Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 12:18:04 +0200 Subject: [PATCH 49/70] Make name and description arguments --- pydantic_ai_slim/pydantic_ai/tools.py | 21 ++++++++++----------- tests/test_tools.py | 18 +++++++++--------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 8436244529..3bbca6a707 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -326,22 +326,24 @@ async def prep_my_tool( self.strict = strict @classmethod - def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) -> Self: + def from_schema( + cls, function: Callable[..., Any], name: str, description: str, json_schema: JsonSchemaValue + ) -> Self: """Creates a Pydantic tool from a function and a JSON schema. Args: function: The function to call + name: The unique name of the tool that clearly communicates its purpose + description: Used to tell the model how/when/why to use the tool. + You can provide few-shot examples as a part of the description. json_schema: The schema for the function arguments Returns: A Pydantic tool that calls the function """ - function_name = function.__name__ - function_description = function.__doc__ or '' - function_schema = _function_schema.FunctionSchema( function=function, - description=function_description, + description=description, validator=SchemaValidator(schema=core_schema.any_schema()), json_schema=json_schema, takes_ctx=False, @@ -351,8 +353,8 @@ def from_schema(cls, function: Callable[..., Any], json_schema: JsonSchemaValue) return cls( function, takes_ctx=False, - name=function_name, - description=function_description, + name=name, + description=description, function_schema=function_schema, ) @@ -387,10 +389,7 @@ def proxy(*args: Any, **kwargs: Any) -> str: kwargs = defaults | kwargs return langchain_tool.run(kwargs) - proxy.__name__ = function_name - proxy.__doc__ = function_description - - return cls.from_schema(function=proxy, json_schema=schema) + return cls.from_schema(function=proxy, name=function_name, description=function_description, json_schema=schema) async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. diff --git a/tests/test_tools.py b/tests/test_tools.py index fab7695102..2c1829c510 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1096,13 +1096,13 @@ def function(*args: Any, **kwargs: Any) -> str: }, 'required': ['one', 'two'], } - pydantic_tool = Tool.from_schema(function, json_schema=json_schema) + pydantic_tool = Tool.from_schema(function, name='foobar', description='does foobar stuff', json_schema=json_schema) agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') - assert result.output == snapshot('{"function":"I like being called like this"}') - assert agent._function_tools['function'].takes_ctx is False - assert agent._function_tools['function'].max_retries == 0 + assert result.output == snapshot('{"foobar":"I like being called like this"}') + assert agent._function_tools['foobar'].takes_ctx is False + assert agent._function_tools['foobar'].max_retries == 0 def test_function_tool_inconsistent_with_schema(): @@ -1118,7 +1118,7 @@ def function(three: str, four: int) -> str: }, 'required': ['one', 'two'], } - pydantic_tool = Tool.from_schema(function, json_schema=json_schema) + pydantic_tool = Tool.from_schema(function, name='foobar', description='does foobar stuff', json_schema=json_schema) agent = Agent('test', tools=[pydantic_tool], retries=0) with pytest.raises(TypeError, match=".* got an unexpected keyword argument 'one'"): @@ -1143,13 +1143,13 @@ async def function(*args: Any, **kwargs: Any) -> str: }, 'required': ['one', 'two'], } - pydantic_tool = Tool.from_schema(function, json_schema=json_schema) + pydantic_tool = Tool.from_schema(function, name='foobar', description='does foobar stuff', json_schema=json_schema) agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') - assert result.output == snapshot('{"function":"I like being called like this"}') - assert agent._function_tools['function'].takes_ctx is False - assert agent._function_tools['function'].max_retries == 0 + assert result.output == snapshot('{"foobar":"I like being called like this"}') + assert agent._function_tools['foobar'].takes_ctx is False + assert agent._function_tools['foobar'].max_retries == 0 @dataclass From d2462d88026dc5098da1534f8cafbc101ccbb86d Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 12:25:18 +0200 Subject: [PATCH 50/70] Revert unrelated formatting changes --- pydantic_ai_slim/pydantic_ai/tools.py | 9 ++--- tests/test_tools.py | 49 +++++---------------------- 2 files changed, 11 insertions(+), 47 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 3bbca6a707..2e80d0a731 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -82,9 +82,7 @@ class RunContext(Generic[AgentDepsT]): """The current step in the run.""" def replace_with( - self, - retry: int | None = None, - tool_name: str | None | _utils.Unset = _utils.UNSET, + self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET, ) -> RunContext[AgentDepsT]: # Create a new `RunContext` a new `retry` value and `tool_name`. kwargs = {} @@ -412,10 +410,7 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition return tool_def async def run( - self, - message: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - tracer: Tracer, + self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer, ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: """Run the tool function asynchronously. diff --git a/tests/test_tools.py b/tests/test_tools.py index 2c1829c510..8057e0363c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,14 +12,7 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext, Tool, ToolOutput, UserError -from pydantic_ai.messages import ( - ModelMessage, - ModelRequest, - ModelResponse, - TextPart, - ToolCallPart, - ToolReturnPart, -) +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition @@ -333,11 +326,7 @@ def test_only_returns_type(): The result as a string. \ """, - 'parameters_json_schema': { - 'additionalProperties': False, - 'properties': {}, - 'type': 'object', - }, + 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, } @@ -513,12 +502,7 @@ def ctx_tool(ctx: RunContext[int], x: int) -> int: # pyright: reportPrivateUsage=false def test_init_tool_ctx(): - agent = Agent( - 'test', - tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], - deps_type=int, - retries=7, - ) + agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7) result = agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') assert agent._function_tools['ctx_tool'].takes_ctx is True @@ -601,10 +585,7 @@ def plain_tool(x: int) -> int: # pragma: no cover def test_init_plain_tool_invalid(): - with pytest.raises( - UserError, - match='RunContext annotations can only be used with tools that take context', - ): + with pytest.raises(UserError, match='RunContext annotations can only be used with tools that take context'): Tool(ctx_tool, takes_ctx=False) @@ -651,10 +632,7 @@ def test_return_bytes_invalid(): def return_pydantic_model() -> bytes: return b'\00 \x81' - with pytest.raises( - PydanticSerializationError, - match='invalid utf-8 sequence of 1 bytes from index 2', - ): + with pytest.raises(PydanticSerializationError, match='invalid utf-8 sequence of 1 bytes from index 2'): agent.run_sync('') @@ -789,8 +767,7 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: def test_future_run_context(create_module: Callable[[str], Any]): - mod = create_module( - """ + mod = create_module(""" from __future__ import annotations from pydantic_ai import Agent, RunContext @@ -799,8 +776,7 @@ def ctx_tool(ctx: RunContext[int], x: int) -> int: return x + ctx.deps agent = Agent('test', tools=[ctx_tool], deps_type=int) - """ - ) + """) result = mod.agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') @@ -824,11 +800,7 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture): 'description': "A tool that documents what it returns but doesn't have a return annotation in the docstring.", 'name': 'tool_without_return_annotation_in_docstring', 'outer_typed_dict_key': None, - 'parameters_json_schema': { - 'additionalProperties': False, - 'properties': {}, - 'type': 'object', - }, + 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'strict': None, } ) @@ -976,10 +948,7 @@ def typed_dict_schema(self, schema: core_schema.TypedDictSchema) -> JsonSchemaVa agent = Agent(FunctionModel(get_json_schema)) - def my_tool( - x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = None, - **kwargs: Any, - ): + def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = None, **kwargs: Any): return x # pragma: no cover agent.tool_plain(name='my_tool_1')(my_tool) From 1c99adfda0403b0a0b728afed617dd75396d2845 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 12:26:47 +0200 Subject: [PATCH 51/70] uv run ruff format --- pydantic_ai_slim/pydantic_ai/tools.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 2e80d0a731..3bbca6a707 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -82,7 +82,9 @@ class RunContext(Generic[AgentDepsT]): """The current step in the run.""" def replace_with( - self, retry: int | None = None, tool_name: str | None | _utils.Unset = _utils.UNSET, + self, + retry: int | None = None, + tool_name: str | None | _utils.Unset = _utils.UNSET, ) -> RunContext[AgentDepsT]: # Create a new `RunContext` a new `retry` value and `tool_name`. kwargs = {} @@ -410,7 +412,10 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition return tool_def async def run( - self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], tracer: Tracer, + self, + message: _messages.ToolCallPart, + run_context: RunContext[AgentDepsT], + tracer: Tracer, ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: """Run the tool function asynchronously. From c7094026e6426a60d9d8a4d93028dce035913905 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 20:11:05 +0200 Subject: [PATCH 52/70] Handle $refs in the langchain schema --- pydantic_ai_slim/pydantic_ai/tools.py | 23 ++++++++++++++++------- tests/test_tools.py | 7 +++++++ 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 3bbca6a707..c84e53a7cf 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -29,6 +29,8 @@ class LangChainTool(Protocol): @property def args(self) -> dict[str, JsonSchemaValue]: ... + def get_input_jsonschema(self) -> JsonSchemaValue: ... + @property def name(self) -> str: ... @@ -327,7 +329,11 @@ async def prep_my_tool( @classmethod def from_schema( - cls, function: Callable[..., Any], name: str, description: str, json_schema: JsonSchemaValue + cls, + function: Callable[..., Any], + name: str, + description: str, + json_schema: JsonSchemaValue, ) -> Self: """Creates a Pydantic tool from a function and a JSON schema. @@ -372,11 +378,9 @@ def from_langchain(cls, langchain_tool: LangChainTool) -> Self: function_description = langchain_tool.description inputs = langchain_tool.args.copy() required = sorted({name for name, detail in inputs.items() if 'default' not in detail}) - schema: JsonSchemaValue = { - 'type': 'object', - 'properties': inputs, - 'additionalProperties': False, - } + schema: JsonSchemaValue = langchain_tool.get_input_jsonschema() + if 'additionalProperties' not in schema: + schema['additionalProperties'] = False if required: schema['required'] = required @@ -389,7 +393,12 @@ def proxy(*args: Any, **kwargs: Any) -> str: kwargs = defaults | kwargs return langchain_tool.run(kwargs) - return cls.from_schema(function=proxy, name=function_name, description=function_description, json_schema=schema) + return cls.from_schema( + function=proxy, + name=function_name, + description=function_description, + json_schema=schema, + ) async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. diff --git a/tests/test_tools.py b/tests/test_tools.py index 8057e0363c..99e0862b95 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1147,6 +1147,13 @@ def run( tool_input = dict(sorted(tool_input.items())) return f'I was called with {tool_input}' + def get_input_jsonschema(self) -> JsonSchemaValue: + return { + 'type': 'object', + 'properties': self.args, + 'additionalProperties': False, + } + def test_langchain_tool_conversion(): langchain_tool = SimulatedLangChainTool( From f075d7d22947f02e8a3493dcbe4c378e30979947 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 20:20:21 +0200 Subject: [PATCH 53/70] Replace arg iteration with assertion --- pydantic_ai_slim/pydantic_ai/tools.py | 3 +-- tests/test_tools.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index c84e53a7cf..9b5b3a809a 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -388,8 +388,7 @@ def from_langchain(cls, langchain_tool: LangChainTool) -> Self: # restructures the arguments to match langchain tool run def proxy(*args: Any, **kwargs: Any) -> str: - for argument, key in zip(args, inputs.keys()): - kwargs[key] = argument + assert not args, 'This should always be called with kwargs' kwargs = defaults | kwargs return langchain_tool.run(kwargs) diff --git a/tests/test_tools.py b/tests/test_tools.py index 99e0862b95..ec907af1f5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1280,8 +1280,8 @@ def test_langchain_tool_positional(): ) pydantic_tool = Tool.from_langchain(langchain_tool) - result = pydantic_tool.function('something') # type: ignore - assert result == snapshot("I was called with {'dir_path': '.', 'pattern': 'something'}") + with pytest.raises(AssertionError, match='This should always be called with kwargs'): + pydantic_tool.function('something') # type: ignore def test_langchain_tool_default_override(): From 65a04bec1d421f1f53d76fd08e0bbb859d23b45e Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 21:03:51 +0200 Subject: [PATCH 54/70] Add flag to allow testing additional properties branch --- tests/test_tools.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/test_tools.py b/tests/test_tools.py index ec907af1f5..82fac6a676 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1126,6 +1126,7 @@ class SimulatedLangChainTool: name: str description: str args: dict[str, dict[str, str]] + additional_properties_missing: bool = False def run( self, @@ -1148,6 +1149,11 @@ def run( return f'I was called with {tool_input}' def get_input_jsonschema(self) -> JsonSchemaValue: + if self.additional_properties_missing: + return { + 'type': 'object', + 'properties': self.args, + } return { 'type': 'object', 'properties': self.args, @@ -1182,6 +1188,34 @@ def test_langchain_tool_conversion(): assert agent._function_tools['file_search'].max_retries == 7 +def test_langchain_tool_no_additional_properties(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + additional_properties_missing=True, + ) + pydantic_tool = Tool.from_langchain(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 + + def test_langchain_tool_conversion_no_defaults(): langchain_tool = SimulatedLangChainTool( name='file_search', From b91cd66de3be8d2b119e2d66d863965f5ed93cc5 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 21:04:14 +0200 Subject: [PATCH 55/70] Write langchain tool section HEAVILY BASED ON SMOLAGENTS DOCS https://smolagents.org/docs/tools-of-smolagents-in-depth-guide/#3-toc-title They write well and in a positive way. --- docs/tools.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/tools.md b/docs/tools.md index 921edd95e1..68e0d020fe 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -624,3 +624,30 @@ def my_flaky_tool(query: str) -> str: return 'Success!' ``` Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception message, which is sent back to the LLM to guide its next attempt. Both `ValidationError` and `ModelRetry` respect the `retries` setting configured on the `Tool` or `Agent`. + +## Use LangChain Tools {#langchain-tools} + +We love LangChain and think it has a very compelling suite of tools. To import a tool from LangChain, use the `from_langchain()` method. + +Here is how you can use it to augment model responses using a LangChain web search tool. This tool will need you to install the `langchain-community` and `duckduckgo-search` dependencies to work properly. + +```python +from pydantic_ai import Tool +from langchain_community.tools import DuckDuckGoSearchRun + +search = DuckDuckGoSearchRun() +search_tool = Tool.from_langchain(search) + +agent = Agent( + 'google-gla:gemini-2.0-flash', # (1)! + tools=[search_tool], +) + +agent.run("What is the release date of Elden Ring Nightreign?") # (2)! +#> Elden Ring Nightreign is planned to be released on May 30, 2025. ... # (3)! +``` + + +1. While this task is simple Gemini 1.5 didn't want to use the provided tool. Gemini 2.0 is still fast and cheap. +2. The release date of this game is the 30th of May 2025, which was confirmed after the knowledge cutoff for Gemini 2.0 (August 2024). +3. Without the tool you get the answer: _There is no Elden Ring expansion called "Nightreign."..._ From 5300f3f29d6d867d051fbbc4fa6c6b14793e8c33 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 21:23:19 +0200 Subject: [PATCH 56/70] formatting for docs --- docs/tools.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index 68e0d020fe..1b603896db 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -632,8 +632,8 @@ We love LangChain and think it has a very compelling suite of tools. To import a Here is how you can use it to augment model responses using a LangChain web search tool. This tool will need you to install the `langchain-community` and `duckduckgo-search` dependencies to work properly. ```python -from pydantic_ai import Tool from langchain_community.tools import DuckDuckGoSearchRun +from pydantic_ai import Tool search = DuckDuckGoSearchRun() search_tool = Tool.from_langchain(search) @@ -643,8 +643,8 @@ agent = Agent( tools=[search_tool], ) -agent.run("What is the release date of Elden Ring Nightreign?") # (2)! -#> Elden Ring Nightreign is planned to be released on May 30, 2025. ... # (3)! +agent.run('What is the release date of Elden Ring Nightreign?') # (2)! +#> Elden Ring Nightreign is planned to be released on May 30, 2025. ... # (3)! ``` From a06f8148daaf628e9e26ca4178abda3fb1b0a1cf Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 21:46:34 +0200 Subject: [PATCH 57/70] Testing imports in docs is neat --- docs/tools.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tools.md b/docs/tools.md index 1b603896db..5049b2bd61 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -633,7 +633,7 @@ Here is how you can use it to augment model responses using a LangChain web sear ```python from langchain_community.tools import DuckDuckGoSearchRun -from pydantic_ai import Tool +from pydantic_ai import Agent, Tool search = DuckDuckGoSearchRun() search_tool = Tool.from_langchain(search) From 24507f99bea35c3bc14cfc2fbd97af6653d3b09f Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Thu, 29 May 2025 21:53:55 +0200 Subject: [PATCH 58/70] import sorting --- docs/tools.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tools.md b/docs/tools.md index 5049b2bd61..f38f03eb57 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -632,8 +632,8 @@ We love LangChain and think it has a very compelling suite of tools. To import a Here is how you can use it to augment model responses using a LangChain web search tool. This tool will need you to install the `langchain-community` and `duckduckgo-search` dependencies to work properly. ```python -from langchain_community.tools import DuckDuckGoSearchRun from pydantic_ai import Agent, Tool +from langchain_community.tools import DuckDuckGoSearchRun search = DuckDuckGoSearchRun() search_tool = Tool.from_langchain(search) From 3ab281f3b142881e31a632cefca7b0382ab0de9c Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Fri, 30 May 2025 10:29:56 +0200 Subject: [PATCH 59/70] Rephrase and skip tests over code The langchain-community dependency is not installed so it cannot be run --- docs/tools.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index f38f03eb57..561ee6267e 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -627,14 +627,15 @@ Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception ## Use LangChain Tools {#langchain-tools} -We love LangChain and think it has a very compelling suite of tools. To import a tool from LangChain, use the `from_langchain()` method. +If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with PydanticAI, you can use the `Tool.from_langchain` convenience method. Note that PydanticAI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. Here is how you can use it to augment model responses using a LangChain web search tool. This tool will need you to install the `langchain-community` and `duckduckgo-search` dependencies to work properly. -```python -from pydantic_ai import Agent, Tool +```python {test="skip"} from langchain_community.tools import DuckDuckGoSearchRun +from pydantic_ai import Agent, Tool + search = DuckDuckGoSearchRun() search_tool = Tool.from_langchain(search) From e67b0abd17526c55b94f00619881e41475efb3e9 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Fri, 30 May 2025 10:30:23 +0200 Subject: [PATCH 60/70] Clarify the requirements over the function --- pydantic_ai_slim/pydantic_ai/tools.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 9b5b3a809a..94623d82ac 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -338,7 +338,9 @@ def from_schema( """Creates a Pydantic tool from a function and a JSON schema. Args: - function: The function to call + function: The function to call. + This will be called with keywords only, and no validation of + the arguments will be performed. name: The unique name of the tool that clearly communicates its purpose description: Used to tell the model how/when/why to use the tool. You can provide few-shot examples as a part of the description. From d96fb5b9369e43a459f6e8685c93b705607fbd40 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Fri, 30 May 2025 10:52:01 +0200 Subject: [PATCH 61/70] Add section on from_schema --- docs/tools.md | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/tools.md b/docs/tools.md index 561ee6267e..8c0a032b44 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -409,6 +409,33 @@ print(test_model.last_model_request_parameters.function_tools) _(This example is complete, it can be run "as is")_ +If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the `Tool.from_schema` function. With this you provide the name, description and JSON schema for the function directly: + +```python +from pydantic_ai import Tool + +def foobar(**kwargs) -> str: + return kwargs['a'] + kwargs['b'] + +tool = Tool.from_schema( + function=foobar, + name="sum", + description="Sum two numbers." + json_schema={ + 'additionalProperties': False, + 'properties': { + 'a': {'description': 'the first number', 'type': 'integer'}, + 'b': {'description': 'the second number', 'type': 'integer'}, + }, + 'required': ['a', 'b'], + 'type': 'object', + } +) +``` + + +Please note that validation of the tool arguments will not be performed, and this will pass all arguments as keyword arguments. + ## Dynamic Function tools {#tool-prepare} Tools can optionally be defined with another function: `prepare`, which is called at each step of a run to From 47e9d7bbb59c33531067d986abe9814f359aeefb Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Fri, 30 May 2025 10:57:09 +0200 Subject: [PATCH 62/70] replace double quotes with single quotes, add comma --- docs/tools.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index 8c0a032b44..886d783a2d 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -419,8 +419,8 @@ def foobar(**kwargs) -> str: tool = Tool.from_schema( function=foobar, - name="sum", - description="Sum two numbers." + name='sum', + description='Sum two numbers.', json_schema={ 'additionalProperties': False, 'properties': { From 17969fbda21126cdaab90dee377c169fa54dca9a Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Fri, 30 May 2025 10:57:49 +0200 Subject: [PATCH 63/70] docs/tools.md:415:1: I001 [*] Import block is un-sorted or un-formatted I am fed up with this linter. It's one import. It still fails. It doesn't say how to fix it and there isn't a way to automatically apply the fix. --- docs/tools.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/tools.md b/docs/tools.md index 886d783a2d..619ced0795 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -412,7 +412,7 @@ _(This example is complete, it can be run "as is")_ If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the `Tool.from_schema` function. With this you provide the name, description and JSON schema for the function directly: ```python -from pydantic_ai import Tool +from pydantic_ai.tools import Tool def foobar(**kwargs) -> str: return kwargs['a'] + kwargs['b'] From 77f6a98399112b981bd925867dbdb543936cc607 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 3 Jun 2025 09:51:40 +0100 Subject: [PATCH 64/70] Move langchain tool converter --- docs/tools.md | 7 +- pydantic_ai_slim/pydantic_ai/ext/__init__.py | 0 pydantic_ai_slim/pydantic_ai/ext/langchain.py | 61 +++++ pydantic_ai_slim/pydantic_ai/tools.py | 55 ---- tests/ext/__init__.py | 0 tests/ext/test_langchain.py | 252 ++++++++++++++++++ tests/test_tools.py | 221 --------------- 7 files changed, 317 insertions(+), 279 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/ext/__init__.py create mode 100644 pydantic_ai_slim/pydantic_ai/ext/langchain.py create mode 100644 tests/ext/__init__.py create mode 100644 tests/ext/test_langchain.py diff --git a/docs/tools.md b/docs/tools.md index 619ced0795..88a097c298 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -654,17 +654,18 @@ Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception ## Use LangChain Tools {#langchain-tools} -If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with PydanticAI, you can use the `Tool.from_langchain` convenience method. Note that PydanticAI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. +If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with PydanticAI, you can use the `pydancic_ai.ext.langchain.from_langchain_tool` convenience method. Note that PydanticAI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. Here is how you can use it to augment model responses using a LangChain web search tool. This tool will need you to install the `langchain-community` and `duckduckgo-search` dependencies to work properly. ```python {test="skip"} from langchain_community.tools import DuckDuckGoSearchRun -from pydantic_ai import Agent, Tool +from pydantic_ai import Agent +from pydantic_ai.ext.langchain import from_langchain_tool search = DuckDuckGoSearchRun() -search_tool = Tool.from_langchain(search) +search_tool = from_langchain_tool(search) agent = Agent( 'google-gla:gemini-2.0-flash', # (1)! diff --git a/pydantic_ai_slim/pydantic_ai/ext/__init__.py b/pydantic_ai_slim/pydantic_ai/ext/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py new file mode 100644 index 0000000000..9bf1e8e985 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -0,0 +1,61 @@ +from typing import Any, Protocol + +from pydantic.json_schema import JsonSchemaValue + +from pydantic_ai.tools import Tool + + +class LangChainTool(Protocol): + # args are like + # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, + # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} + @property + def args(self) -> dict[str, JsonSchemaValue]: ... + + def get_input_jsonschema(self) -> JsonSchemaValue: ... + + @property + def name(self) -> str: ... + + @property + def description(self) -> str: ... + + def run(self, *args: Any, **kwargs: Any) -> str: ... + + +__all__ = ('from_langchain_tool',) + + +def from_langchain_tool(langchain_tool: LangChainTool) -> Tool: + """Creates a Pydantic tool proxy from a LangChain tool. + + Args: + langchain_tool: The LangChain tool to wrap. + + Returns: + A Pydantic tool that corresponds to the LangChain tool. + """ + function_name = langchain_tool.name + function_description = langchain_tool.description + inputs = langchain_tool.args.copy() + required = sorted({name for name, detail in inputs.items() if 'default' not in detail}) + schema: JsonSchemaValue = langchain_tool.get_input_jsonschema() + if 'additionalProperties' not in schema: + schema['additionalProperties'] = False + if required: + schema['required'] = required + + defaults = {name: detail['default'] for name, detail in inputs.items() if 'default' in detail} + + # restructures the arguments to match langchain tool run + def proxy(*args: Any, **kwargs: Any) -> str: + assert not args, 'This should always be called with kwargs' + kwargs = defaults | kwargs + return langchain_tool.run(kwargs) + + return Tool.from_schema( + function=proxy, + name=function_name, + description=function_description, + json_schema=schema, + ) diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 94623d82ac..ae7427a10b 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -17,29 +17,9 @@ from .exceptions import ModelRetry, UnexpectedModelBehavior if TYPE_CHECKING: - from typing import Protocol - from .models import Model from .result import Usage - class LangChainTool(Protocol): - # args are like - # {'dir_path': {'default': '.', 'description': 'Subdirectory to search in.', 'title': 'Dir Path', 'type': 'string'}, - # 'pattern': {'description': 'Unix shell regex, where * matches everything.', 'title': 'Pattern', 'type': 'string'}} - @property - def args(self) -> dict[str, JsonSchemaValue]: ... - - def get_input_jsonschema(self) -> JsonSchemaValue: ... - - @property - def name(self) -> str: ... - - @property - def description(self) -> str: ... - - def run(self, *args: Any, **kwargs: Any) -> str: ... - - __all__ = ( 'AgentDepsT', 'DocstringFormat', @@ -366,41 +346,6 @@ def from_schema( function_schema=function_schema, ) - @classmethod - def from_langchain(cls, langchain_tool: LangChainTool) -> Self: - """Creates a Pydantic tool proxy from a LangChain tool. - - Args: - langchain_tool: The LangChain tool to wrap. - - Returns: - A Pydantic tool that corresponds to the LangChain tool. - """ - function_name = langchain_tool.name - function_description = langchain_tool.description - inputs = langchain_tool.args.copy() - required = sorted({name for name, detail in inputs.items() if 'default' not in detail}) - schema: JsonSchemaValue = langchain_tool.get_input_jsonschema() - if 'additionalProperties' not in schema: - schema['additionalProperties'] = False - if required: - schema['required'] = required - - defaults = {name: detail['default'] for name, detail in inputs.items() if 'default' in detail} - - # restructures the arguments to match langchain tool run - def proxy(*args: Any, **kwargs: Any) -> str: - assert not args, 'This should always be called with kwargs' - kwargs = defaults | kwargs - return langchain_tool.run(kwargs) - - return cls.from_schema( - function=proxy, - name=function_name, - description=function_description, - json_schema=schema, - ) - async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. diff --git a/tests/ext/__init__.py b/tests/ext/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/ext/test_langchain.py b/tests/ext/test_langchain.py new file mode 100644 index 0000000000..0c79aa72b7 --- /dev/null +++ b/tests/ext/test_langchain.py @@ -0,0 +1,252 @@ +from dataclasses import dataclass +from typing import Any, Union + +import pytest +from inline_snapshot import snapshot +from pydantic.json_schema import JsonSchemaValue + +from pydantic_ai import Agent +from pydantic_ai.ext.langchain import from_langchain_tool + + +@dataclass +class SimulatedLangChainTool: + name: str + description: str + args: dict[str, dict[str, str]] + additional_properties_missing: bool = False + + def run( + self, + tool_input: Union[str, dict[str, Any]], + verbose: Union[bool, None] = None, + start_color: Union[str, None] = 'green', + color: Union[str, None] = 'green', + callbacks: Any = None, + *, + tags: Union[list[str], None] = None, + metadata: Union[dict[str, Any], None] = None, + run_name: Union[str, None] = None, + run_id: Union[Any, None] = None, + config: Union[Any, None] = None, + tool_call_id: Union[str, None] = None, + **kwargs: Any, + ) -> Any: + if isinstance(tool_input, dict): + tool_input = dict(sorted(tool_input.items())) + return f'I was called with {tool_input}' + + def get_input_jsonschema(self) -> JsonSchemaValue: + if self.additional_properties_missing: + return { + 'type': 'object', + 'properties': self.args, + } + return { + 'type': 'object', + 'properties': self.args, + 'additionalProperties': False, + } + + +def test_langchain_tool_conversion(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = from_langchain_tool(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 + + +def test_langchain_tool_no_additional_properties(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + additional_properties_missing=True, + ) + pydantic_tool = from_langchain_tool(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 + + +def test_langchain_tool_conversion_no_defaults(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = from_langchain_tool(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': 'a', 'pattern': 'a'}\"}") + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 + + +def test_langchain_tool_conversion_no_required(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'default': '*', + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = from_langchain_tool(langchain_tool) + + agent = Agent('test', tools=[pydantic_tool], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': '*'}\"}") + assert agent._function_tools['file_search'].takes_ctx is False + assert agent._function_tools['file_search'].max_retries == 7 + + +def test_langchain_tool_defaults(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = from_langchain_tool(langchain_tool) + + result = pydantic_tool.function(pattern='something') # type: ignore + assert result == snapshot("I was called with {'dir_path': '.', 'pattern': 'something'}") + + +def test_langchain_tool_positional(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + }, + ) + pydantic_tool = from_langchain_tool(langchain_tool) + + with pytest.raises(AssertionError, match='This should always be called with kwargs'): + pydantic_tool.function('something') # type: ignore + + +def test_langchain_tool_default_override(): + langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + pydantic_tool = from_langchain_tool(langchain_tool) + + result = pydantic_tool.function(pattern='something', dir_path='somewhere') # type: ignore + assert result == snapshot("I was called with {'dir_path': 'somewhere', 'pattern': 'something'}") + + +def test_simulated_tool_string_input(): + tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', + }, + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, + ) + result = tool.run('this string argument') + assert result == snapshot('I was called with this string argument') diff --git a/tests/test_tools.py b/tests/test_tools.py index 82fac6a676..4c09b31e81 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1119,224 +1119,3 @@ async def function(*args: Any, **kwargs: Any) -> str: assert result.output == snapshot('{"foobar":"I like being called like this"}') assert agent._function_tools['foobar'].takes_ctx is False assert agent._function_tools['foobar'].max_retries == 0 - - -@dataclass -class SimulatedLangChainTool: - name: str - description: str - args: dict[str, dict[str, str]] - additional_properties_missing: bool = False - - def run( - self, - tool_input: Union[str, dict[str, Any]], - verbose: Union[bool, None] = None, - start_color: Union[str, None] = 'green', - color: Union[str, None] = 'green', - callbacks: Any = None, - *, - tags: Union[list[str], None] = None, - metadata: Union[dict[str, Any], None] = None, - run_name: Union[str, None] = None, - run_id: Union[Any, None] = None, - config: Union[Any, None] = None, - tool_call_id: Union[str, None] = None, - **kwargs: Any, - ) -> Any: - if isinstance(tool_input, dict): - tool_input = dict(sorted(tool_input.items())) - return f'I was called with {tool_input}' - - def get_input_jsonschema(self) -> JsonSchemaValue: - if self.additional_properties_missing: - return { - 'type': 'object', - 'properties': self.args, - } - return { - 'type': 'object', - 'properties': self.args, - 'additionalProperties': False, - } - - -def test_langchain_tool_conversion(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, - }, - ) - pydantic_tool = Tool.from_langchain(langchain_tool) - - agent = Agent('test', tools=[pydantic_tool], retries=7) - result = agent.run_sync('foobar') - assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 - - -def test_langchain_tool_no_additional_properties(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, - }, - additional_properties_missing=True, - ) - pydantic_tool = Tool.from_langchain(langchain_tool) - - agent = Agent('test', tools=[pydantic_tool], retries=7) - result = agent.run_sync('foobar') - assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 - - -def test_langchain_tool_conversion_no_defaults(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, - }, - ) - pydantic_tool = Tool.from_langchain(langchain_tool) - - agent = Agent('test', tools=[pydantic_tool], retries=7) - result = agent.run_sync('foobar') - assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': 'a', 'pattern': 'a'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 - - -def test_langchain_tool_conversion_no_required(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'default': '*', - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, - }, - ) - pydantic_tool = Tool.from_langchain(langchain_tool) - - agent = Agent('test', tools=[pydantic_tool], retries=7) - result = agent.run_sync('foobar') - assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': '*'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 - - -def test_langchain_tool_defaults(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, - }, - ) - pydantic_tool = Tool.from_langchain(langchain_tool) - - result = pydantic_tool.function(pattern='something') # type: ignore - assert result == snapshot("I was called with {'dir_path': '.', 'pattern': 'something'}") - - -def test_langchain_tool_positional(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - }, - ) - pydantic_tool = Tool.from_langchain(langchain_tool) - - with pytest.raises(AssertionError, match='This should always be called with kwargs'): - pydantic_tool.function('something') # type: ignore - - -def test_langchain_tool_default_override(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, - }, - ) - pydantic_tool = Tool.from_langchain(langchain_tool) - - result = pydantic_tool.function(pattern='something', dir_path='somewhere') # type: ignore - assert result == snapshot("I was called with {'dir_path': 'somewhere', 'pattern': 'something'}") From 01428581b0109aa5355db2ebeab859726ab61dde Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 3 Jun 2025 09:58:22 +0100 Subject: [PATCH 65/70] Add example of using tool from Tool.from_schema --- docs/tools.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/docs/tools.md b/docs/tools.md index 88a097c298..56fb02ce9e 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -412,7 +412,8 @@ _(This example is complete, it can be run "as is")_ If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the `Tool.from_schema` function. With this you provide the name, description and JSON schema for the function directly: ```python -from pydantic_ai.tools import Tool +from pydantic_ai import Agent, Tool +from pydantic_ai.models.test import TestModel def foobar(**kwargs) -> str: return kwargs['a'] + kwargs['b'] @@ -431,6 +432,13 @@ tool = Tool.from_schema( 'type': 'object', } ) + +test_model = TestModel() +agent = Agent(test_model, tools=[tool]) + +result = agent.run_sync('testing...') +print(result.output) +#> {"sum":0} ``` From 77a7beea4b41384de6106393b449a444fa8d04e7 Mon Sep 17 00:00:00 2001 From: Matthew Franglen Date: Tue, 3 Jun 2025 10:15:20 +0100 Subject: [PATCH 66/70] remove assertions over _function_tools --- tests/ext/test_langchain.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/ext/test_langchain.py b/tests/ext/test_langchain.py index 0c79aa72b7..0cd583885f 100644 --- a/tests/ext/test_langchain.py +++ b/tests/ext/test_langchain.py @@ -72,8 +72,6 @@ def test_langchain_tool_conversion(): agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 def test_langchain_tool_no_additional_properties(): @@ -100,8 +98,6 @@ def test_langchain_tool_no_additional_properties(): agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 def test_langchain_tool_conversion_no_defaults(): @@ -126,8 +122,6 @@ def test_langchain_tool_conversion_no_defaults(): agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': 'a', 'pattern': 'a'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 def test_langchain_tool_conversion_no_required(): @@ -154,8 +148,6 @@ def test_langchain_tool_conversion_no_required(): agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': '*'}\"}") - assert agent._function_tools['file_search'].takes_ctx is False - assert agent._function_tools['file_search'].max_retries == 7 def test_langchain_tool_defaults(): From 669b299b6acc5c9005e89e8995fe18960ec52330 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 6 Jun 2025 18:32:32 +0000 Subject: [PATCH 67/70] Satisfy linter --- docs/tools.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/tools.md b/docs/tools.md index 56fb02ce9e..eeaa5b657f 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -415,6 +415,7 @@ If you have a function that lacks appropriate documentation (i.e. poorly named, from pydantic_ai import Agent, Tool from pydantic_ai.models.test import TestModel + def foobar(**kwargs) -> str: return kwargs['a'] + kwargs['b'] From d83ef10dbdf4a8884089a4d2463c6838a7c74683 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 6 Jun 2025 18:40:31 +0000 Subject: [PATCH 68/70] Rename helper to tool_from_langchain --- docs/tools.md | 6 +++--- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 4 ++-- tests/ext/test_langchain.py | 16 ++++++++-------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index eeaa5b657f..affd100412 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -663,7 +663,7 @@ Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception ## Use LangChain Tools {#langchain-tools} -If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with PydanticAI, you can use the `pydancic_ai.ext.langchain.from_langchain_tool` convenience method. Note that PydanticAI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. +If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with PydanticAI, you can use the `pydancic_ai.ext.langchain.tool_from_langchain` convenience method. Note that PydanticAI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. Here is how you can use it to augment model responses using a LangChain web search tool. This tool will need you to install the `langchain-community` and `duckduckgo-search` dependencies to work properly. @@ -671,10 +671,10 @@ Here is how you can use it to augment model responses using a LangChain web sear from langchain_community.tools import DuckDuckGoSearchRun from pydantic_ai import Agent -from pydantic_ai.ext.langchain import from_langchain_tool +from pydantic_ai.ext.langchain import tool_from_langchain search = DuckDuckGoSearchRun() -search_tool = from_langchain_tool(search) +search_tool = tool_from_langchain(search) agent = Agent( 'google-gla:gemini-2.0-flash', # (1)! diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 9bf1e8e985..9d13adda07 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -23,10 +23,10 @@ def description(self) -> str: ... def run(self, *args: Any, **kwargs: Any) -> str: ... -__all__ = ('from_langchain_tool',) +__all__ = ('tool_from_langchain',) -def from_langchain_tool(langchain_tool: LangChainTool) -> Tool: +def tool_from_langchain(langchain_tool: LangChainTool) -> Tool: """Creates a Pydantic tool proxy from a LangChain tool. Args: diff --git a/tests/ext/test_langchain.py b/tests/ext/test_langchain.py index 0cd583885f..73e7cc0504 100644 --- a/tests/ext/test_langchain.py +++ b/tests/ext/test_langchain.py @@ -6,7 +6,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai import Agent -from pydantic_ai.ext.langchain import from_langchain_tool +from pydantic_ai.ext.langchain import tool_from_langchain @dataclass @@ -67,7 +67,7 @@ def test_langchain_tool_conversion(): }, }, ) - pydantic_tool = from_langchain_tool(langchain_tool) + pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') @@ -93,7 +93,7 @@ def test_langchain_tool_no_additional_properties(): }, additional_properties_missing=True, ) - pydantic_tool = from_langchain_tool(langchain_tool) + pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') @@ -117,7 +117,7 @@ def test_langchain_tool_conversion_no_defaults(): }, }, ) - pydantic_tool = from_langchain_tool(langchain_tool) + pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') @@ -143,7 +143,7 @@ def test_langchain_tool_conversion_no_required(): }, }, ) - pydantic_tool = from_langchain_tool(langchain_tool) + pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) result = agent.run_sync('foobar') @@ -168,7 +168,7 @@ def test_langchain_tool_defaults(): }, }, ) - pydantic_tool = from_langchain_tool(langchain_tool) + pydantic_tool = tool_from_langchain(langchain_tool) result = pydantic_tool.function(pattern='something') # type: ignore assert result == snapshot("I was called with {'dir_path': '.', 'pattern': 'something'}") @@ -192,7 +192,7 @@ def test_langchain_tool_positional(): }, }, ) - pydantic_tool = from_langchain_tool(langchain_tool) + pydantic_tool = tool_from_langchain(langchain_tool) with pytest.raises(AssertionError, match='This should always be called with kwargs'): pydantic_tool.function('something') # type: ignore @@ -216,7 +216,7 @@ def test_langchain_tool_default_override(): }, }, ) - pydantic_tool = from_langchain_tool(langchain_tool) + pydantic_tool = tool_from_langchain(langchain_tool) result = pydantic_tool.function(pattern='something', dir_path='somewhere') # type: ignore assert result == snapshot("I was called with {'dir_path': 'somewhere', 'pattern': 'something'}") From fe14acda4f821efd9615a5bc4002be7316333472 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 6 Jun 2025 18:58:07 +0000 Subject: [PATCH 69/70] Update LangChain example --- docs/tools.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index affd100412..99090447f3 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -681,11 +681,11 @@ agent = Agent( tools=[search_tool], ) -agent.run('What is the release date of Elden Ring Nightreign?') # (2)! -#> Elden Ring Nightreign is planned to be released on May 30, 2025. ... # (3)! +result = agent.run_sync("What is the release date of Elden Ring Nightreign?") # (2)! +print(result.output) +# > Elden Ring Nightreign is planned to be released on May 30, 2025. ``` 1. While this task is simple Gemini 1.5 didn't want to use the provided tool. Gemini 2.0 is still fast and cheap. 2. The release date of this game is the 30th of May 2025, which was confirmed after the knowledge cutoff for Gemini 2.0 (August 2024). -3. Without the tool you get the answer: _There is no Elden Ring expansion called "Nightreign."..._ From e558fc740b3f093f1522b5f0e31a47a7dabab046 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 6 Jun 2025 19:02:38 +0000 Subject: [PATCH 70/70] Update LangChain example --- docs/tools.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/tools.md b/docs/tools.md index 99090447f3..8f8871bcac 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -681,9 +681,9 @@ agent = Agent( tools=[search_tool], ) -result = agent.run_sync("What is the release date of Elden Ring Nightreign?") # (2)! +result = agent.run_sync('What is the release date of Elden Ring Nightreign?') # (2)! print(result.output) -# > Elden Ring Nightreign is planned to be released on May 30, 2025. +#> Elden Ring Nightreign is planned to be released on May 30, 2025. ```