diff --git a/python/packages/core/agent_framework/_workflows/_agent.py b/python/packages/core/agent_framework/_workflows/_agent.py index d4f6c1411d..e85c14a2f7 100644 --- a/python/packages/core/agent_framework/_workflows/_agent.py +++ b/python/packages/core/agent_framework/_workflows/_agent.py @@ -141,7 +141,19 @@ async def run( checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id, used to load and restore the checkpoint. When provided without checkpoint_id, enables checkpointing for this run. - **kwargs: Additional keyword arguments. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These are stored in SharedState and accessible in @ai_function tools + via the **kwargs parameter. + + With custom context for ai_functions: + + .. code-block:: python + + result = await workflow_agent.run( + "analyze data", + custom_data={"endpoint": "https://api.example.com"}, + user_token={"user": "alice"}, + ) Returns: The final workflow response as an AgentRunResponse. @@ -153,7 +165,12 @@ async def run( response_id = str(uuid.uuid4()) async for update in self._run_stream_impl( - input_messages, response_id, thread, checkpoint_id, checkpoint_storage + input_messages, + response_id, + thread, + checkpoint_id, + checkpoint_storage, + run_kwargs=kwargs if kwargs else None, ): response_updates.append(update) @@ -187,7 +204,20 @@ async def run_stream( checkpoint_storage: Runtime checkpoint storage. When provided with checkpoint_id, used to load and restore the checkpoint. When provided without checkpoint_id, enables checkpointing for this run. - **kwargs: Additional keyword arguments. + **kwargs: Additional keyword arguments to pass through to agent invocations. + These are stored in SharedState and accessible in @ai_function tools + via the **kwargs parameter. + + With custom context for ai_functions: + + .. code-block:: python + + async for event in workflow_agent.run_stream( + "analyze data", + custom_data={"endpoint": "https://api.example.com"}, + user_token={"user": "alice"}, + ): + process(event) Yields: AgentRunResponseUpdate objects representing the workflow execution progress. @@ -198,7 +228,12 @@ async def run_stream( response_id = str(uuid.uuid4()) async for update in self._run_stream_impl( - input_messages, response_id, thread, checkpoint_id, checkpoint_storage + input_messages, + response_id, + thread, + checkpoint_id, + checkpoint_storage, + run_kwargs=kwargs if kwargs else None, ): response_updates.append(update) yield update @@ -216,6 +251,7 @@ async def _run_stream_impl( thread: AgentThread, checkpoint_id: str | None = None, checkpoint_storage: CheckpointStorage | None = None, + run_kwargs: dict[str, Any] | None = None, ) -> AsyncIterable[AgentRunResponseUpdate]: """Internal implementation of streaming execution. @@ -225,6 +261,7 @@ async def _run_stream_impl( thread: The conversation thread containing message history. checkpoint_id: ID of checkpoint to restore from. checkpoint_storage: Runtime checkpoint storage. + run_kwargs: Optional kwargs to store in SharedState for agent invocations Yields: AgentRunResponseUpdate objects representing the workflow execution progress. @@ -255,6 +292,7 @@ async def _run_stream_impl( message=None, checkpoint_id=checkpoint_id, checkpoint_storage=checkpoint_storage, + **(run_kwargs if run_kwargs else {}), ) else: # Execute workflow with streaming (initial run or no function responses) @@ -268,6 +306,7 @@ async def _run_stream_impl( event_stream = self.workflow.run_stream( message=conversation_messages, checkpoint_storage=checkpoint_storage, + **(run_kwargs if run_kwargs else {}), ) # Process events from the stream