Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .._threads import AgentThread
from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
from ._const import WORKFLOW_RUN_KWARGS_KEY
from ._conversation_state import encode_chat_messages
from ._events import (
AgentRunEvent,
Expand Down Expand Up @@ -309,9 +310,12 @@ async def _run_agent(self, ctx: WorkflowContext) -> AgentRunResponse | None:
Returns:
The complete AgentRunResponse, or None if waiting for user input.
"""
run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY)

response = await self._agent.run(
self._cache,
thread=self._agent_thread,
**run_kwargs,
)
await ctx.add_event(AgentRunEvent(self.id, response))

Expand All @@ -333,11 +337,14 @@ async def _run_agent_streaming(self, ctx: WorkflowContext) -> AgentRunResponse |
Returns:
The complete AgentRunResponse, or None if waiting for user input.
"""
run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY)

updates: list[AgentRunResponseUpdate] = []
user_input_requests: list[FunctionApprovalRequestContent] = []
async for update in self._agent.run_stream(
self._cache,
thread=self._agent_thread,
**run_kwargs,
):
updates.append(update)
await ctx.add_event(AgentRunUpdateEvent(self.id, update))
Expand Down
5 changes: 5 additions & 0 deletions python/packages/core/agent_framework/_workflows/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
# Source identifier for internal workflow messages.
INTERNAL_SOURCE_PREFIX = "internal"

# SharedState key for storing run kwargs that should be passed to agent invocations.
# Used by all orchestration patterns (Sequential, Concurrent, GroupChat, Handoff, Magentic)
# to pass kwargs from workflow.run_stream() through to agent.run_stream() and @ai_function tools.
WORKFLOW_RUN_KWARGS_KEY = "_workflow_run_kwargs"


