Skip to content

Commit

Permalink
Use .iter() API to fully replace existing streaming implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 22, 2025
1 parent 8bf5bee commit 55ab7f0
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 2 deletions.
20 changes: 19 additions & 1 deletion pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,30 @@ async def run(

return await self._make_request(ctx)

@asynccontextmanager
async def stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[result.AgentStream[DepsT, T]]:
async with self._stream(ctx) as streamed_response:
agent_stream = result.AgentStream[DepsT, T](
streamed_response,
ctx.deps.result_schema,
ctx.deps.result_validators,
build_run_context(ctx),
ctx.deps.usage_limits,
)
yield agent_stream
# In case the user didn't manually consume the full stream, ensure it is fully consumed here,
# otherwise usage won't be properly counted:
async for _ in agent_stream:
pass

@asynccontextmanager
async def _stream(
self,
ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]],
) -> AsyncIterator[models.StreamedResponse]:
# TODO: Consider changing this to return something more similar to a `StreamedRunResult`, then make it public
assert not self._did_stream, 'stream() should only be called once per node'

model_settings, model_request_parameters = await self._prepare_request(ctx)
Expand Down
15 changes: 15 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,9 +444,24 @@ class PartDeltaEvent:
"""Event type identifier, used as a discriminator."""


@dataclass
class FinalResultEvent:
"""An event indicating the response to the current model request matches the result schema."""

tool_name: str | None
"""The name of the result tool that was called. `None` if the result is from text content and not from a tool."""
event_kind: Literal['final_result'] = 'final_result'
"""Event type identifier, used as a discriminator."""


ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""

AgentStreamEvent = Annotated[
Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
]
"""An event in the agent stream."""


@dataclass
class FunctionToolCallEvent:
Expand Down
122 changes: 121 additions & 1 deletion pydantic_ai_slim/pydantic_ai/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from typing import Generic, Union, cast

import logfire_api
from typing_extensions import TypeVar
from typing_extensions import TypeVar, assert_type

from . import _result, _utils, exceptions, messages as _messages, models
from .messages import AgentStreamEvent, FinalResultEvent
from .tools import AgentDepsT, RunContext
from .usage import Usage, UsageLimits

Expand Down Expand Up @@ -51,6 +52,125 @@
_logfire = logfire_api.Logfire(otel_scope='pydantic-ai')


@dataclass
class AgentStream(Generic[AgentDepsT, ResultDataT]):
_raw_stream_response: models.StreamedResponse
_result_schema: _result.ResultSchema[ResultDataT] | None
_result_validators: list[_result.ResultValidator[AgentDepsT, ResultDataT]]
_run_ctx: RunContext[AgentDepsT]
_usage_limits: UsageLimits | None

_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
_final_result_event: FinalResultEvent | None = field(default=None, init=False)
_initial_run_ctx_usage: Usage = field(init=False)

def __post_init__(self):
self._initial_run_ctx_usage = copy(self._run_ctx.usage)

async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[ResultDataT]:
"""Asynchronously stream the (validated) agent outputs."""
async for response in self.stream_responses(debounce_by=debounce_by):
if self._final_result_event is not None:
yield await self._validate_response(response, self._final_result_event.tool_name, allow_partial=True)
if self._final_result_event is not None:
yield await self._validate_response(
self._raw_stream_response.get(), self._final_result_event.tool_name, allow_partial=False
)

async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]:
"""Asynchronously stream the (unvalidated) model responses for the agent."""
# if the message currently has any parts with content, yield before streaming
msg = self._raw_stream_response.get()
for part in msg.parts:
if part.has_content():
yield msg
break

async with _utils.group_by_temporal(self, debounce_by) as group_iter:
async for _items in group_iter:
yield self._raw_stream_response.get() # current state of the response

def usage(self) -> Usage:
"""Return the usage of the whole run.
!!! note
This won't return the full usage until the stream is finished.
"""
return self._initial_run_ctx_usage + self._raw_stream_response.usage()

async def _validate_response(
self, message: _messages.ModelResponse, result_tool_name: str | None, *, allow_partial: bool = False
) -> ResultDataT:
"""Validate a structured result message."""
if self._result_schema is not None and result_tool_name is not None:
match = self._result_schema.find_named_tool(message.parts, result_tool_name)
if match is None:
raise exceptions.UnexpectedModelBehavior(
f'Invalid response, unable to find tool: {self._result_schema.tool_names()}'
)

call, result_tool = match
result_data = result_tool.validate(call, allow_partial=allow_partial, wrap_validation_errors=False)

for validator in self._result_validators:
result_data = await validator.validate(result_data, call, self._run_ctx)
return result_data
else:
text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart))
for validator in self._result_validators:
text = await validator.validate(
text,
None,
self._run_ctx,
)
# Since there is no result tool, we can assume that str is compatible with ResultDataT
return cast(ResultDataT, text)

def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
This proxies the _raw_stream_response and sends all events to the agent stream, while also checking for matches
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
first match is found.
"""
if self._agent_stream_iterator is not None:
return self._agent_stream_iterator

async def aiter():
result_schema = self._result_schema
allow_text_result = result_schema is None or result_schema.allow_text_result

def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages.FinalResultEvent | None:
"""Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result."""
if isinstance(e, _messages.PartStartEvent):
new_part = e.part
if isinstance(new_part, _messages.ToolCallPart):
if result_schema is not None and (match := result_schema.find_tool([new_part])):
call, _ = match
return _messages.FinalResultEvent(tool_name=call.tool_name)
elif allow_text_result:
assert_type(e, _messages.PartStartEvent)
return _messages.FinalResultEvent(tool_name=None)

usage_checking_stream = _get_usage_checking_stream_response(
self._raw_stream_response, self._usage_limits, self.usage
)
async for event in usage_checking_stream:
yield event
if (final_result_event := _get_final_result_event(event)) is not None:
self._final_result_event = final_result_event
yield final_result_event
break

# If we broke out of the above loop, we need to yield the rest of the events
# If we didn't, this will just be a no-op
async for event in usage_checking_stream:
yield event

self._agent_stream_iterator = aiter()
return self._agent_stream_iterator


@dataclass
class StreamedRunResult(Generic[AgentDepsT, ResultDataT]):
"""Result of a streamed run that returns structured data via a tool call."""
Expand Down

0 comments on commit 55ab7f0

Please sign in to comment.