Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use .iter() API to fully replace existing streaming implementation #951

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,30 @@ async def run(

return await self._make_request(ctx)

@asynccontextmanager
async def stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
async with self._stream(ctx) as streamed_response:
agent_stream = result.AgentStream[DepsT, T](
streamed_response,
ctx.deps.result_schema,
ctx.deps.result_validators,
build_run_context(ctx),
ctx.deps.usage_limits,
)
yield agent_stream
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in agent_stream:
pass

@asynccontextmanager
async def _stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[models.StreamedResponse]:
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters = await self._prepare_request(ctx)
Expand Down
15 changes: 15 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,24 @@ class PartDeltaEvent:
"""Event type identifier, used as a discriminator."""


@dataclass
class FinalResultEvent:
"""An event indicating the response to the current model request matches the result schema."""

tool_name: str | None
"""The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
event_kind: Literal['final_result'] = 'final_result'
"""Event type identifier, used as a discriminator."""


ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""

AgentStreamEvent = Annotated[
Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
]
"""An event in the agent stream."""


@dataclass
class FunctionToolCallEvent:
Expand Down
122 changes: 121 additions & 1 deletion pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from typing import Generic, Union, cast

import logfire_api
from typing_extensions import TypeVar
from typing_extensions import TypeVar, assert_type

from . import _result, _utils, exceptions, messages as _messages, models
from .messages import AgentStreamEvent, FinalResultEvent
from .tools import AgentDepsT, RunContext
from .usage import Usage, UsageLimits

Expand Down Expand Up @@ -51,6 +52,125 @@
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')


@dataclass
class AgentStream(Generic[AgentDepsT, ResultDataT]):
_raw_stream_response: models.StreamedResponse
_result_schema: _result.ResultSchema[ResultDataT] | None
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
_run_ctx: RunContext[AgentDepsT]
_usage_limits: UsageLimits | None

_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
_initial_run_ctx_usage: Usage = field(init=False)

def __post_init__(self):
self._initial_run_ctx_usage = copy(self._run_ctx.usage)

async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
"""Asynchronously stream the (validated) agent outputs."""
async for response in self.stream_responses(debounce_by=debounce_by):
if self._final_result_event is not None:
yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
if self._final_result_event is not None:
yield await self._validate_response(
self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
)

async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
"""Asynchronously stream the (unvalidated) model responses for the agent."""
# if the message currently has any parts with content, yield before streaming
msg = self._raw_stream_response.get()
for part in msg.parts:
if part.has_content():
yield msg
break

async with _utils.group_by_temporal(self, debounce_by) as group_iter:
async for _items in group_iter:
yield self._raw_stream_response.get() # current state of the response

def usage(self) -> Usage:
"""Return the usage of the whole run.

!!! note
This won't return the full usage until the stream is finished.
"""
return self._initial_run_ctx_usage + self._raw_stream_response.usage()

async def _validate_response(
self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False
) -> ResultDataT:
"""Validate a structured result message."""
if self._result_schema is not None and result_tool_name is not None:
match = self._result_schema.find_named_tool(message.parts, result_tool_name)
if match is None:
raise exceptions.UnexpectedModelBehavior(
f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
)

call, result_tool = match
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)

for validator in self._result_validators:
result_data = await validator.validate(result_data, call, self._run_ctx)
return result_data
else:
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
for validator in self._result_validators:
text = await validator.validate(
text,
None,
self._run_ctx,
)
# Since there is no result tool, we can assume that str is compatible with ResultDataT
return cast(ResultDataT, text)

def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.

This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
first match is found.
"""
if self._agent_stream_iterator is not None:
return self._agent_stream_iterator

async def aiter():
result_schema = self._result_schema
allow_text_result = result_schema is None or result_schema.allow_text_result

def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
if isinstance(e, _messages.PartStartEvent):
new_part = e.part
if isinstance(new_part, _messages.ToolCallPart):
if result_schema is not None and (match := result_schema.find_tool([new_part])):
call, _ = match
return _messages.FinalResultEvent(tool_name=call.tool_name)
elif allow_text_result:
assert_type(e, _messages.PartStartEvent)
return _messages.FinalResultEvent(tool_name=None)

usage_checking_stream = _get_usage_checking_stream_response(
self._raw_stream_response, self._usage_limits, self.usage
)
async for event in usage_checking_stream:
yield event
if (final_result_event := _get_final_result_event(event)) is not None:
self._final_result_event = final_result_event
yield final_result_event
break

# If we broke out of the above loop, we need to yield the rest of the events
# If we didn't, this will just be a no-op
async for event in usage_checking_stream:
yield event

self._agent_stream_iterator = aiter()
return self._agent_stream_iterator


@dataclass
class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
"""Result of a streamed run that returns structured data via a tool call."""
Expand Down
Loading