def INTERNAL_SOURCE_ID(executor_id: str) -> str:
"""Generate an internal source ID for a given executor."""
Expand Down
60 changes: 44 additions & 16 deletions python/packages/core/agent_framework/_workflows/_magentic.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
from ._const import EXECUTOR_STATE_KEY
from ._const import EXECUTOR_STATE_KEY, WORKFLOW_RUN_KWARGS_KEY
from ._events import AgentRunUpdateEvent, WorkflowEvent
from ._executor import Executor, handler
from ._group_chat import (
Expand Down Expand Up @@ -286,19 +286,22 @@ class _MagenticStartMessage(DictConvertible):
"""Internal: A message to start a magentic workflow."""

messages: list[ChatMessage] = field(default_factory=_new_chat_message_list)
run_kwargs: dict[str, Any] = field(default_factory=dict)

def __init__(
self,
messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None,
*,
task: ChatMessage | None = None,
run_kwargs: dict[str, Any] | None = None,
) -> None:
normalized = normalize_messages_input(messages)
if task is not None:
normalized += normalize_messages_input(task)
if not normalized:
raise ValueError("MagenticStartMessage requires at least one message input.")
self.messages: list[ChatMessage] = normalized
self.run_kwargs: dict[str, Any] = run_kwargs or {}

@property
def task(self) -> ChatMessage:
Expand Down Expand Up @@ -1179,6 +1182,10 @@ async def handle_start_message(
return
logger.info("Magentic Orchestrator: Received start message")

# Store run_kwargs in SharedState so agent executors can access them
# Always store (even empty dict) so retrieval is deterministic
await context.set_shared_state(WORKFLOW_RUN_KWARGS_KEY, message.run_kwargs or {})

self._context = MagenticContext(
task=message.task,
participant_descriptions=self._participants,
Expand Down Expand Up @@ -2004,10 +2011,12 @@ async def _invoke_agent(
"""
logger.debug(f"Agent {self._agent_id}: Running with {len(self._chat_history)} messages")

run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY)

updates: list[AgentRunResponseUpdate] = []
# The wrapped participant is guaranteed to be an BaseAgent when this is called.
agent = cast("AgentProtocol", self._agent)
async for update in agent.run_stream(messages=self._chat_history): # type: ignore[attr-defined]
async for update in agent.run_stream(messages=self._chat_history, **run_kwargs): # type: ignore[attr-defined]
updates.append(update)
await self._emit_agent_delta_event(ctx, update)

Expand Down Expand Up @@ -2604,52 +2613,68 @@ def workflow(self) -> Workflow:
"""Access the underlying workflow."""
return self._workflow

async def run_streaming_with_string(self, task_text: str) -> AsyncIterable[WorkflowEvent]:
async def run_streaming_with_string(self, task_text: str, **kwargs: Any) -> AsyncIterable[WorkflowEvent]:
"""Run the workflow with a task string.

Args:
task_text: The task description as a string.
**kwargs: Additional keyword arguments to pass through to agent invocations.
These kwargs will be available in @ai_function tools via **kwargs.

Yields:
WorkflowEvent: The events generated during the workflow execution.
"""
start_message = _MagenticStartMessage.from_string(task_text)
start_message.run_kwargs = kwargs
async for event in self._workflow.run_stream(start_message):
yield event

async def run_streaming_with_message(self, task_message: ChatMessage) -> AsyncIterable[WorkflowEvent]:
async def run_streaming_with_message(
self, task_message: ChatMessage, **kwargs: Any
) -> AsyncIterable[WorkflowEvent]:
"""Run the workflow with a ChatMessage.

Args:
task_message: The task as a ChatMessage.
**kwargs: Additional keyword arguments to pass through to agent invocations.
These kwargs will be available in @ai_function tools via **kwargs.

Yields:
WorkflowEvent: The events generated during the workflow execution.
"""
start_message = _MagenticStartMessage(task_message)
start_message = _MagenticStartMessage(task_message, run_kwargs=kwargs)
async for event in self._workflow.run_stream(start_message):
yield event

async def run_stream(self, message: Any | None = None) -> AsyncIterable[WorkflowEvent]:
async def run_stream(self, message: Any | None = None, **kwargs: Any) -> AsyncIterable[WorkflowEvent]:
"""Run the workflow with either a message object or the preset task string.

Args:
message: The message to send. If None and task_text was provided during construction,
uses the preset task string.
**kwargs: Additional keyword arguments to pass through to agent invocations.
These kwargs will be available in @ai_function tools via **kwargs.
Example: workflow.run_stream("task", user_id="123", custom_data={...})

Yields:
WorkflowEvent: The events generated during the workflow execution.
"""
if message is None:
if self._task_text is None:
raise ValueError("No message provided and no preset task text available")
message = _MagenticStartMessage.from_string(self._task_text)
start_message = _MagenticStartMessage.from_string(self._task_text)
elif isinstance(message, str):
message = _MagenticStartMessage.from_string(message)
start_message = _MagenticStartMessage.from_string(message)
elif isinstance(message, (ChatMessage, list)):
message = _MagenticStartMessage(message) # type: ignore[arg-type]
start_message = _MagenticStartMessage(message) # type: ignore[arg-type]
else:
start_message = message

async for event in self._workflow.run_stream(message):
# Attach kwargs to the start message
if isinstance(start_message, _MagenticStartMessage):
start_message.run_kwargs = kwargs

async for event in self._workflow.run_stream(start_message):
yield event

async def _validate_checkpoint_participants(
Expand Down Expand Up @@ -2730,46 +2755,49 @@ async def _validate_checkpoint_participants(
f"Missing names: {missing}; unexpected names: {unexpected}."
)

async def run_with_string(self, task_text: str) -> WorkflowRunResult:
async def run_with_string(self, task_text: str, **kwargs: Any) -> WorkflowRunResult:
"""Run the workflow with a task string and return all events.

Args:
task_text: The task description as a string.
**kwargs: Additional keyword arguments to pass through to agent invocations.

Returns:
WorkflowRunResult: All events generated during the workflow execution.
"""
events: list[WorkflowEvent] = []
async for event in self.run_streaming_with_string(task_text):
async for event in self.run_streaming_with_string(task_text, **kwargs):
events.append(event)
return WorkflowRunResult(events)

async def run_with_message(self, task_message: ChatMessage) -> WorkflowRunResult:
async def run_with_message(self, task_message: ChatMessage, **kwargs: Any) -> WorkflowRunResult:
"""Run the workflow with a ChatMessage and return all events.

Args:
task_message: The task as a ChatMessage.
**kwargs: Additional keyword arguments to pass through to agent invocations.

Returns:
WorkflowRunResult: All events generated during the workflow execution.
"""
events: list[WorkflowEvent] = []
async for event in self.run_streaming_with_message(task_message):
async for event in self.run_streaming_with_message(task_message, **kwargs):
events.append(event)
return WorkflowRunResult(events)

async def run(self, message: Any | None = None) -> WorkflowRunResult:
async def run(self, message: Any | None = None, **kwargs: Any) -> WorkflowRunResult:
"""Run the workflow and return all events.

Args:
message: The message to send. If None and task_text was provided during construction,
uses the preset task string.
**kwargs: Additional keyword arguments to pass through to agent invocations.

Returns:
WorkflowRunResult: All events generated during the workflow execution.
"""
events: list[WorkflowEvent] = []
async for event in self.run_stream(message):
async for event in self.run_stream(message, **kwargs):
events.append(event)
return WorkflowRunResult(events)

Expand Down
39 changes: 38 additions & 1 deletion python/packages/core/agent_framework/_workflows/_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._agent import WorkflowAgent
from ._checkpoint import CheckpointStorage
from ._const import DEFAULT_MAX_ITERATIONS
from ._const import DEFAULT_MAX_ITERATIONS, WORKFLOW_RUN_KWARGS_KEY
from ._edge import (
EdgeGroup,
FanOutEdgeGroup,
Expand Down Expand Up @@ -291,6 +291,7 @@ async def _run_workflow_with_tracing(
initial_executor_fn: Callable[[], Awaitable[None]] | None = None,
reset_context: bool = True,
streaming: bool = False,
run_kwargs: dict[str, Any] | None = None,
) -> AsyncIterable[WorkflowEvent]:
"""Private method to run workflow with proper tracing.

Expand All @@ -301,6 +302,7 @@ async def _run_workflow_with_tracing(
initial_executor_fn: Optional function to execute initial executor
reset_context: Whether to reset the context for a new run
streaming: Whether to enable streaming mode for agents
run_kwargs: Optional kwargs to store in SharedState for agent invocations

Yields:
WorkflowEvent: The events generated during the workflow execution.
Expand Down Expand Up @@ -335,6 +337,10 @@ async def _run_workflow_with_tracing(
self._runner.context.reset_for_new_run()
await self._shared_state.clear()

# Store run kwargs in SharedState so executors can access them
# Always store (even empty dict) so retrieval is deterministic
await self._shared_state.set(WORKFLOW_RUN_KWARGS_KEY, run_kwargs or {})

# Set streaming mode after reset
self._runner_context.set_streaming(streaming)

Expand Down Expand Up @@ -442,6 +448,7 @@ async def run_stream(
*,
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
**kwargs: Any,
) -> AsyncIterable[WorkflowEvent]:
"""Run the workflow and stream events.

Expand All @@ -457,6 +464,9 @@ async def run_stream(
- With checkpoint_id: Used to load and restore the specified checkpoint
- Without checkpoint_id: Enables checkpointing for this run, overriding
build-time configuration
**kwargs: Additional keyword arguments to pass through to agent invocations.
These are stored in SharedState and accessible in @ai_function tools
via the **kwargs parameter.

Yields:
WorkflowEvent: Events generated during workflow execution.
Expand All @@ -475,6 +485,17 @@ async def run_stream(
async for event in workflow.run_stream("start message"):
process(event)

With custom context for ai_functions:

.. code-block:: python

async for event in workflow.run_stream(
"analyze data",
custom_data={"endpoint": "https://api.example.com"},
user_token={"user": "alice"},
):
process(event)

Enable checkpointing at runtime:

.. code-block:: python
Expand Down Expand Up @@ -524,6 +545,7 @@ async def run_stream(
),
reset_context=reset_context,
streaming=True,
run_kwargs=kwargs if kwargs else None,
):
yield event
finally:
Expand Down Expand Up @@ -559,6 +581,7 @@ async def run(
checkpoint_id: str | None = None,
checkpoint_storage: CheckpointStorage | None = None,
include_status_events: bool = False,
**kwargs: Any,
) -> WorkflowRunResult:
"""Run the workflow to completion and return all events.

Expand All @@ -575,6 +598,9 @@ async def run(
- Without checkpoint_id: Enables checkpointing for this run, overriding
build-time configuration
include_status_events: Whether to include WorkflowStatusEvent instances in the result list.
**kwargs: Additional keyword arguments to pass through to agent invocations.
These are stored in SharedState and accessible in @ai_function tools
via the **kwargs parameter.

Returns:
A WorkflowRunResult instance containing events generated during workflow execution.
Expand All @@ -593,6 +619,16 @@ async def run(
result = await workflow.run("start message")
outputs = result.get_outputs()

With custom context for ai_functions:

.. code-block:: python

result = await workflow.run(
"analyze data",
custom_data={"endpoint": "https://api.example.com"},
user_token={"user": "alice"},
)

Enable checkpointing at runtime:

.. code-block:: python
Expand Down Expand Up @@ -637,6 +673,7 @@ async def run(
self._execute_with_message_or_checkpoint, message, checkpoint_id, checkpoint_storage
),
reset_context=reset_context,
run_kwargs=kwargs if kwargs else None,
)
]
finally:
Expand Down
Loading
Loading