diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 4fbffb27b2..874ffdd0be 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -6,6 +6,13 @@ AgentExecutorRequest, AgentExecutorResponse, ) +from ._agent_utils import resolve_agent_id +from ._base_group_chat_orchestrator import ( + BaseGroupChatOrchestrator, + GroupChatRequestMessage, + GroupChatRequestSentEvent, + GroupChatResponseReceivedEvent, +) from ._checkpoint import ( CheckpointStorage, FileCheckpointStorage, @@ -56,37 +63,30 @@ ) from ._function_executor import FunctionExecutor, executor from ._group_chat import ( - DEFAULT_MANAGER_INSTRUCTIONS, - DEFAULT_MANAGER_STRUCTURED_OUTPUT_PROMPT, + AgentBasedGroupChatOrchestrator, GroupChatBuilder, - GroupChatDirective, - GroupChatStateSnapshot, - ManagerDirectiveModel, - ManagerSelectionRequest, - ManagerSelectionResponse, + GroupChatState, ) -from ._handoff import HandoffBuilder, HandoffUserInputRequest +from ._handoff import HandoffAgentUserRequest, HandoffBuilder, HandoffSentEvent from ._magentic import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, ORCH_MSG_KIND_INSTRUCTION, ORCH_MSG_KIND_NOTICE, ORCH_MSG_KIND_TASK_LEDGER, ORCH_MSG_KIND_USER_TASK, MagenticBuilder, MagenticContext, - MagenticHumanInputRequest, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, MagenticManagerBase, - MagenticStallInterventionDecision, - MagenticStallInterventionReply, - MagenticStallInterventionRequest, + MagenticOrchestrator, + MagenticOrchestratorEvent, + MagenticOrchestratorEventType, + MagenticPlanReviewRequest, + MagenticPlanReviewResponse, + MagenticProgressLedger, + MagenticProgressLedgerItem, + MagenticResetSignal, StandardMagenticManager, ) -from ._orchestration_request_info import AgentInputRequest, AgentResponseReviewRequest, RequestInfoInterceptor +from ._orchestration_request_info import AgentRequestInfoResponse from ._orchestration_state import OrchestrationState from ._request_info_mixin import response_handler from ._runner import Runner @@ -112,22 +112,19 @@ from ._workflow_executor import SubWorkflowRequestMessage, SubWorkflowResponseMessage, WorkflowExecutor __all__ = [ - "DEFAULT_MANAGER_INSTRUCTIONS", - "DEFAULT_MANAGER_STRUCTURED_OUTPUT_PROMPT", "DEFAULT_MAX_ITERATIONS", - "MAGENTIC_EVENT_TYPE_AGENT_DELTA", - "MAGENTIC_EVENT_TYPE_ORCHESTRATOR", "ORCH_MSG_KIND_INSTRUCTION", "ORCH_MSG_KIND_NOTICE", "ORCH_MSG_KIND_TASK_LEDGER", "ORCH_MSG_KIND_USER_TASK", + "AgentBasedGroupChatOrchestrator", "AgentExecutor", "AgentExecutorRequest", "AgentExecutorResponse", - "AgentInputRequest", - "AgentResponseReviewRequest", + "AgentRequestInfoResponse", "AgentRunEvent", "AgentRunUpdateEvent", + "BaseGroupChatOrchestrator", "Case", "CheckpointStorage", "ConcurrentBuilder", @@ -146,30 +143,29 @@ "FunctionExecutor", "GraphConnectivityError", "GroupChatBuilder", - "GroupChatDirective", - "GroupChatStateSnapshot", + "GroupChatRequestMessage", + "GroupChatRequestSentEvent", + "GroupChatResponseReceivedEvent", + "GroupChatState", + "HandoffAgentUserRequest", "HandoffBuilder", - "HandoffUserInputRequest", + "HandoffSentEvent", "InMemoryCheckpointStorage", "InProcRunnerContext", "MagenticBuilder", "MagenticContext", - "MagenticHumanInputRequest", - "MagenticHumanInterventionDecision", - "MagenticHumanInterventionKind", - "MagenticHumanInterventionReply", - "MagenticHumanInterventionRequest", "MagenticManagerBase", - "MagenticStallInterventionDecision", - "MagenticStallInterventionReply", - "MagenticStallInterventionRequest", - "ManagerDirectiveModel", - "ManagerSelectionRequest", - "ManagerSelectionResponse", + "MagenticOrchestrator", + "MagenticOrchestratorEvent", + "MagenticOrchestratorEventType", + "MagenticPlanReviewRequest", + "MagenticPlanReviewResponse", + "MagenticProgressLedger", + "MagenticProgressLedgerItem", + "MagenticResetSignal", "Message", "OrchestrationState", "RequestInfoEvent", - "RequestInfoInterceptor", "Runner", "RunnerContext", "SequentialBuilder", @@ -208,6 +204,7 @@ "executor", "get_checkpoint_summary", "handler", + "resolve_agent_id", "response_handler", "validate_workflow_graph", ] diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index e469929bcc..52e3f76536 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -10,6 +10,7 @@ from .._agents import AgentProtocol, ChatAgent from .._threads import AgentThread from .._types import AgentRunResponse, AgentRunResponseUpdate, ChatMessage +from ._agent_utils import resolve_agent_id from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._const import WORKFLOW_RUN_KWARGS_KEY from ._conversation_state import encode_chat_messages @@ -88,16 +89,20 @@ def __init__( id: A unique identifier for the executor. If None, the agent's name will be used if available. """ # Prefer provided id; else use agent.name if present; else generate deterministic prefix - exec_id = id or agent.name + exec_id = id or resolve_agent_id(agent) if not exec_id: - raise ValueError("Agent must have a name or an explicit id must be provided.") + raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.") super().__init__(exec_id) self._agent = agent self._agent_thread = agent_thread or self._agent.get_new_thread() self._pending_agent_requests: dict[str, FunctionApprovalRequestContent] = {} self._pending_responses_to_agent: list[FunctionApprovalResponseContent] = [] self._output_response = output_response + + # AgentExecutor maintains an internal cache of messages in between runs self._cache: list[ChatMessage] = [] + # This tracks the full conversation after each run + self._full_conversation: list[ChatMessage] = [] @property def output_response(self) -> bool: @@ -227,6 +232,7 @@ async def on_checkpoint_save(self) -> dict[str, Any]: return { "cache": encode_chat_messages(self._cache), + "full_conversation": encode_chat_messages(self._full_conversation), "agent_thread": serialized_thread, "pending_agent_requests": encode_checkpoint_value(self._pending_agent_requests), "pending_responses_to_agent": encode_checkpoint_value(self._pending_responses_to_agent), @@ -251,6 +257,16 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: else: self._cache = [] + full_conversation_payload = state.get("full_conversation") + if full_conversation_payload: + try: + self._full_conversation = decode_chat_messages(full_conversation_payload) + except Exception as exc: + logger.warning("Failed to restore full conversation: %s", exc) + self._full_conversation = [] + else: + self._full_conversation = [] + thread_payload = state.get("agent_thread") if thread_payload: try: @@ -289,6 +305,12 @@ async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, # Non-streaming mode: use run() and emit single event response = await self._run_agent(cast(WorkflowContext, ctx)) + # Always extend full conversation with cached messages plus agent outputs + # (agent_run_response.messages) after each run. This is to avoid losing context + # when agent did not complete and the cache is cleared when responses come back. + # Do not mutate response.messages so AgentRunEvent remains faithful to the raw output. + self._full_conversation.extend(list(self._cache) + (list(response.messages) if response else [])) + if response is None: # Agent did not complete (e.g., waiting for user input); do not emit response logger.info("AgentExecutor %s: Agent did not complete, awaiting user input", self.id) @@ -297,12 +319,7 @@ async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, if self._output_response: await ctx.yield_output(response) - # Always construct a full conversation snapshot from inputs (cache) - # plus agent outputs (agent_run_response.messages). Do not mutate - # response.messages so AgentRunEvent remains faithful to the raw output. - full_conversation: list[ChatMessage] = list(self._cache) + list(response.messages) - - agent_response = AgentExecutorResponse(self.id, response, full_conversation=full_conversation) + agent_response = AgentExecutorResponse(self.id, response, full_conversation=self._full_conversation) await ctx.send_message(agent_response) self._cache.clear() diff --git a/python/packages/core/agent_framework/_workflows/_agent_utils.py b/python/packages/core/agent_framework/_workflows/_agent_utils.py new file mode 100644 index 0000000000..f296f53ab9 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_agent_utils.py @@ -0,0 +1,17 @@ +# Copyright (c) Microsoft. All rights reserved. + +from .._agents import AgentProtocol + + +def resolve_agent_id(agent: AgentProtocol) -> str: + """Resolve the unique identifier for an agent. + + Prefers the `.name` attribute if set; otherwise falls back to `.id`. + + Args: + agent: The agent whose identifier is to be resolved. + + Returns: + The resolved unique identifier for the agent. + """ + return agent.name if agent.name else agent.id diff --git a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py index 5576246a8e..6e03d9e5c3 100644 --- a/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py +++ b/python/packages/core/agent_framework/_workflows/_base_group_chat_orchestrator.py @@ -2,16 +2,23 @@ """Base class for group chat orchestrators that manages conversation flow and participant selection.""" +import asyncio import inspect import logging import sys -from abc import ABC, abstractmethod +from abc import ABC +from collections import OrderedDict from collections.abc import Awaitable, Callable, Sequence -from typing import Any +from dataclasses import dataclass +from typing import Any, ClassVar, TypeAlias -from .._types import ChatMessage -from ._executor import Executor -from ._orchestrator_helpers import ParticipantRegistry +from typing_extensions import Never + +from .._types import ChatMessage, Role +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._events import WorkflowEvent +from ._executor import Executor, handler +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow_context import WorkflowContext if sys.version_info >= (3, 12): @@ -23,6 +30,129 @@ logger = logging.getLogger(__name__) +@dataclass +class GroupChatRequestMessage: + """Request envelope sent from the orchestrator to a participant.""" + + additional_instruction: str | None = None + metadata: dict[str, Any] | None = None + + +@dataclass +class GroupChatParticipantMessage: + """Message envelop containing messages generated by a participant. + + This message envelope is used to broadcast messages from one participant + to other participants in the group chat to keep them synchronized. + """ + + messages: list[ChatMessage] + + +@dataclass +class GroupChatResponseMessage: + """Response envelope emitted by participants back to the orchestrator.""" + + message: ChatMessage + + +TerminationCondition: TypeAlias = Callable[[list[ChatMessage]], bool | Awaitable[bool]] +GroupChatWorkflowContext_T_Out: TypeAlias = AgentExecutorRequest | GroupChatRequestMessage | GroupChatParticipantMessage + + +# region Group chat events +class GroupChatEvent(WorkflowEvent): + """Base class for group chat workflow events.""" + + def __init__(self, round_index: int, data: Any | None = None) -> None: + """Initialize group chat event. + + Args: + round_index: Current round index + data: Optional event-specific data + """ + super().__init__(data) + self.round_index = round_index + + +class GroupChatResponseReceivedEvent(GroupChatEvent): + """Event emitted when a participant response is received.""" + + def __init__(self, round_index: int, participant_name: str, data: Any | None = None) -> None: + """Initialize response received event. + + Args: + round_index: Current round index + participant_name: Name of the participant who sent the response + data: Optional event-specific data + """ + super().__init__(round_index, data) + self.participant_name = participant_name + + +class GroupChatRequestSentEvent(GroupChatEvent): + """Event emitted when a request is sent to a participant.""" + + def __init__(self, round_index: int, participant_name: str, data: Any | None = None) -> None: + """Initialize request sent event. + + Args: + round_index: Current round index + participant_name: Name of the participant to whom the request was sent + data: Optional event-specific data + """ + super().__init__(round_index, data) + self.participant_name = participant_name + + +# endregion + + +# region Participant registry +class ParticipantRegistry: + """Simple registry for tracking group chat participants and their types and other properties.""" + + EMPTY_DESCRIPTION_PLACEHOLDER: ClassVar[str] = ( + "" + ) + + def __init__(self, participants: Sequence[Executor]) -> None: + """Initialize the registry and validate participant IDs. + + Args: + participants: List of executors (agents or custom executors) to register + Raises: + ValueError: If there are duplicate or conflicting participant IDs + """ + self._agents: set[str] = set() + self._participants: OrderedDict[str, str] = OrderedDict() + self._resolve_participants(participants) + + def _resolve_participants(self, participants: Sequence[Executor]) -> None: + """Register participants and validate IDs.""" + for participant in participants: + if participant.id in self._participants: + raise ValueError(f"Participant ID conflict: '{participant.id}' registered as both agent and executor.") + + if isinstance(participant, AgentExecutor | AgentApprovalExecutor): + self._agents.add(participant.id) + self._participants[participant.id] = participant.description or self.EMPTY_DESCRIPTION_PLACEHOLDER + else: + self._participants[participant.id] = self.EMPTY_DESCRIPTION_PLACEHOLDER + + def is_agent(self, name: str) -> bool: + """Check if a participant is an agent (vs custom executor).""" + return name in self._agents + + @property + def participants(self) -> OrderedDict[str, str]: + """Get all registered participant names and descriptions in an ordered dictionary.""" + return self._participants + + +# endregion + + class BaseGroupChatOrchestrator(Executor, ABC): """Abstract base class for group chat orchestrators. @@ -33,36 +163,159 @@ class BaseGroupChatOrchestrator(Executor, ABC): inheriting the common participant management infrastructure. """ - def __init__(self, executor_id: str) -> None: + TERMINATION_CONDITION_MET_MESSAGE: ClassVar[str] = "The group chat has reached its termination condition." + MAX_ROUNDS_MET_MESSAGE: ClassVar[str] = "The group chat has reached the maximum number of rounds." + + def __init__( + self, + id: str, + participant_registry: ParticipantRegistry, + *, + name: str | None = None, + max_rounds: int | None = None, + termination_condition: TerminationCondition | None = None, + ) -> None: """Initialize base orchestrator. Args: - executor_id: Unique identifier for this orchestrator executor + id: Unique identifier for this orchestrator executor + participant_registry: Registry of group chat participants that tracks their types (agents + vs custom executors) + name: Optional display name for orchestrator messages + max_rounds: Optional maximum number of conversation rounds. + Must be equal to or greater than 1 if set. Number smaller than 1 will be coerced to 1. + termination_condition: Optional callable to determine conversation termination """ - super().__init__(executor_id) - self._registry = ParticipantRegistry() - # Shared conversation state management - self._conversation: list[ChatMessage] = [] + super().__init__(id) + self._name = name or id + self._max_rounds = max(1, max_rounds) if max_rounds is not None else None + self._termination_condition = termination_condition self._round_index: int = 0 - self._max_rounds: int | None = None - self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None + self._participant_registry = participant_registry + # Shared conversation state management + self._full_conversation: list[ChatMessage] = [] + + # region Handlers + + @handler + async def handle_str( + self, + task: str, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handler for string input as workflow entry point. + + Wraps the string in a USER role ChatMessage and delegates to _handle_task_message. + + Args: + task: Plain text task description from user + ctx: Workflow context - def register_participant_entry( - self, name: str, *, entry_id: str, is_agent: bool, exit_id: str | None = None + Usage: + workflow.run("Write a blog post about AI agents") + """ + await self._handle_messages([ChatMessage(role=Role.USER, text=task)], ctx) + + @handler + async def handle_message( + self, + task: ChatMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Record routing details for a participant's entry executor. + """Handler for single ChatMessage input as workflow entry point. - This method provides a unified interface for registering participants - across all orchestrator patterns, whether they are agents or custom executors. + Wraps the message in a list and delegates to _handle_task_message. Args: - name: Participant name (used for selection and tracking) - entry_id: Executor ID for this participant's entry point - is_agent: Whether this is an AgentExecutor (True) or custom Executor (False) - exit_id: Executor ID for this participant's exit point (where responses come from). - If None, defaults to entry_id. + task: ChatMessage from user + ctx: Workflow context + + Usage: + workflow.run(ChatMessage(role=Role.USER, text="Write a blog post about AI agents")) """ - self._registry.register(name, entry_id=entry_id, is_agent=is_agent, exit_id=exit_id) + await self._handle_messages([task], ctx) + + @handler + async def handle_messages( + self, + task: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handler for list of ChatMessages as workflow entry point. + + Delegates to _handle_task_message. + + Args: + task: List of ChatMessages from user + ctx: Workflow context + Usage: + workflow.run([ + ChatMessage(role=Role.USER, text="Write a blog post about AI agents"), + ChatMessage(role=Role.USER, text="Make it engaging and informative.") + ]) + """ + if not task: + raise ValueError("At least one ChatMessage is required to start the group chat workflow.") + await self._handle_messages(task, ctx) + + @handler + async def handle_participant_response( + self, + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handler for participant responses. + + This method can be overridden by subclasses if specific response handling is needed. + + Args: + response: Response from a participant + ctx: Workflow context + """ + await ctx.add_event( + GroupChatResponseReceivedEvent( + round_index=self._round_index, + participant_name=ctx.source_executor_ids[0] if ctx.source_executor_ids else "unknown", + data=response, + ) + ) + await self._handle_response(response, ctx) + + # endregion + + # region Handler methods subclasses must implement + + async def _handle_messages( + self, + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handle task messages from users as workflow entry point. + + Subclasses must implement this method to define pattern-specific orchestration logic. + + Args: + messages: Task messages from user + ctx: Workflow context + """ + raise NotImplementedError("_handle_messages must be implemented by subclasses.") + + async def _handle_response( + self, + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], + ) -> None: + """Handle a participant response. + + Subclasses must implement this method to define pattern-specific response handling logic. + + Args: + response: Response from a participant + ctx: Workflow context + """ + raise NotImplementedError("_handle_response must be implemented by subclasses.") + + # endregion # Conversation state management (shared across all patterns) @@ -72,7 +325,7 @@ def _append_messages(self, messages: Sequence[ChatMessage]) -> None: Args: messages: Messages to append """ - self._conversation.extend(messages) + self._full_conversation.extend(messages) def _get_conversation(self) -> list[ChatMessage]: """Get a copy of the current conversation. @@ -80,11 +333,27 @@ def _get_conversation(self) -> list[ChatMessage]: Returns: Cloned conversation list """ - return list(self._conversation) + return list(self._full_conversation) + + def _process_participant_response( + self, response: AgentExecutorResponse | GroupChatResponseMessage + ) -> list[ChatMessage]: + """Extract ChatMessage from participant response. + + Args: + response: Response from participant + Returns: + List of ChatMessages extracted from the response + """ + if isinstance(response, AgentExecutorResponse): + return response.agent_run_response.messages + if isinstance(response, GroupChatResponseMessage): + return [response.message] + raise TypeError(f"Unsupported response type: {type(response)}") def _clear_conversation(self) -> None: """Clear the conversation history.""" - self._conversation.clear() + self._full_conversation.clear() def _increment_round(self) -> None: """Increment the round counter.""" @@ -102,97 +371,121 @@ async def _check_termination(self) -> bool: return False result = self._termination_condition(self._get_conversation()) - if inspect.iscoroutine(result) or inspect.isawaitable(result): + if inspect.isawaitable(result): result = await result - return bool(result) + return result - @abstractmethod - def _get_author_name(self) -> str: - """Get the author name for orchestrator-generated messages. + async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check termination conditions and yield completion if met. - Subclasses must implement this to provide a stable author name - for completion messages and other orchestrator-generated content. + Args: + ctx: Workflow context for yielding output Returns: - Author name to use for messages generated by this orchestrator + True if termination condition met and output yielded, False otherwise """ - ... + terminate = await self._check_termination() + if terminate: + self._append_messages([self._create_completion_message(self.TERMINATION_CONDITION_MET_MESSAGE)]) + await ctx.yield_output(self._full_conversation) + return True - def _create_completion_message( - self, - text: str | None = None, - reason: str = "completed", - ) -> ChatMessage: + return False + + def _create_completion_message(self, message: str) -> ChatMessage: """Create a standardized completion message. Args: - text: Optional message text (auto-generated if None) - reason: Completion reason for default text + message: Completion text Returns: ChatMessage with completion content """ - from .._types import Role + return ChatMessage(role=Role.ASSISTANT, text=message, author_name=self._name) + + # Participant routing (shared across all patterns) + + async def _broadcast_messages_to_participants( + self, + messages: list[ChatMessage], + ctx: WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], + participants: Sequence[str] | None = None, + ) -> None: + """Broadcast messages to participants. - message_text = text or f"Conversation {reason}." - return ChatMessage( - role=Role.ASSISTANT, - text=message_text, - author_name=self._get_author_name(), + This method sends the given messages to all registered participants + or a specified subset. This acts as a message broadcast mechanism for + participants in the group chat to stay synchronized. + + Args: + messages: Messages to send + ctx: Workflow context for message broadcasting + participants: Optional list of participant names to route to. + If None, routes to all registered participants. + """ + target_participants = ( + participants if participants is not None else list(self._participant_registry.participants) ) - # Participant routing (shared across all patterns) + async def _send_messages(target: str) -> None: + if self._participant_registry.is_agent(target): + # Send messages without requesting a response + await ctx.send_message(AgentExecutorRequest(messages=messages, should_respond=False), target_id=target) + else: + # Send messages wrapped in GroupChatParticipantMessage + await ctx.send_message(GroupChatParticipantMessage(messages=messages), target_id=target) - async def _route_to_participant( + await asyncio.gather(*[_send_messages(p) for p in target_participants]) + + async def _send_request_to_participant( self, - participant_name: str, - conversation: list[ChatMessage], - ctx: WorkflowContext[Any, Any], + target: str, + ctx: WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], *, - instruction: str | None = None, - task: ChatMessage | None = None, + additional_instruction: str | None = None, metadata: dict[str, Any] | None = None, ) -> None: - """Route a conversation to a participant. + """Send a request to a participant. This method handles the dual envelope pattern: - AgentExecutors receive AgentExecutorRequest (messages only) - Custom executors receive GroupChatRequestMessage (full context) Args: - participant_name: Name of the participant to route to - conversation: Conversation history to send + target: Name of the participant to route to ctx: Workflow context for message routing - instruction: Optional instruction from manager/orchestrator - task: Optional task context + additional_instruction: Optional additional instruction for the participant. + This can be used to provide guidance to steer the participant's response. metadata: Optional metadata dict Raises: ValueError: If participant is not registered """ - from ._agent_executor import AgentExecutorRequest - from ._orchestrator_helpers import prepare_participant_request - - entry_id = self._registry.get_entry_id(participant_name) - if entry_id is None: - raise ValueError(f"No registered entry executor for participant '{participant_name}'.") - - if self._registry.is_agent(participant_name): + if self._participant_registry.is_agent(target): # AgentExecutors receive simple message list - await ctx.send_message( - AgentExecutorRequest(messages=conversation, should_respond=True), - target_id=entry_id, + messages: list[ChatMessage] = [] + if additional_instruction: + messages.append(ChatMessage(role=Role.USER, text=additional_instruction)) + request = AgentExecutorRequest(messages=messages, should_respond=True) + await ctx.send_message(request, target_id=target) + await ctx.add_event( + GroupChatRequestSentEvent( + round_index=self._round_index, + participant_name=target, + data=request, + ) ) else: # Custom executors receive full context envelope - request = prepare_participant_request( - participant_name=participant_name, - conversation=conversation, - instruction=instruction or "", - task=task, - metadata=metadata, + request = GroupChatRequestMessage(additional_instruction=additional_instruction, metadata=metadata) # type: ignore[assignment] + await ctx.send_message(request, target_id=target) + await ctx.add_event( + GroupChatRequestSentEvent( + round_index=self._round_index, + participant_name=target, + data=request, + ) ) - await ctx.send_message(request, target_id=entry_id) # Round limit enforcement (shared across all patterns) @@ -217,6 +510,23 @@ def _check_round_limit(self) -> bool: return False + async def _check_round_limit_and_yield(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check round limit and yield completion if reached. + + Args: + ctx: Workflow context for yielding output + + Returns: + True if round limit reached and output yielded, False otherwise + """ + reach_max_rounds = self._check_round_limit() + if reach_max_rounds: + self._append_messages([self._create_completion_message(self.MAX_ROUNDS_MET_MESSAGE)]) + await ctx.yield_output(self._full_conversation) + return True + + return False + # State persistence (shared across all patterns) # State persistence (shared across all patterns) @@ -234,8 +544,9 @@ async def on_checkpoint_save(self) -> dict[str, Any]: from ._orchestration_state import OrchestrationState state = OrchestrationState( - conversation=list(self._conversation), + conversation=list(self._full_conversation), round_index=self._round_index, + orchestrator_name=self._name, metadata=self._snapshot_pattern_metadata(), ) return state.to_dict() @@ -263,8 +574,9 @@ async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: from ._orchestration_state import OrchestrationState orch_state = OrchestrationState.from_dict(state) - self._conversation = list(orch_state.conversation) + self._full_conversation = list(orch_state.conversation) self._round_index = orch_state.round_index + self._name = orch_state.orchestrator_name self._restore_pattern_metadata(orch_state.metadata) def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: diff --git a/python/packages/core/agent_framework/_workflows/_concurrent.py b/python/packages/core/agent_framework/_workflows/_concurrent.py index 2900254126..dbbccccee2 100644 --- a/python/packages/core/agent_framework/_workflows/_concurrent.py +++ b/python/packages/core/agent_framework/_workflows/_concurrent.py @@ -10,11 +10,12 @@ from agent_framework import AgentProtocol, ChatMessage, Role -from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._agent_utils import resolve_agent_id from ._checkpoint import CheckpointStorage from ._executor import Executor, handler from ._message_utils import normalize_messages_input -from ._orchestration_request_info import RequestInfoInterceptor +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext @@ -247,6 +248,7 @@ def __init__(self) -> None: self._aggregator_factory: Callable[[], Executor] | None = None self._checkpoint_storage: CheckpointStorage | None = None self._request_info_enabled: bool = False + self._request_info_filter: set[str] | None = None def register_participants( self, @@ -461,25 +463,68 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "Concurre self._checkpoint_storage = checkpoint_storage return self - def with_request_info(self) -> "ConcurrentBuilder": - """Enable request info before aggregation in the workflow. + def with_request_info( + self, + *, + agents: Sequence[str | AgentProtocol] | None = None, + ) -> "ConcurrentBuilder": + """Enable request info after agent participant responses. + + This enables human-in-the-loop (HIL) scenarios for the sequential orchestration. + When enabled, the workflow pauses after each agent participant runs, emitting + a RequestInfoEvent that allows the caller to review the conversation and optionally + inject guidance for the agent participant to iterate. The caller provides input via + the standard response_handler/request_info pattern. + + Simulated flow with HIL: + Input -> [Agent Participant <-> Request Info] -> [Agent Participant <-> Request Info] -> ... - When enabled, the workflow pauses after all parallel agents complete, - emitting a RequestInfoEvent that allows the caller to review and optionally - modify the combined results before aggregation. The caller provides feedback - via the standard response_handler/request_info pattern. + Note: This is only available for agent participants. Executor participants can incorporate + request info handling in their own implementation if desired. - Note: - Unlike SequentialBuilder and GroupChatBuilder, ConcurrentBuilder does not - support per-agent filtering since all agents run in parallel and results - are collected together. The pause occurs once with all agent outputs received. + Args: + agents: Optional list of agents names or agent factories to enable request info for. + If None, enables HIL for all agent participants. Returns: - self: The builder instance for fluent chaining. + Self for fluent chaining """ + from ._orchestration_request_info import resolve_request_info_filter + self._request_info_enabled = True + self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + return self + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + participants: list[Executor | AgentProtocol] = [] + if self._participant_factories: + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Sequential builder still creates new instances per workflow build. + for factory in self._participant_factories: + p = factory() + participants.append(p) + else: + participants = self._participants + + executors: list[Executor] = [] + for p in participants: + if isinstance(p, Executor): + executors.append(p) + elif isinstance(p, AgentProtocol): + if self._request_info_enabled and ( + not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter + ): + # Handle request info enabled agents + executors.append(AgentApprovalExecutor(p)) + else: + executors.append(AgentExecutor(p)) + else: + raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.") + + return executors + def build(self) -> Workflow: r"""Build and validate the concurrent workflow. @@ -521,29 +566,15 @@ def build(self) -> Workflow: ) ) - participants: list[Executor | AgentProtocol] = [] - if self._participant_factories: - # Resolve the participant factories now. This doesn't break the factory pattern - # since the Concurrent builder still creates new instances per workflow build. - for factory in self._participant_factories: - p = factory() - participants.append(p) - else: - participants = self._participants + # Resolve participants and participant factories to executors + participants: list[Executor] = self._resolve_participants() builder = WorkflowBuilder() builder.set_start_executor(dispatcher) + # Fan-out for parallel execution builder.add_fan_out_edges(dispatcher, participants) - - if self._request_info_enabled: - # Insert interceptor between fan-in and aggregator - # participants -> fan-in -> interceptor -> aggregator - request_info_interceptor = RequestInfoInterceptor(executor_id="request_info") - builder.add_fan_in_edges(participants, request_info_interceptor) - builder.add_edge(request_info_interceptor, aggregator) - else: - # Direct fan-in to aggregator - builder.add_fan_in_edges(participants, aggregator) + # Direct fan-in to aggregator + builder.add_fan_in_edges(participants, aggregator) if self._checkpoint_storage is not None: builder = builder.with_checkpointing(self._checkpoint_storage) diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index 57c600519d..27709ad3a9 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -367,9 +367,9 @@ def __repr__(self) -> str: # pragma: no cover - representation only class AgentRunUpdateEvent(ExecutorEvent): """Event triggered when an agent is streaming messages.""" - data: AgentRunResponseUpdate | None + data: AgentRunResponseUpdate - def __init__(self, executor_id: str, data: AgentRunResponseUpdate | None = None): + def __init__(self, executor_id: str, data: AgentRunResponseUpdate): """Initialize the agent streaming event.""" super().__init__(executor_id, data) @@ -381,9 +381,9 @@ def __repr__(self) -> str: class AgentRunEvent(ExecutorEvent): """Event triggered when an agent run is completed.""" - data: AgentRunResponse | None + data: AgentRunResponse - def __init__(self, executor_id: str, data: AgentRunResponse | None = None): + def __init__(self, executor_id: str, data: AgentRunResponse): """Initialize the agent run event.""" super().__init__(executor_id, data) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index fad1e5f15e..49f3dafd06 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -250,6 +250,8 @@ async def execute( ): # Find the handler and handler spec that matches the message type. handler = self._find_handler(message) + + original_message = message if isinstance(message, Message): # Unwrap raw data for handler call message = message.data @@ -261,6 +263,9 @@ async def execute( runner_context=runner_context, trace_contexts=trace_contexts, source_span_ids=source_span_ids, + request_id=original_message.original_request_info_event.request_id + if isinstance(original_message, Message) and original_message.original_request_info_event + else None, ) # Invoke the handler with the message and context @@ -291,6 +296,7 @@ def _create_context_for_handler( runner_context: RunnerContext, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, + request_id: str | None = None, ) -> WorkflowContext[Any]: """Create the appropriate WorkflowContext based on the handler's context annotation. @@ -300,6 +306,7 @@ def _create_context_for_handler( runner_context: The runner context that provides methods to send messages and events. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking. + request_id: Optional request ID if this context is for a `handle_response` handler. Returns: WorkflowContext[Any] based on the handler's context annotation. @@ -312,6 +319,7 @@ def _create_context_for_handler( runner_context=runner_context, trace_contexts=trace_contexts, source_span_ids=source_span_ids, + request_id=request_id, ) def _discover_handlers(self) -> None: @@ -356,7 +364,17 @@ def can_handle(self, message: Message) -> bool: True if the executor can handle the message type, False otherwise. """ if message.type == MessageType.RESPONSE: - return any(is_instance_of(message.data, message_type) for message_type in self._response_handlers) + if message.original_request_info_event is None: + logger.warning( + f"Executor {self.__class__.__name__} received a response message without an original request event." + ) + return False + + return any( + is_instance_of(message.original_request_info_event.data, message_type[0]) + and is_instance_of(message.data, message_type[1]) + for message_type in self._response_handlers + ) return any(is_instance_of(message.data, message_type) for message_type in self._handlers) @@ -427,7 +445,7 @@ def workflow_output_types(self) -> list[type[Any]]: output_types: set[type[Any]] = set() # Collect workflow output types from all handlers - for handler_spec in self._handler_specs: + for handler_spec in self._handler_specs + self._response_handler_specs: handler_workflow_output_types = handler_spec.get("workflow_output_types", []) output_types.update(handler_workflow_output_types) @@ -457,11 +475,15 @@ def _find_handler(self, message: Any) -> Callable[[Any, WorkflowContext[Any, Any f"Executor {self.__class__.__name__} cannot handle message of type {type(message.data)}." ) # Response message case - find response handler based on original request and response types - handler = self._find_response_handler(message.original_request, message.data) + if message.original_request_info_event is None: + raise RuntimeError( + f"Executor {self.__class__.__name__} received a response message without an original request event." + ) + handler = self._find_response_handler(message.original_request_info_event.data, message.data) if not handler: raise RuntimeError( f"Executor {self.__class__.__name__} cannot handle request of type " - f"{type(message.original_request)} and response of type {type(message.data)}." + f"{type(message.original_request_info_event.data)} and response of type {type(message.data)}." ) return handler diff --git a/python/packages/core/agent_framework/_workflows/_group_chat.py b/python/packages/core/agent_framework/_workflows/_group_chat.py index 7381e757ad..e1ba5156ee 100644 --- a/python/packages/core/agent_framework/_workflows/_group_chat.py +++ b/python/packages/core/agent_framework/_workflows/_group_chat.py @@ -2,15 +2,15 @@ """Group chat orchestration primitives. -This module introduces a reusable orchestration surface for manager-directed +This module introduces a reusable orchestration surface for orchestrator-directed multi-agent conversations. The key components are: - GroupChatRequestMessage / GroupChatResponseMessage: canonical envelopes used between the orchestrator and participants. -- Group chat managers: minimal asynchronous callables for pluggable coordination logic. -- GroupChatOrchestratorExecutor: runtime state machine that delegates to a - manager to select the next participant or complete the task. -- GroupChatBuilder: high-level builder that wires managers and participants +- GroupChatSelectionFunction: asynchronous callable for pluggable speaker selection logic. +- GroupChatOrchestrator: runtime state machine that delegates to a + selection function to select the next participant or complete the task. +- GroupChatBuilder: high-level builder that wires orchestrators and participants into a workflow graph. It mirrors the ergonomics of SequentialBuilder and ConcurrentBuilder while allowing Magentic to reuse the same infrastructure. @@ -19,1727 +19,723 @@ """ import inspect -import itertools import logging -from collections.abc import Awaitable, Callable, Mapping, Sequence -from dataclasses import dataclass, field -from types import MappingProxyType -from typing import Any, TypeAlias, cast -from uuid import uuid4 +import sys +from collections import OrderedDict +from collections.abc import Awaitable, Callable, Sequence +from dataclasses import dataclass +from typing import Any, ClassVar, cast from pydantic import BaseModel, Field +from typing_extensions import Never from .._agents import AgentProtocol, ChatAgent +from .._threads import AgentThread from .._types import ChatMessage, Role -from ._agent_executor import AgentExecutorRequest, AgentExecutorResponse -from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._agent_utils import resolve_agent_id +from ._base_group_chat_orchestrator import ( + BaseGroupChatOrchestrator, + GroupChatParticipantMessage, + GroupChatRequestMessage, + GroupChatResponseMessage, + GroupChatWorkflowContext_T_Out, + ParticipantRegistry, + TerminationCondition, +) from ._checkpoint import CheckpointStorage -from ._conversation_history import ensure_author, latest_user_message -from ._edge import EdgeCondition -from ._executor import Executor, handler -from ._orchestration_request_info import RequestInfoInterceptor -from ._participant_utils import GroupChatParticipantSpec, prepare_participant_metadata, wrap_participant +from ._conversation_state import decode_chat_messages, encode_chat_messages +from ._executor import Executor +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext -logger = logging.getLogger(__name__) - - -# region Message primitives - - -@dataclass -class _GroupChatRequestMessage: - """Internal: Request envelope sent from the orchestrator to a participant.""" - - agent_name: str - conversation: list[ChatMessage] = field(default_factory=list) # type: ignore - instruction: str = "" - task: ChatMessage | None = None - metadata: dict[str, Any] | None = None - - -@dataclass -class _GroupChatResponseMessage: - """Internal: Response envelope emitted by participants back to the orchestrator.""" - - agent_name: str - message: ChatMessage - - -@dataclass -class _GroupChatTurn: - """Internal: Represents a single turn in the manager-participant conversation.""" - - speaker: str - role: str - message: ChatMessage - - -@dataclass -class GroupChatDirective: - """Instruction emitted by a group chat manager implementation.""" - - agent_name: str | None = None - instruction: str | None = None - metadata: dict[str, Any] | None = None - finish: bool = False - final_message: ChatMessage | None = None - - -@dataclass -class ManagerSelectionRequest: - """Request sent to manager agent for next speaker selection. - - This dataclass packages the full conversation state and task context - for the manager agent to analyze and make a speaker selection decision. - - Attributes: - task: Original user task message - participants: Mapping of participant names to their descriptions - conversation: Full conversation history including all messages - round_index: Number of manager selection rounds completed so far - metadata: Optional metadata for extensibility - """ - - task: ChatMessage - participants: dict[str, str] # type: ignore - conversation: list[ChatMessage] # type: ignore - round_index: int - metadata: dict[str, Any] | None = None - - def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for serialization.""" - return { - "task": self.task.to_dict(), - "participants": dict(self.participants), - "conversation": [msg.to_dict() for msg in self.conversation], - "round_index": self.round_index, - "metadata": self.metadata, - } - - -class ManagerSelectionResponse(BaseModel): - """Response from manager agent with speaker selection decision. - - The manager agent must produce this structure (or compatible dict/JSON) - to communicate its decision back to the orchestrator. - - Attributes: - selected_participant: Name of participant to speak next (None = finish conversation) - instruction: Optional instruction to provide to the selected participant - finish: Whether the conversation should be completed - final_message: Optional final message string when finishing conversation (will be converted to ChatMessage) - """ - - model_config = { - "extra": "forbid", - # OpenAI strict mode requires all properties to be in required array - "json_schema_extra": {"required": ["selected_participant", "instruction", "finish", "final_message"]}, - } - - selected_participant: str | None = None - instruction: str | None = None - finish: bool = False - final_message: str | None = Field(default=None, description="Optional text content for final message") - - @staticmethod - def from_dict(data: dict[str, Any]) -> "ManagerSelectionResponse": - """Create from dictionary representation.""" - return ManagerSelectionResponse( - selected_participant=data.get("selected_participant"), - instruction=data.get("instruction"), - finish=data.get("finish", False), - final_message=data.get("final_message"), - ) - - def get_final_message_as_chat_message(self) -> ChatMessage | None: - """Convert final_message string to ChatMessage if present.""" - if self.final_message: - return ChatMessage(role=Role.ASSISTANT, text=self.final_message) - return None - - -# endregion - - -# region Manager callable - - -GroupChatStateSnapshot = Mapping[str, Any] -_GroupChatManagerFn = Callable[[GroupChatStateSnapshot], Awaitable[GroupChatDirective]] - +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override -async def _maybe_await(value: Any) -> Any: - """Await value if it is awaitable; otherwise return as-is.""" - if inspect.isawaitable(value): - return await value - return value - - -_GroupChatParticipantPipeline: TypeAlias = Sequence[Executor] +logger = logging.getLogger(__name__) -@dataclass -class _GroupChatConfig: - """Internal: Configuration passed to factories during workflow assembly. +@dataclass(frozen=True) +class GroupChatState: + """Immutable state of the group chat for the selection function to determine the next speaker. Attributes: - manager: Manager callable for orchestration decisions (used by set_select_speakers_func) - manager_participant: Manager agent/executor instance (used by set_manager) - manager_name: Display name for the manager in conversation history - participants: Mapping of participant names to their specifications - max_rounds: Optional limit on manager selection rounds to prevent infinite loops - termination_condition: Optional callable that halts the conversation when it returns True - orchestrator: Orchestrator executor instance (populated during build) - participant_aliases: Mapping of aliases to executor IDs - participant_executors: Mapping of participant names to their executor instances + current_round: The current round index of the group chat, starting from 0. + participants: A mapping of participant names to their descriptions in the group chat. + conversation: The full conversation history up to this point as a list of ChatMessage. """ - manager: _GroupChatManagerFn | None - manager_participant: AgentProtocol | Executor | None - manager_name: str - participants: Mapping[str, GroupChatParticipantSpec] - max_rounds: int | None = None - termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None - orchestrator: Executor | None = None - participant_aliases: dict[str, str] = field(default_factory=dict) # type: ignore[type-arg] - participant_executors: dict[str, Executor] = field(default_factory=dict) # type: ignore[type-arg] - - -# endregion - - -# region Default participant factory + # Round index, starting from 0 + current_round: int + # participant name to description mapping as a ordered dict + participants: OrderedDict[str, str] + # Full conversation history up to this point + conversation: list[ChatMessage] -_GroupChatOrchestratorFactory: TypeAlias = Callable[[_GroupChatConfig], Executor] -_InterceptorSpec: TypeAlias = tuple[Callable[[_GroupChatConfig], Executor], EdgeCondition] - - -def _default_participant_factory( - spec: GroupChatParticipantSpec, - wiring: _GroupChatConfig, -) -> _GroupChatParticipantPipeline: - """Default factory for constructing participant pipeline nodes in the workflow graph. - - Creates a single AgentExecutor node for AgentProtocol participants or a passthrough executor - for custom participants. Translation between group-chat envelopes and the agent runtime is now - handled inside the orchestrator, removing the need for dedicated ingress/egress adapters. - - Args: - spec: Participant specification containing name, instance, and description - wiring: GroupChatWiring configuration for accessing cached executors - - Returns: - Sequence of executors representing the participant pipeline in execution order - - Behavior: - - AgentProtocol participants are wrapped in AgentExecutor with deterministic IDs - - Executor participants are wired directly without additional adapters - """ - participant = spec.participant - if isinstance(participant, Executor): - return (participant,) - cached = wiring.participant_executors.get(spec.name) - if cached is not None: - return (cached,) +# region Default orchestrator - agent_executor = wrap_participant(participant, executor_id=f"groupchat_agent:{spec.name}") - return (agent_executor,) +# Type alias for the selection function used by the orchestrator to choose the next speaker. +GroupChatSelectionFunction = Callable[[GroupChatState], Awaitable[str] | str] -# endregion +class GroupChatOrchestrator(BaseGroupChatOrchestrator): + """Orchestrator that manages a group chat between multiple participants. -# region Default orchestrator + This group chat orchestrator operates under the direction of a selection function + provided at initialization. The selection function receives the current state of + the group chat and returns the name of the next participant to speak. + This orchestrator drives the conversation loop as follows: + 1. Receives initial messages, saves to history, and broadcasts to all participants + 2. Invokes the selection function to determine the next speaker based on the most recent state + 3. Sends a request to the selected participant to generate a response + 4. Receives the participant's response, saves to history, and broadcasts to all participants + except the one that just spoke + 5. Repeats steps 2-4 until the termination conditions are met -class GroupChatOrchestratorExecutor(BaseGroupChatOrchestrator): - """Executor that orchestrates a group chat between multiple participants using a manager. - - This is the central runtime state machine that drives multi-agent conversations. It - maintains conversation state, delegates speaker selection to a manager, routes messages - to participants, and collects responses in a loop until the manager signals completion. - - Core responsibilities: - - Accept initial input as str, ChatMessage, or list[ChatMessage] - - Maintain conversation history and turn tracking - - Query manager for next action (select participant or finish) - - Route requests to selected participants using AgentExecutorRequest or GroupChatRequestMessage - - Collect participant responses and append to conversation - - Enforce optional round limits to prevent infinite loops - - Yield final completion message and transition to idle state - - State management: - - _conversation: Growing list of all messages (user, manager, agents) - - _history: Turn-by-turn record with speaker attribution and roles - - _task_message: Original user task extracted from input - - _pending_agent: Name of agent currently processing a request - - _round_index: Count of manager selection rounds for limit enforcement - - Manager interaction: - The orchestrator builds immutable state snapshots and passes them to the manager - callable. The manager returns a GroupChatDirective indicating either: - - Next participant to speak (with optional instruction) - - Finish signal (with optional final message) - - Message flow topology: - User input -> orchestrator -> manager -> orchestrator -> participant -> orchestrator - (loops until manager returns finish directive) - - Why this design: - - Separates orchestration logic (this class) from selection logic (manager) - - Manager is stateless and testable in isolation - - Orchestrator handles all state mutations and message routing - - Broadcast routing to participants keeps graph structure simple - - Args: - manager: Callable that selects the next participant or finishes based on state snapshot - participants: Mapping of participant names to descriptions (for manager context) - manager_name: Display name for manager in conversation history - max_rounds: Optional limit on manager selection rounds (None = unlimited) - termination_condition: Optional callable that halts the conversation when it returns True - executor_id: Optional custom ID for observability (auto-generated if not provided) + This is the most basic orchestrator, great for getting started with multi-agent + conversations. More advanced orchestrators can be built by extending BaseGroupChatOrchestrator + and implementing custom logic in the message and response handlers. """ def __init__( self, - manager: _GroupChatManagerFn, + id: str, + participant_registry: ParticipantRegistry, + selection_func: GroupChatSelectionFunction, *, - participants: Mapping[str, str], - manager_name: str, + name: str | None = None, max_rounds: int | None = None, - termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None, - executor_id: str | None = None, + termination_condition: TerminationCondition | None = None, ) -> None: - super().__init__(executor_id or f"groupchat_orchestrator_{uuid4().hex[:8]}") - self._manager = manager - self._participants = dict(participants) - self._manager_name = manager_name - self._max_rounds = max_rounds - self._termination_condition = termination_condition - self._history: list[_GroupChatTurn] = [] - self._task_message: ChatMessage | None = None - self._pending_agent: str | None = None - self._pending_finalization: bool = False - # Stashes the initial conversation list until _handle_task_message normalizes it into _conversation. - self._pending_initial_conversation: list[ChatMessage] | None = None - - def _get_author_name(self) -> str: - """Get the manager name for orchestrator-generated messages.""" - return self._manager_name - - def _build_state(self) -> GroupChatStateSnapshot: - """Build a snapshot of current orchestration state for the manager. - - Packages conversation history, participant metadata, and round tracking into - an immutable mapping that the manager uses to make speaker selection decisions. - - Returns: - Mapping containing all context needed for manager decision-making - - Raises: - RuntimeError: If called before task message initialization (defensive check) - - When this is called: - - After initial input is processed (first manager query) - - After each participant response (subsequent manager queries) - """ - if self._task_message is None: - raise RuntimeError("GroupChatOrchestratorExecutor state not initialized with task message.") - snapshot: dict[str, Any] = { - "task": self._task_message, - "participants": dict(self._participants), - "conversation": tuple(self._conversation), - "history": tuple(self._history), - "pending_agent": self._pending_agent, - "round_index": self._round_index, - } - return MappingProxyType(snapshot) - - def _snapshot_pattern_metadata(self) -> dict[str, Any]: - """Serialize GroupChat-specific state for checkpointing. - - Returns: - Dict with participants, manager name, history, and pending agent - """ - return { - "participants": dict(self._participants), - "manager_name": self._manager_name, - "pending_agent": self._pending_agent, - "task_message": self._task_message.to_dict() if self._task_message else None, - "history": [ - {"speaker": turn.speaker, "role": turn.role, "message": turn.message.to_dict()} - for turn in self._history - ], - } - - def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: - """Restore GroupChat-specific state from checkpoint. + """Initialize the GroupChatOrchestrator. Args: - metadata: Pattern-specific state dict - """ - if "participants" in metadata: - self._participants = dict(metadata["participants"]) - if "manager_name" in metadata: - self._manager_name = metadata["manager_name"] - if "pending_agent" in metadata: - self._pending_agent = metadata["pending_agent"] - task_msg = metadata.get("task_message") - if task_msg: - self._task_message = ChatMessage.from_dict(task_msg) - if "history" in metadata: - self._history = [ - _GroupChatTurn( - speaker=turn["speaker"], - role=turn["role"], - message=ChatMessage.from_dict(turn["message"]), - ) - for turn in metadata["history"] - ] + id: Unique executor ID for the orchestrator. The ID must be unique within the workflow. + participant_registry: Registry of participants in the group chat that track executor types + (agents vs. executors) and provide resolution utilities. + selection_func: Function to select the next speaker based on conversation state + name: Optional display name for the orchestrator in the messages, defaults to executor ID. + A more descriptive name that is not an ID could help models better understand the role + of the orchestrator in multi-agent conversations. If the ID is not human-friendly, + providing a name can improve context for the agents. + max_rounds: Optional limit on selection rounds to prevent infinite loops. + termination_condition: Optional callable that halts the conversation when it returns True + + Note: If neither `max_rounds` nor `termination_condition` is provided, the conversation + will continue indefinitely. It is recommended to always set one of these to ensure proper termination. - async def _complete_on_termination( - self, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> bool: - """Finish the conversation early when the termination condition is met.""" - if not await self._check_termination(): - return False - - if self._is_manager_agent(): - if self._pending_finalization: - return True - - self._pending_finalization = True - termination_prompt = ChatMessage( - role=Role.SYSTEM, - text="Termination condition met. Provide a final manager summary and finish the conversation.", - ) - manager_conversation = [ - self._build_manager_context_message(), - termination_prompt, - *list(self._conversation), - ] - self._pending_agent = self._manager_name - await self._route_to_participant( - participant_name=self._manager_name, - conversation=manager_conversation, - ctx=ctx, - instruction="", - task=self._task_message, - metadata={"termination_condition": True}, - ) - return True - - final_message: ChatMessage | None = None - if self._manager is not None and not self._is_manager_agent(): - try: - directive = await self._manager(self._build_state()) - except Exception: - logger.warning("Manager finalization failed during termination; using default termination message.") - else: - if directive.final_message is not None: - final_message = ensure_author(directive.final_message, self._manager_name) - elif directive.finish: - final_message = ensure_author( - self._create_completion_message( - text="Conversation completed.", - reason="termination_condition_manager_finish", - ), - self._manager_name, - ) - - if final_message is None: - final_message = ensure_author( - self._create_completion_message( - text="Conversation halted after termination condition was met.", - reason="termination_condition", - ), - self._manager_name, - ) - self._conversation.append(final_message) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) - return True + Example: + .. code-block:: python - async def _apply_directive( - self, - directive: GroupChatDirective, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Execute a manager directive by either finishing the workflow or routing to a participant. + from agent_framework import GroupChatOrchestrator - This is the core routing logic that interprets manager decisions. It handles two cases: - 1. Finish directive: append final message, update state, yield output, become idle - 2. Agent selection: build request envelope, route to participant, increment round counter - Args: - directive: Manager's decision (finish or select next participant) - ctx: Workflow context for sending messages and yielding output - - Behavior for finish directive: - - Uses provided final_message or creates default completion message - - Ensures author_name is set to manager for attribution - - Appends to conversation and history for complete record - - Yields message as workflow output - - Orchestrator becomes idle (no further processing) - - Behavior for agent selection: - - Validates agent_name exists in participants - - Optionally appends manager instruction as USER message - - Prepares full conversation context for the participant - - Routes request directly to the participant entry executor - - Increments round counter and enforces max_rounds if configured - - Round limit enforcement: - If max_rounds is reached, recursively calls _apply_directive with a finish - directive to gracefully terminate the conversation. + async def round_robin_selector(state: GroupChatState) -> str: + # Simple round-robin selection among participants + return state.participants[state.current_round % len(state.participants)] - Raises: - ValueError: If directive lacks agent_name when finish=False, or if - agent_name doesn't match any participant - """ - if directive.finish: - final_message = directive.final_message - if final_message is None: - final_message = self._create_completion_message( - text="Completed without final summary.", - reason="no summary provided", - ) - final_message = ensure_author(final_message, self._manager_name) - self._conversation.extend((final_message,)) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) - return - - agent_name = directive.agent_name - if not agent_name: - raise ValueError("Directive must include agent_name when finish is False.") - if agent_name not in self._participants: - raise ValueError(f"Manager selected unknown participant '{agent_name}'.") - - instruction = directive.instruction or "" - conversation = list(self._conversation) - if instruction: - manager_message = ensure_author( - self._create_completion_message(text=instruction, reason="instruction"), - self._manager_name, + orchestrator = GroupChatOrchestrator( + id="group_chat_orchestrator_1", + selection_func=round_robin_selector, + participants=["researcher", "writer"], + name="Coordinator", + max_rounds=10, ) - conversation.extend((manager_message,)) - self._conversation.extend((manager_message,)) - self._history.append(_GroupChatTurn(self._manager_name, "manager", manager_message)) - - if await self._complete_on_termination(ctx): - return - - self._pending_agent = agent_name - self._increment_round() - - # Use inherited routing method from BaseGroupChatOrchestrator - await self._route_to_participant( - participant_name=agent_name, - conversation=conversation, - ctx=ctx, - instruction=instruction, - task=self._task_message, - metadata=directive.metadata, + """ + super().__init__( + id, + participant_registry, + name=name, + max_rounds=max_rounds, + termination_condition=termination_condition, ) + self._selection_func = selection_func - if self._check_round_limit(): - await self._apply_directive( - GroupChatDirective( - finish=True, - final_message=self._create_completion_message( - text="Conversation halted after reaching manager round limit.", - reason="max_rounds reached", - ), - ), - ctx, - ) - - async def _ingest_participant_message( + @override + async def _handle_messages( self, - participant_name: str, - message: ChatMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - trailing_messages: list[ChatMessage] | None = None, + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Common response ingestion logic shared by agent and custom participants. - - Args: - participant_name: Name of the participant who sent the message - message: The participant's response message - ctx: Workflow context for routing and output - trailing_messages: Optional list of messages to inject after the participant's - message (e.g., additional input from the RequestInfoInterceptor) - """ - if participant_name not in self._participants: - raise ValueError(f"Received response from unknown participant '{participant_name}'.") - - message = ensure_author(message, participant_name) - self._conversation.extend((message,)) - self._history.append(_GroupChatTurn(participant_name, "agent", message)) - - # Inject any trailing messages (e.g., human input) into the conversation - if trailing_messages: - for trailing_msg in trailing_messages: - self._conversation.extend((trailing_msg,)) - # Record as user input in history - author = trailing_msg.author_name or "human" - self._history.append(_GroupChatTurn(author, "user", trailing_msg)) - logger.debug( - f"Injected human input into group chat conversation: " - f"{trailing_msg.text[:50] if trailing_msg.text else '(empty)'}..." - ) - - self._pending_agent = None - - if await self._complete_on_termination(ctx): + """Initialize orchestrator state and start the conversation loop.""" + self._append_messages(messages) + # Termination condition will also be applied to the input messages + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - if self._check_round_limit(): - final_message = self._create_completion_message( - text="Conversation halted after reaching manager round limit.", - reason="max_rounds reached after response", - ) - self._conversation.extend((final_message,)) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message)) - await ctx.yield_output(list(self._conversation)) - return - - # Query manager for next speaker selection - if self._is_manager_agent(): - # Agent-based manager: route request through workflow graph - # Prepend system message with participant context - manager_conversation = [self._build_manager_context_message(), *list(self._conversation)] - await self._route_to_participant( - participant_name=self._manager_name, - conversation=manager_conversation, - ctx=ctx, - instruction="", - task=self._task_message, - metadata=None, - ) - else: - # Callable manager: invoke directly - directive = await self._manager(self._build_state()) - await self._apply_directive(directive, ctx) + next_speaker = await self._get_next_speaker() - def _is_manager_agent(self) -> bool: - """Check if orchestrator is using an agent-based manager (vs callable manager).""" - return self._registry.is_participant_registered(self._manager_name) - - def _build_manager_context_message(self) -> ChatMessage: - """Build system message with participant context for manager agent. - - This message is prepended to the conversation when querying the manager - to provide up-to-date participant information for selection decisions. - - Returns: - System message with participant names and descriptions - """ - participant_list = "\n".join(f"- {name}: {desc}" for name, desc in self._participants.items()) - context_text = ( - "Available participants:\n" - f"{participant_list}\n\n" - "IMPORTANT: Choose only from these exact participant names (case-sensitive)." + # Broadcast messages to all participants for context + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), ) - return ChatMessage(role=Role.SYSTEM, text=context_text) - - def _parse_manager_selection(self, response: AgentExecutorResponse) -> ManagerSelectionResponse: - """Extract manager selection decision from agent response. - - Attempts to parse structured output from the manager agent using multiple strategies: - 1. response.value (structured output from response_format) - 2. JSON parsing from message text - 3. Fallback error handling - - Args: - response: AgentExecutor response from manager agent - - Returns: - Parsed ManagerSelectionResponse with speaker selection - - Raises: - RuntimeError: If manager response cannot be parsed into valid selection - """ - import json - - # Strategy 1: agent_run_response.value (structured output) - agent_value = response.agent_run_response.value - if agent_value is not None: - if isinstance(agent_value, ManagerSelectionResponse): - return agent_value - if isinstance(agent_value, dict): - return ManagerSelectionResponse.from_dict(cast(dict[str, Any], agent_value)) - if isinstance(agent_value, str): - try: - data = json.loads(agent_value) - return ManagerSelectionResponse.from_dict(data) - except (json.JSONDecodeError, TypeError, KeyError) as e: - raise RuntimeError(f"Manager response.value contains invalid JSON: {e}") from e - - # Strategy 2: Parse from message text - messages = response.agent_run_response.messages or [] - if messages: - last_msg = messages[-1] - text = last_msg.text or "" - try: - return ManagerSelectionResponse.model_validate_json(text) - except (json.JSONDecodeError, TypeError, KeyError): - pass - - # Fallback: Cannot parse manager decision - raise RuntimeError( - "Manager response did not contain valid selection data. " - "Ensure manager agent uses response_format=ManagerSelectionResponse " - "or returns compatible JSON structure." + # Send request to selected participant + await self._send_request_to_participant( + next_speaker, + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), ) + self._increment_round() - async def _handle_manager_response( + @override + async def _handle_response( self, - response: AgentExecutorResponse, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Process manager agent's speaker selection decision. - - Parses the manager's response and either finishes the conversation or routes - to the selected participant. This method implements the core orchestration - logic for agent-based managers. - - Also handles any human input that was injected into the response's full_conversation - by the human input hook interceptor. - - Args: - response: AgentExecutor response from manager agent - ctx: Workflow context for routing and output - - Behavior: - - Extracts any human input from the response - - Parses manager selection from response - - If finish=True: yields final message and completes workflow - - If participant selected: routes request to that participant with human input included - - Validates selected participant exists - - Enforces round limits if configured + """Handle a participant response.""" + messages = self._process_participant_response(response) + self._append_messages(messages) - Raises: - ValueError: If manager selects invalid/unknown participant - RuntimeError: If manager response cannot be parsed - """ - # Extract any human input that was injected by the human input hook - trailing_user_messages = self._extract_trailing_user_messages(response) - - selection = self._parse_manager_selection(response) - - if self._pending_finalization: - self._pending_finalization = False - final_message_obj = selection.get_final_message_as_chat_message() - if final_message_obj is None: - final_message_obj = self._create_completion_message( - text="Conversation halted after termination condition was met.", - reason="termination_condition_manager", - ) - final_message_obj = ensure_author(final_message_obj, self._manager_name) - - self._conversation.append(final_message_obj) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message_obj)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - - if selection.finish: - # Manager decided to complete conversation - final_message_obj = selection.get_final_message_as_chat_message() - if final_message_obj is None: - final_message_obj = self._create_completion_message( - text="Conversation completed.", - reason="manager_finish", - ) - final_message_obj = ensure_author(final_message_obj, self._manager_name) - - self._conversation.append(final_message_obj) - self._history.append(_GroupChatTurn(self._manager_name, "manager", final_message_obj)) - self._pending_agent = None - await ctx.yield_output(list(self._conversation)) + if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - # Manager selected next participant - selected = selection.selected_participant - if not selected: - raise ValueError("Manager selection missing selected_participant when finish=False.") - if selected not in self._participants: - raise ValueError(f"Manager selected unknown participant: '{selected}'") - - # Route to selected participant - instruction = selection.instruction or "" - conversation = list(self._conversation) - if instruction: - manager_message = ensure_author( - self._create_completion_message(text=instruction, reason="manager_instruction"), - self._manager_name, - ) - conversation.append(manager_message) - self._conversation.append(manager_message) - self._history.append(_GroupChatTurn(self._manager_name, "manager", manager_message)) - - # Inject any human input that was attached to the manager's response - # This ensures the next participant sees the human's guidance - if trailing_user_messages: - for human_msg in trailing_user_messages: - conversation.append(human_msg) - self._conversation.append(human_msg) - author = human_msg.author_name or "human" - self._history.append(_GroupChatTurn(author, "user", human_msg)) - logger.debug( - f"Injected human input after manager selection: " - f"{human_msg.text[:50] if human_msg.text else '(empty)'}..." - ) - - if await self._complete_on_termination(ctx): - return + next_speaker = await self._get_next_speaker() - self._pending_agent = selected - self._increment_round() - - await self._route_to_participant( - participant_name=selected, - conversation=conversation, - ctx=ctx, - instruction=instruction, - task=self._task_message, - metadata=None, + # Broadcast participant messages to all participants for context, except + # the participant that just responded + participant = ctx.get_source_executor_id() + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + participants=[p for p in self._participant_registry.participants if p != participant], ) + # Send request to selected participant + await self._send_request_to_participant( + next_speaker, + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + ) + self._increment_round() - if self._check_round_limit(): - await self._apply_directive( - GroupChatDirective( - finish=True, - final_message=self._create_completion_message( - text="Conversation halted after reaching manager round limit.", - reason="max_rounds reached after manager selection", - ), - ), - ctx, - ) - - @staticmethod - def _extract_agent_message(response: AgentExecutorResponse, participant_name: str) -> ChatMessage: - """Select the final assistant message from an AgentExecutor response.""" - from ._orchestrator_helpers import create_completion_message - - final_message: ChatMessage | None = None - candidate_sequences: tuple[Sequence[ChatMessage] | None, ...] = ( - response.agent_run_response.messages, - response.full_conversation, + async def _get_next_speaker(self) -> str: + """Determine the next speaker using the selection function.""" + group_chat_state = GroupChatState( + current_round=self._round_index, + participants=self._participant_registry.participants, + conversation=self._get_conversation(), ) - for sequence in candidate_sequences: - if not sequence: - continue - for candidate in reversed(sequence): - if candidate.role == Role.ASSISTANT: - final_message = candidate - break - if final_message is not None: - break - - if final_message is None: - final_message = create_completion_message( - text="", - author_name=participant_name, - reason="empty response", - ) - return ensure_author(final_message, participant_name) - @staticmethod - def _extract_trailing_user_messages(response: AgentExecutorResponse) -> list[ChatMessage]: - """Extract any user messages that appear after the last assistant message. + next_speaker = self._selection_func(group_chat_state) + if inspect.isawaitable(next_speaker): + next_speaker = await next_speaker - This is used to capture human input that was injected by the human input hook - interceptor. The hook adds user messages to full_conversation after the agent's - response, so they appear at the end of the sequence. + if next_speaker not in self._participant_registry.participants: + raise RuntimeError(f"Selection function returned unknown participant '{next_speaker}'.") - Args: - response: AgentExecutor response that may contain trailing user messages + return next_speaker - Returns: - List of user messages that appear after the last assistant message, - or empty list if none found - """ - if not response.full_conversation: - return [] - # Find index of last assistant message - last_assistant_idx = -1 - for i, msg in enumerate(response.full_conversation): - if msg.role == Role.ASSISTANT: - last_assistant_idx = i - - if last_assistant_idx < 0: - return [] +# endregion - # Collect any user messages after the last assistant message - trailing_user: list[ChatMessage] = [] - for msg in response.full_conversation[last_assistant_idx + 1 :]: - if msg.role == Role.USER: - trailing_user.append(msg) +# region Agent-based orchestrator - return trailing_user - async def _handle_task_message( - self, - task_message: ChatMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Initialize orchestrator state and start the manager-directed conversation loop. +class AgentOrchestrationOutput(BaseModel): + """Structured output type for the agent in AgentBasedGroupChatOrchestrator.""" - This internal method is called by all public handlers (str, ChatMessage, list[ChatMessage]) - after normalizing their input. It initializes conversation state, queries the manager - for the first action, and applies the resulting directive. + model_config = { + "extra": "forbid", + # OpenAI strict mode requires all properties to be in required array + "json_schema_extra": {"required": ["terminate", "reason", "next_speaker", "final_message"]}, + } - Args: - task_message: The primary user task message (extracted or provided directly) - ctx: Workflow context for sending messages and yielding output - - Behavior: - - Sets task_message for manager context - - Initializes conversation from pending_initial_conversation if present - - Otherwise starts fresh with just the task message - - Builds turn history with speaker attribution - - Resets pending_agent and round_index - - Queries manager for first action - - Applies directive to start the conversation loop - - State initialization: - - _conversation: Full message list for context - - _history: Turn-by-turn record with speaker names and roles - - _pending_agent: None (no active request) - - _round_index: 0 (first manager query) - - Why pending_initial_conversation exists: - The handle_conversation handler supplies an explicit task (the first message in - the list) but still forwards the entire conversation for context. The full list is - stashed in _pending_initial_conversation to preserve all context when initializing state. - """ - self._task_message = task_message - if self._pending_initial_conversation: - initial_conversation = list(self._pending_initial_conversation) - self._pending_initial_conversation = None - self._conversation = initial_conversation - self._history = [ - _GroupChatTurn( - msg.author_name or msg.role.value, - msg.role.value, - msg, - ) - for msg in initial_conversation - ] - else: - self._conversation = [task_message] - self._history = [_GroupChatTurn("user", "user", task_message)] - self._pending_agent = None - self._round_index = 0 - - if await self._complete_on_termination(ctx): - return + # Whether to terminate the conversation + terminate: bool + # An explanation for the decision made + reason: str + # Next speaker to select if not terminating + next_speaker: str | None = Field( + default=None, + description="Name of the next participant to speak (if not terminating)", + ) + # Optional final message to send if terminating + final_message: str | None = Field(default=None, description="Optional final message if terminating") - # Query manager for first speaker selection - if self._is_manager_agent(): - # Agent-based manager: route request through workflow graph - # Prepend system message with participant context - manager_conversation = [self._build_manager_context_message(), *list(self._conversation)] - await self._route_to_participant( - participant_name=self._manager_name, - conversation=manager_conversation, - ctx=ctx, - instruction="", - task=self._task_message, - metadata=None, - ) - else: - # Callable manager: invoke directly - directive = await self._manager(self._build_state()) - await self._apply_directive(directive, ctx) - @handler - async def handle_str( - self, - task: str, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Handler for string input as workflow entry point. +class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator): + """Orchestrator that manages a group chat between multiple participants. - Wraps the string in a USER role ChatMessage and delegates to _handle_task_message. + This group chat orchestrator is driven by an agent that can select the next speaker + intelligently based on the conversation context. - Args: - task: Plain text task description from user - ctx: Workflow context + This orchestrator drives the conversation loop as follows: + 1. Receives initial messages, saves to history, and broadcasts to all participants + 2. Invokes the agent to determine the next speaker based on the most recent state + 3. Sends a request to the selected participant to generate a response + 4. Receives the participant's response, saves to history, and broadcasts to all participants + except the one that just spoke + 5. Repeats steps 2-4 until the termination conditions are met - Usage: - workflow.run("Write a blog post about AI agents") - """ - await self._handle_task_message(ChatMessage(role=Role.USER, text=task), ctx) + Note: The agent will be asked to generate a structured output of type `AgentOrchestrationOutput`, + thus it must be capable of structured output. + """ - @handler - async def handle_chat_message( + def __init__( self, - task_message: ChatMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + agent: ChatAgent, + participant_registry: ParticipantRegistry, + *, + max_rounds: int | None = None, + termination_condition: TerminationCondition | None = None, + retry_attempts: int | None = None, + thread: AgentThread | None = None, ) -> None: - """Handler for ChatMessage input as workflow entry point. - - Directly delegates to _handle_task_message for state initialization. + """Initialize the GroupChatOrchestrator. Args: - task_message: Structured chat message from user (may include metadata, role, etc.) - ctx: Workflow context - - Usage: - workflow.run(ChatMessage(role=Role.USER, text="Analyze this data")) + agent: Agent that selects the next speaker based on conversation state + participant_registry: Registry of participants in the group chat that track executor types + (agents vs. executors) and provide resolution utilities. + max_rounds: Optional limit on selection rounds to prevent infinite loops. + termination_condition: Optional callable that halts the conversation when it returns True + retry_attempts: Optional number of retry attempts for the agent in case of failure. + thread: Optional agent thread to use for the orchestrator agent. """ - await self._handle_task_message(task_message, ctx) - - @handler - async def handle_conversation( + super().__init__( + resolve_agent_id(agent), + participant_registry, + name=agent.name, + max_rounds=max_rounds, + termination_condition=termination_condition, + ) + self._agent = agent + self._retry_attempts = retry_attempts + self._thread = thread or agent.get_new_thread() + # Cache for messages since last agent invocation + # This is different from the full conversation history maintained by the base orchestrator + self._cache: list[ChatMessage] = [] + + @override + def _append_messages(self, messages: Sequence[ChatMessage]) -> None: + self._cache.extend(messages) + return super()._append_messages(messages) + + @override + async def _handle_messages( self, - conversation: list[ChatMessage], - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handler for conversation history as workflow entry point. - - Accepts a pre-existing conversation and uses the first message in the list as the task. - Preserves the full conversation for state initialization. + """Initialize orchestrator state and start the conversation loop.""" + self._append_messages(messages) + # Termination condition will also be applied to the input messages + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): + return - Args: - conversation: List of chat messages (system, user, assistant) - ctx: Workflow context + agent_orchestration_output = await self._invoke_agent() + if await self._check_agent_terminate_and_yield( + agent_orchestration_output, + cast(WorkflowContext[Never, list[ChatMessage]], ctx), + ): + return - Raises: - ValueError: If conversation list is empty - - Behavior: - - Validates conversation is non-empty - - Clones conversation to avoid mutation - - Extracts task message (most recent USER message) - - Stashes full conversation in _pending_initial_conversation - - Delegates to _handle_task_message for initialization - - Usage: - existing_messages = [ - ChatMessage(role=Role.SYSTEM, text="You are an expert"), - ChatMessage(role=Role.USER, text="Help me with this task") - ] - workflow.run(existing_messages) - """ - if not conversation: - raise ValueError("GroupChat workflow requires at least one chat message.") - self._pending_initial_conversation = list(conversation) - task_message = latest_user_message(conversation) - await self._handle_task_message(task_message, ctx) - - @handler - async def handle_agent_response( - self, - response: _GroupChatResponseMessage, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], - ) -> None: - """Handle responses from custom participant executors.""" - await self._ingest_participant_message(response.agent_name, response.message, ctx) + # Broadcast messages to all participants for context + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + ) + # Send request to selected participant + await self._send_request_to_participant( + # If not terminating, next_speaker must be provided thus will not be None + agent_orchestration_output.next_speaker, # type: ignore[arg-type] + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + ) + self._increment_round() - @handler - async def handle_agent_executor_response( + @override + async def _handle_response( self, - response: AgentExecutorResponse, - ctx: WorkflowContext[AgentExecutorRequest | _GroupChatRequestMessage, list[ChatMessage]], + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handle responses from both manager agent and regular participants. - - Routes responses based on whether they come from the manager or a participant: - - Manager responses: parsed for speaker selection decisions - - Participant responses: ingested as conversation messages - - Also handles any human input that was injected into the response's full_conversation - by the human input hook interceptor. - """ - participant_name = self._registry.get_participant_name(response.executor_id) - if participant_name is None: - logger.debug( - "Ignoring response from unregistered agent executor '%s'.", - response.executor_id, - ) + """Handle a participant response.""" + messages = self._process_participant_response(response) + self._append_messages(messages) + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): + return + if await self._check_round_limit_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): return - # Check if response is from manager agent - if participant_name == self._manager_name and self._is_manager_agent(): - await self._handle_manager_response(response, ctx) - else: - # Regular participant response - message = self._extract_agent_message(response, participant_name) - - # Check for human input injected by human input hook - # Human input appears as user messages at the end of full_conversation - # after the agent's assistant message - trailing_user_messages = self._extract_trailing_user_messages(response) - - await self._ingest_participant_message(participant_name, message, ctx, trailing_user_messages) - - -def _default_orchestrator_factory(wiring: _GroupChatConfig) -> Executor: - """Default factory for creating the GroupChatOrchestratorExecutor instance. - - This is the internal implementation used by GroupChatBuilder to instantiate the - orchestrator. It extracts participant descriptions from the wiring configuration - and passes them to the orchestrator for manager context. - - Args: - wiring: Complete workflow configuration assembled by the builder - - Returns: - Initialized GroupChatOrchestratorExecutor ready to coordinate the conversation - - Behavior: - - Extracts participant names and descriptions for manager context - - Forwards manager instance, manager name, max_rounds, and termination_condition settings - - Allows orchestrator to auto-generate its executor ID - - Supports both callable managers (set_select_speakers_func) and agent-based managers (set_manager) - - Why descriptions are extracted: - The manager needs participant descriptions (not full specs) to make informed - selection decisions. The orchestrator doesn't need participant instances directly - since routing is handled by the workflow graph. + agent_orchestration_output = await self._invoke_agent() + if await self._check_agent_terminate_and_yield( + agent_orchestration_output, + cast(WorkflowContext[Never, list[ChatMessage]], ctx), + ): + return - Raises: - RuntimeError: If neither manager nor manager_participant is configured - """ - if wiring.manager is None and wiring.manager_participant is None: - raise RuntimeError( - "Default orchestrator factory requires a manager to be configured. " - "Call set_manager(...) or set_select_speakers_func(...) before build()." + # Broadcast participant messages to all participants for context, except + # the participant that just responded + participant = ctx.get_source_executor_id() + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + participants=[p for p in self._participant_registry.participants if p != participant], ) + # Send request to selected participant + await self._send_request_to_participant( + # If not terminating, next_speaker must be provided thus will not be None + agent_orchestration_output.next_speaker, # type: ignore[arg-type] + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + ) + self._increment_round() - manager_callable = wiring.manager - if manager_callable is None: - # Keep orchestrator signature satisfied; agent managers are routed via the workflow graph - async def _agent_manager_placeholder(_: GroupChatStateSnapshot) -> GroupChatDirective: # noqa: RUF029 - raise RuntimeError( - "Manager callable invoked unexpectedly. Agent-based managers should route through the workflow graph." - ) - - manager_callable = _agent_manager_placeholder - - return GroupChatOrchestratorExecutor( - manager=manager_callable, - participants={name: spec.description for name, spec in wiring.participants.items()}, - manager_name=wiring.manager_name, - max_rounds=wiring.max_rounds, - termination_condition=wiring.termination_condition, - ) + async def _invoke_agent(self) -> AgentOrchestrationOutput: + """Invoke the orchestrator agent to determine the next speaker and termination.""" + + async def _invoke_agent_helper(conversation: list[ChatMessage]) -> AgentOrchestrationOutput: + # Run the agent in non-streaming mode for simplicity + agent_response = await self._agent.run( + messages=conversation, + thread=self._thread, + response_format=AgentOrchestrationOutput, + ) + # Parse and validate the structured output + agent_orchestration_output = AgentOrchestrationOutput.model_validate_json(agent_response.text) + + if not agent_orchestration_output.terminate and not agent_orchestration_output.next_speaker: + raise ValueError("next_speaker must be provided if not terminating the conversation.") + + return agent_orchestration_output + + # We only need the last message for context since history is maintained in the thread + current_conversation = self._cache.copy() + self._cache.clear() + instruction = ( + "Decide what to do next. Respond with a JSON object of the following format:\n" + "{\n" + ' "terminate": ,\n' + ' "reason": "",\n' + ' "next_speaker": "",\n' + ' "final_message": ""\n' + "}\n" + "If not terminating, here are the valid participant names (case-sensitive) and their descriptions:\n" + + "\n".join([ + f"{name}: {description}" for name, description in self._participant_registry.participants.items() + ]) + ) + # Prepend instruction as system message + current_conversation.append(ChatMessage(role=Role.USER, text=instruction)) + retry_attempts = self._retry_attempts + while True: + try: + return await _invoke_agent_helper(current_conversation) + except Exception as ex: + logger.error(f"Agent orchestration invocation failed: {ex}") + if retry_attempts is None or retry_attempts <= 0: + raise + retry_attempts -= 1 + logger.debug(f"Retrying agent orchestration invocation, attempts left: {retry_attempts}") + # We don't need the full conversation since the thread should maintain history + current_conversation = [ + ChatMessage( + role=Role.USER, + text=f"Your input could not be parsed due to an error: {ex}. Please try again.", + ) + ] -def group_chat_orchestrator(factory: _GroupChatOrchestratorFactory | None = None) -> _GroupChatOrchestratorFactory: - """Return a callable orchestrator factory, defaulting to the built-in implementation.""" - return factory or _default_orchestrator_factory - - -def assemble_group_chat_workflow( - *, - wiring: _GroupChatConfig, - participant_factory: Callable[[GroupChatParticipantSpec, _GroupChatConfig], _GroupChatParticipantPipeline], - orchestrator_factory: _GroupChatOrchestratorFactory = _default_orchestrator_factory, - interceptors: Sequence[_InterceptorSpec] | None = None, - checkpoint_storage: CheckpointStorage | None = None, - builder: WorkflowBuilder | None = None, - return_builder: bool = False, -) -> Workflow | tuple[WorkflowBuilder, Executor]: - """Build the workflow graph shared by group-chat style orchestrators.""" - interceptor_specs = interceptors or () - - orchestrator = wiring.orchestrator or orchestrator_factory(wiring) - wiring.orchestrator = orchestrator - - workflow_builder = builder or WorkflowBuilder() - start_executor = getattr(workflow_builder, "_start_executor", None) - if start_executor is None: - workflow_builder = workflow_builder.set_start_executor(orchestrator) - - # Wire manager as participant if agent-based manager is configured - if wiring.manager_participant is not None: - manager_spec = GroupChatParticipantSpec( - name=wiring.manager_name, - participant=wiring.manager_participant, - description="Coordination manager", - ) - manager_pipeline = list(participant_factory(manager_spec, wiring)) - if not manager_pipeline: - raise ValueError("Participant factory returned empty pipeline for manager.") - - manager_entry = manager_pipeline[0] - manager_exit = manager_pipeline[-1] - - # Register manager with orchestrator (with entry and exit IDs for pipeline routing) - register_entry = getattr(orchestrator, "register_participant_entry", None) - if callable(register_entry): - register_entry( - wiring.manager_name, - entry_id=manager_entry.id, - is_agent=not isinstance(wiring.manager_participant, Executor), - exit_id=manager_exit.id if manager_exit is not manager_entry else None, - ) + async def _check_agent_terminate_and_yield( + self, + agent_orchestration_output: AgentOrchestrationOutput, + ctx: WorkflowContext[Never, list[ChatMessage]], + ) -> bool: + """Check if the agent requested termination and yield completion if so. - # Wire manager edges: Orchestrator ↔ Manager - workflow_builder = workflow_builder.add_edge(orchestrator, manager_entry) - for upstream, downstream in itertools.pairwise(manager_pipeline): - workflow_builder = workflow_builder.add_edge(upstream, downstream) - if manager_exit is not orchestrator: - workflow_builder = workflow_builder.add_edge(manager_exit, orchestrator) - - # Wire regular participants - for name, spec in wiring.participants.items(): - pipeline = list(participant_factory(spec, wiring)) - if not pipeline: - raise ValueError( - f"Participant factory returned an empty pipeline for '{name}'. " - "Provide at least one executor per participant." - ) - entry_executor = pipeline[0] - exit_executor = pipeline[-1] - - register_entry = getattr(orchestrator, "register_participant_entry", None) - if callable(register_entry): - # Register both entry and exit IDs so responses can be routed correctly - # when interceptors are prepended to the pipeline - register_entry( - name, - entry_id=entry_executor.id, - is_agent=not isinstance(spec.participant, Executor), - exit_id=exit_executor.id if exit_executor is not entry_executor else None, + Args: + agent_orchestration_output: Output from the orchestrator agent + ctx: Workflow context for yielding output + Returns: + True if termination was requested and output was yielded, False otherwise + """ + if agent_orchestration_output.terminate: + final_message = ( + agent_orchestration_output.final_message or "The conversation has been terminated by the agent." ) + self._append_messages([self._create_completion_message(final_message)]) + await ctx.yield_output(self._full_conversation) + return True - workflow_builder = workflow_builder.add_edge(orchestrator, entry_executor) - for upstream, downstream in itertools.pairwise(pipeline): - workflow_builder = workflow_builder.add_edge(upstream, downstream) - if exit_executor is not orchestrator: - workflow_builder = workflow_builder.add_edge(exit_executor, orchestrator) + return False - for factory, condition in interceptor_specs: - interceptor_executor = factory(wiring) - workflow_builder = workflow_builder.add_edge(orchestrator, interceptor_executor, condition=condition) - workflow_builder = workflow_builder.add_edge(interceptor_executor, orchestrator) + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Capture current orchestrator state for checkpointing.""" + state = await super().on_checkpoint_save() + state["cache"] = encode_chat_messages(self._cache) + serialized_thread = await self._thread.serialize() + state["thread"] = serialized_thread - if checkpoint_storage is not None: - workflow_builder = workflow_builder.with_checkpointing(checkpoint_storage) + return state - if return_builder: - return workflow_builder, orchestrator - return workflow_builder.build() + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore executor state from checkpoint.""" + await super().on_checkpoint_restore(state) + self._cache = decode_chat_messages(state.get("cache", [])) + serialized_thread = state.get("thread") + if serialized_thread: + self._thread = await self._agent.deserialize_thread(serialized_thread) # endregion - # region Builder class GroupChatBuilder: - r"""High-level builder for manager-directed group chat workflows with dynamic orchestration. - - GroupChat coordinates multi-agent conversations using a manager that selects which participant - speaks next. The manager can be a simple Python function (:py:meth:`GroupChatBuilder.set_select_speakers_func`) - or an agent-based selector via :py:meth:`GroupChatBuilder.set_manager`. These two approaches are - mutually exclusive. - - **Core Workflow:** - 1. Define participants: list of agents (uses their .name) or dict mapping names to agents - 2. Configure speaker selection: :py:meth:`GroupChatBuilder.set_select_speakers_func` OR - :py:meth:`GroupChatBuilder.set_manager` (not both) - 3. Optional: set round limits, checkpointing, termination conditions - 4. Build and run the workflow - - **Speaker Selection Patterns:** - - *Pattern 1: Simple function-based selection (recommended)* + r"""High-level builder for group chat workflows. - .. code-block:: python + GroupChat coordinates multi-agent conversations using an orchestrator that can dynamically + select participants to speak at each turn based on the conversation state. - from agent_framework import GroupChatBuilder, GroupChatStateSnapshot + Routing Pattern: + Agents respond in turns as directed by the orchestrator until termination conditions are met. + This provides a centralized approach to multi-agent collaboration, similar to a star topology. + Participants can be a combination of agents and executors. If they are executors, they + must implement the expected handlers for receiving GroupChat messages and returning responses + (Read our official documentation for details on implementing custom participant executors). - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - # state contains: task, participants, conversation, history, round_index - if state["round_index"] >= 5: - return None # Finish - last_speaker = state["history"][-1].speaker if state["history"] else None - if last_speaker == "researcher": - return "writer" - return "researcher" + The orchestrator can be provided directly, or a simple selection function can be defined + to choose the next speaker based on the current state. The builder wires everything together + into a complete workflow graph that can be executed. + Outputs: + The final conversation history as a list of ChatMessage once the group chat completes. + """ - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants([researcher_agent, writer_agent]) # Uses agent.name - .build() - ) - - *Pattern 2: LLM-based selection* - - .. code-block:: python - - from agent_framework import ChatAgent - from agent_framework.azure import AzureOpenAIChatClient + DEFAULT_ORCHESTRATOR_ID: ClassVar[str] = "group_chat_orchestrator" - manager_agent = AzureOpenAIChatClient().create_agent( - instructions="Coordinate the conversation and pick the next speaker.", - name="Coordinator", - temperature=0.3, - seed=42, - max_tokens=500, - ) + def __init__(self) -> None: + """Initialize the GroupChatBuilder.""" + self._participants: dict[str, AgentProtocol | Executor] = {} - workflow = ( - GroupChatBuilder() - .set_manager(manager_agent, display_name="Coordinator") - .participants([researcher, writer]) # Or use dict: researcher=r, writer=w - .with_max_rounds(10) - .build() - ) + # Orchestrator related members + self._orchestrator: BaseGroupChatOrchestrator | None = None + self._selection_func: GroupChatSelectionFunction | None = None + self._agent_orchestrator: ChatAgent | None = None + self._termination_condition: TerminationCondition | None = None + self._max_rounds: int | None = None + self._orchestrator_name: str | None = None - *Pattern 3: Request info for mid-conversation feedback* + # Checkpoint related members + self._checkpoint_storage: CheckpointStorage | None = None - .. code-block:: python + # Request info related members + self._request_info_enabled: bool = False + self._request_info_filter: set[str] = set() - from agent_framework import GroupChatBuilder + def with_orchestrator(self, orchestrator: BaseGroupChatOrchestrator) -> "GroupChatBuilder": + """Set the orchestrator for this group chat workflow. - # Pause before all participants - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants([researcher, writer]) - .with_request_info() - .build() - ) + An group chat orchestrator is responsible for managing the flow of conversation, making + sure all participants are synced and picking the next speaker according to the defined logic + until the termination conditions are met. - # Pause only before specific participants - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants([researcher, writer, editor]) - .with_request_info(agents=[editor]) # Only pause before editor responds - .build() - ) + Args: + orchestrator: An instance of BaseGroupChatOrchestrator to manage the group chat. - **Participant Specification:** + Returns: + Self for fluent chaining. - Two ways to specify participants: - - List form: `[agent1, agent2]` - uses `agent.name` attribute for participant names - - Dict form: `{name1: agent1, name2: agent2}` - explicit name control - - Keyword form: `participants(name1=agent1, name2=agent2)` - explicit name control + Raises: + ValueError: If an orchestrator has already been set - **State Snapshot Structure:** + Example: + .. code-block:: python - The GroupChatStateSnapshot passed to set_select_speakers_func contains: - - `task`: ChatMessage - Original user task - - `participants`: dict[str, str] - Mapping of participant names to descriptions - - `conversation`: tuple[ChatMessage, ...] - Full conversation history - - `history`: tuple[GroupChatTurn, ...] - Turn-by-turn record with speaker attribution - - `round_index`: int - Number of manager selection rounds so far - - `pending_agent`: str | None - Name of agent currently processing (if any) + from agent_framework import GroupChatBuilder - **Important Constraints:** - - Cannot combine :py:meth:`GroupChatBuilder.set_select_speakers_func` and :py:meth:`GroupChatBuilder.set_manager` - - Participant names must be unique - - When using list form, agents must have a non-empty `name` attribute - """ - def __init__( - self, - *, - _orchestrator_factory: _GroupChatOrchestratorFactory | None = None, - _participant_factory: Callable[[GroupChatParticipantSpec, _GroupChatConfig], _GroupChatParticipantPipeline] - | None = None, - ) -> None: - """Initialize the GroupChatBuilder. - - Args: - _orchestrator_factory: Internal extension point for custom orchestrator implementations. - Used by Magentic. Not part of public API - subject to change. - _participant_factory: Internal extension point for custom participant pipelines. - Used by Magentic. Not part of public API - subject to change. + orchestrator = CustomGroupChatOrchestrator(...) + workflow = GroupChatBuilder().with_orchestrator(orchestrator).participants([agent1, agent2]).build() """ - self._participants: dict[str, AgentProtocol | Executor] = {} - self._participant_metadata: dict[str, Any] | None = None - self._manager: _GroupChatManagerFn | None = None - self._manager_participant: AgentProtocol | Executor | None = None - self._manager_name: str = "manager" - self._checkpoint_storage: CheckpointStorage | None = None - self._max_rounds: int | None = None - self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None - self._interceptors: list[_InterceptorSpec] = [] - self._orchestrator_factory = group_chat_orchestrator(_orchestrator_factory) - self._participant_factory = _participant_factory or _default_participant_factory - self._request_info_enabled: bool = False - self._request_info_filter: set[str] | None = None - - def _set_manager_function( - self, - manager: _GroupChatManagerFn, - display_name: str | None, - ) -> "GroupChatBuilder": - if self._manager is not None or self._manager_participant is not None: + if self._orchestrator is not None: + raise ValueError("An orchestrator has already been configured. Call with_orchestrator(...) at most once.") + if self._agent_orchestrator is not None: raise ValueError( - "GroupChatBuilder already has a manager configured. " - "Call set_select_speakers_func(...) or set_manager(...) at most once." + "An agent orchestrator has already been configured. " + "Call only one of with_orchestrator(...) or with_agent_orchestrator(...)." + ) + if self._selection_func is not None: + raise ValueError( + "A selection function has already been configured. " + "Call only one of with_orchestrator(...) or with_select_speaker_func(...)." ) - resolved_name = display_name or getattr(manager, "name", None) or "manager" - self._manager = manager - self._manager_name = resolved_name - return self - def set_manager( - self, - manager: AgentProtocol | Executor, - *, - display_name: str | None = None, - ) -> "GroupChatBuilder": - """Configure the manager/coordinator agent for group chat orchestration. + self._orchestrator = orchestrator + return self - The manager coordinates participants by selecting who speaks next based on - conversation state and task requirements. The manager is a full workflow - participant with access to all agent infrastructure (tools, context, observability). + def with_agent_orchestrator(self, agent: ChatAgent) -> "GroupChatBuilder": + """Set an agent-based orchestrator for this group chat workflow. - The manager agent must produce structured output compatible with ManagerSelectionResponse - to communicate its speaker selection decisions. Use response_format for reliable parsing. - GroupChatBuilder enforces this when the manager is a ChatAgent and rejects incompatible - response formats. + An agent-based group chat orchestrator uses a ChatAgent to select the next speaker + intelligently based on the conversation context. Args: - manager: Agent or executor responsible for speaker selection and coordination. - Must return ManagerSelectionResponse or compatible dict/JSON structure. - display_name: Optional name for manager messages in conversation history. - If not provided, uses manager.name for AgentProtocol or manager.id for Executor. + agent: An instance of ChatAgent to manage the group chat. Returns: Self for fluent chaining. Raises: - ValueError: If manager is already configured via :py:meth:`GroupChatBuilder.set_select_speakers_func` - TypeError: If manager is not AgentProtocol or Executor instance - - Example: - - .. code-block:: python - - from agent_framework import GroupChatBuilder, ChatAgent - from agent_framework.openai import OpenAIChatClient - - # Coordinator agent - response_format is enforced to ManagerSelectionResponse - coordinator = ChatAgent( - name="Coordinator", - description="Coordinates multi-agent collaboration", - instructions=''' - You coordinate a team conversation. Review the conversation history - and select the next participant to speak. - - When ready to finish, set finish=True and provide a summary in final_message. - ''', - chat_client=OpenAIChatClient(), - ) - - workflow = ( - GroupChatBuilder() - .set_manager(coordinator, display_name="Orchestrator") - .participants([researcher, writer]) - .build() - ) - - Note: - The manager agent's response_format must be ManagerSelectionResponse for structured output. - Custom response formats raise ValueError instead of being overridden. - - The manager can be included in :py:meth:`with_request_info` to pause before the manager - runs, allowing human steering of orchestration decisions. If no filter is specified, - the manager is included automatically. To filter explicitly:: - - .with_request_info(agents=[manager, writer]) # Pause before manager and writer + ValueError: If an orchestrator has already been set """ - if self._manager is not None or self._manager_participant is not None: + if self._agent_orchestrator is not None: raise ValueError( - "GroupChatBuilder already has a manager configured. " - "Call set_select_speakers_func(...) or set_manager(...) at most once." + "Agent orchestrator has already been configured. Call with_agent_orchestrator(...) at most once." ) - - if not isinstance(manager, (AgentProtocol, Executor)): - raise TypeError(f"Manager must be AgentProtocol or Executor instance. Got {type(manager).__name__}.") - - # Infer display name from manager if not provided - if display_name is None: - display_name = manager.id if isinstance(manager, Executor) else manager.name or "manager" - - # Enforce ManagerSelectionResponse for ChatAgent managers - if ( - isinstance(manager, ChatAgent) - and manager.default_options.setdefault("response_format", ManagerSelectionResponse) - != ManagerSelectionResponse - ): - configured_format_name = getattr( - manager.default_options.get("response_format"), - "__name__", - str(manager.default_options.get("response_format")), + if self._orchestrator is not None: + raise ValueError( + "An orchestrator has already been configured. " + "Call only one of with_agent_orchestrator(...) or with_orchestrator(...)." ) + if self._selection_func is not None: raise ValueError( - "Manager ChatAgent response_format must be ManagerSelectionResponse. " - f"Received '{configured_format_name}' for manager '{display_name}'." + "A selection function has already been configured. " + "Call only one of with_agent_orchestrator(...) or with_select_speaker_func(...)." ) - self._manager_participant = manager - self._manager_name = display_name + self._agent_orchestrator = agent return self - def set_select_speakers_func( + def with_select_speaker_func( self, - selector: ( - Callable[[GroupChatStateSnapshot], Awaitable[str | None]] | Callable[[GroupChatStateSnapshot], str | None] - ), + selection_func: GroupChatSelectionFunction, *, - display_name: str | None = None, - final_message: ChatMessage | str | Callable[[GroupChatStateSnapshot], Any] | None = None, + orchestrator_name: str | None = None, ) -> "GroupChatBuilder": - """Configure speaker selection using a pure function that examines group chat state. - - This is the primary way to control orchestration flow in a GroupChat. Your selector - function receives an immutable snapshot of the current conversation state and returns - the name of the next participant to speak, or None to finish the conversation. + """Define a custom function to select the next speaker in the group chat. - The selector function can implement any logic including: - - Simple round-robin or rule-based selection - - LLM-based decision making with custom prompts - - Conversation summarization before routing to the next agent - - Custom metadata or context passing - - For advanced scenarios, return a GroupChatDirective instead of a string to include - custom instructions or metadata for the next participant. - - The selector function signature: - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - # state contains: task, participants, conversation, history, round_index - # Return participant name to continue, or None to finish - ... + This is a quick way to implement simple orchestration logic without needing a full + GroupChatOrchestrator. The provided function receives the current state of + the group chat and returns the name of the next participant to speak. Args: - selector: Function that takes GroupChatStateSnapshot and returns the next speaker's - name (str) to continue the conversation, or None to finish. May be sync or async. - Can also return GroupChatDirective for advanced control (instruction, metadata). - display_name: Optional name shown in conversation history for orchestrator messages - (defaults to "manager"). - final_message: Optional final message (or factory) emitted when selector returns None - (defaults to "Conversation completed." authored by the manager). + selection_func: Callable that receives the current GroupChatState and returns + the name of the next participant to speak, or None to finish. + orchestrator_name: Optional display name for the orchestrator in the workflow. + If not provided, defaults to `GroupChatBuilder.DEFAULT_ORCHESTRATOR_ID`. Returns: - Self for fluent chaining. + Self for fluent chaining - Example (simple): + Raises: + ValueError: If an orchestrator has already been set + Example: .. code-block:: python - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - if state["round_index"] >= 3: - return None # Finish after 3 rounds - last_speaker = state["history"][-1].speaker if state["history"] else None - if last_speaker == "researcher": - return "writer" - return "researcher" + from agent_framework import GroupChatBuilder, GroupChatState + + + async def round_robin_selector(state: GroupChatState) -> str: + # Simple round-robin selection among participants + return state.participants[state.current_round % len(state.participants)] workflow = ( GroupChatBuilder() - .set_select_speakers_func(select_next_speaker) - .participants(researcher=researcher_agent, writer=writer_agent) + .with_select_speaker_func(round_robin_selector, orchestrator_name="Coordinator") + .participants([agent1, agent2]) .build() ) - - Example (with LLM and custom instructions): - - .. code-block:: python - - from agent_framework import GroupChatDirective - - - async def llm_based_selector(state: GroupChatStateSnapshot) -> GroupChatDirective | None: - if state["round_index"] >= 5: - return GroupChatDirective(finish=True) - - # Use LLM to decide next speaker and summarize conversation - conversation_summary = await summarize_with_llm(state["conversation"]) - next_agent = await pick_agent_with_llm(state["participants"], state["task"]) - - # Pass custom instruction to the selected agent - return GroupChatDirective( - agent_name=next_agent, - instruction=f"Context summary: {conversation_summary}", - ) - - - workflow = GroupChatBuilder().set_select_speakers_func(llm_based_selector).participants(...).build() - - Note: - Cannot be combined with :py:meth:`GroupChatBuilder.set_manager`. Choose one orchestration strategy. """ - manager_name = display_name or "manager" - adapter = _SpeakerSelectorAdapter( - selector, - manager_name=manager_name, - final_message=final_message, - ) - return self._set_manager_function(adapter, display_name) + if self._selection_func is not None: + raise ValueError( + "select_speakers_func has already been configured. Call with_select_speakers_func(...) at most once." + ) + if self._orchestrator is not None: + raise ValueError( + "An orchestrator has already been configured. " + "Call only one of with_select_speaker_func(...) or with_orchestrator(...)." + ) + if self._agent_orchestrator is not None: + raise ValueError( + "An agent orchestrator has already been configured. " + "Call only one of with_select_speaker_func(...) or with_agent_orchestrator(...)." + ) - def participants( - self, - participants: Mapping[str, AgentProtocol | Executor] | Sequence[AgentProtocol | Executor] | None = None, - /, - **named_participants: AgentProtocol | Executor, - ) -> "GroupChatBuilder": + self._selection_func = selection_func + self._orchestrator_name = orchestrator_name + return self + + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "GroupChatBuilder": """Define participants for this group chat workflow. Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances. - Provide a mapping of name → participant for explicit control, or pass a sequence and - names will be inferred from the agent's name attribute (or executor id). Args: - participants: Optional mapping or sequence of participant definitions - **named_participants: Keyword arguments mapping names to agent/executor instances + participants: Sequence of participant definitions Returns: Self for fluent chaining Raises: - ValueError: If participants are empty, names are duplicated, or names are empty strings + ValueError: If participants are empty, names are duplicated, or already set + TypeError: If any participant is not AgentProtocol or Executor instance - Usage: + Example: .. code-block:: python from agent_framework import GroupChatBuilder workflow = ( - GroupChatBuilder().set_manager(manager_agent).participants([writer_agent, reviewer_agent]).build() + GroupChatBuilder() + .with_select_speaker_func(my_selection_function) + .participants([agent1, agent2, custom_executor]) + .build() ) """ - combined: dict[str, AgentProtocol | Executor] = {} - - def _add(name: str, participant: AgentProtocol | Executor) -> None: - if not name: - raise ValueError("participant names must be non-empty strings") - if name in combined or name in self._participants: - raise ValueError(f"Duplicate participant name '{name}' supplied.") - if name == self._manager_name: - raise ValueError( - f"Participant name '{name}' conflicts with manager name. " - "Manager is automatically registered as a participant." - ) - combined[name] = participant - - if participants: - if isinstance(participants, Mapping): - for name, participant in participants.items(): - _add(name, participant) + if self._participants: + raise ValueError("participants have already been set. Call participants(...) at most once.") + + if not participants: + raise ValueError("participants cannot be empty.") + + # Name of the executor mapped to participant instance + named: dict[str, AgentProtocol | Executor] = {} + for participant in participants: + if isinstance(participant, Executor): + identifier = participant.id + elif isinstance(participant, AgentProtocol): + if not participant.name: + raise ValueError("AgentProtocol participants must have a non-empty name.") + identifier = participant.name else: - for participant in participants: - inferred_name: str - if isinstance(participant, Executor): - inferred_name = participant.id - else: - name_attr = getattr(participant, "name", None) - if not name_attr: - raise ValueError( - "Agent participants supplied via sequence must define a non-empty 'name' attribute." - ) - inferred_name = str(name_attr) - _add(inferred_name, participant) - - for name, participant in named_participants.items(): - _add(name, participant) - - if not combined: - raise ValueError("participants cannot be empty") - - for name, participant in combined.items(): - self._participants[name] = participant - self._participant_metadata = None - return self - - def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "GroupChatBuilder": - """Enable checkpointing for the built workflow using the provided storage. - - Checkpointing allows the workflow to persist state and resume from interruption - points, enabling long-running conversations and failure recovery. - - Args: - checkpoint_storage: Storage implementation for persisting workflow state - - Returns: - Self for fluent chaining + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) - Usage: + if identifier in named: + raise ValueError(f"Duplicate participant name '{identifier}' detected") - .. code-block:: python + named[identifier] = participant - from agent_framework import GroupChatBuilder, MemoryCheckpointStorage + self._participants = named - storage = MemoryCheckpointStorage() - workflow = ( - GroupChatBuilder() - .set_manager(manager_agent) - .participants(agent1=agent1, agent2=agent2) - .with_checkpointing(storage) - .build() - ) - """ - self._checkpoint_storage = checkpoint_storage return self - def with_request_handler( - self, - handler: Callable[[_GroupChatConfig], Executor] | Executor, - *, - condition: EdgeCondition, - ) -> "GroupChatBuilder": - """Register an interceptor factory that creates executors for special requests. + def with_termination_condition(self, termination_condition: TerminationCondition) -> "GroupChatBuilder": + """Set a custom termination condition for the group chat workflow. Args: - handler: Callable that receives the wiring and returns an executor, or a pre-built executor - condition: Filter determining which orchestrator messages the interceptor should process + termination_condition: Callable that receives the conversation history and returns + True to terminate the conversation, False to continue. Returns: Self for fluent chaining - """ - factory: Callable[[_GroupChatConfig], Executor] - if isinstance(handler, Executor): - executor = handler - - def _factory(_: _GroupChatConfig) -> Executor: - return executor - - factory = _factory - else: - factory = handler - - self._interceptors.append((factory, condition)) - return self - - def with_termination_condition( - self, - condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]], - ) -> "GroupChatBuilder": - """Define a custom termination condition for the group chat workflow. - - The condition receives the full conversation (including manager and agent messages) and may be async. - When it returns True, the orchestrator halts the conversation and emits a completion message authored - by the manager. Example: @@ -1756,324 +752,179 @@ def stop_after_two_calls(conversation: list[ChatMessage]) -> bool: specialist_agent = ... workflow = ( GroupChatBuilder() - .set_select_speakers_func(lambda _: "specialist") - .participants(specialist=specialist_agent) + .with_select_speaker_func(my_selection_function) + .participants([agent1, specialist_agent]) .with_termination_condition(stop_after_two_calls) .build() ) """ - self._termination_condition = condition + if self._orchestrator is not None: + logger.warning( + "Orchestrator has already been configured; setting termination condition on builder has no effect." + ) + + self._termination_condition = termination_condition return self def with_max_rounds(self, max_rounds: int | None) -> "GroupChatBuilder": - """Set a maximum number of manager rounds to prevent infinite conversations. + """Set a maximum number of orchestrator rounds to prevent infinite conversations. When the round limit is reached, the workflow automatically completes with a default completion message. Setting to None allows unlimited rounds. Args: - max_rounds: Maximum number of manager selection rounds, or None for unlimited + max_rounds: Maximum number of orchestrator selection rounds, or None for unlimited Returns: Self for fluent chaining - - Usage: - - .. code-block:: python - - from agent_framework import GroupChatBuilder - - # Limit to 15 rounds - workflow = ( - GroupChatBuilder() - .set_manager(manager_agent) - .participants(agent1=agent1, agent2=agent2) - .with_max_rounds(15) - .build() - ) - - # Unlimited rounds - workflow = ( - GroupChatBuilder().set_manager(manager_agent).participants(agent1=agent1).with_max_rounds(None).build() - ) """ self._max_rounds = max_rounds return self - def with_request_info( - self, - *, - agents: Sequence[str | AgentProtocol | Executor] | None = None, - ) -> "GroupChatBuilder": - """Enable request info before participants run in the workflow. + def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "GroupChatBuilder": + """Enable checkpointing for the built workflow using the provided storage. - When enabled, the workflow pauses before each participant runs, emitting - a RequestInfoEvent that allows the caller to review the conversation and - optionally inject guidance before the participant responds. The caller provides - input via the standard response_handler/request_info pattern. + Checkpointing allows the workflow to persist state and resume from interruption + points, enabling long-running conversations and failure recovery. Args: - agents: Optional filter - only pause before these specific agents/executors. - Accepts agent names (str), agent instances, or executor instances. - If None (default), pauses before every participant. + checkpoint_storage: Storage implementation for persisting workflow state Returns: - self: The builder instance for fluent chaining. + Self for fluent chaining Example: .. code-block:: python - # Pause before all participants - workflow = ( - GroupChatBuilder() - .set_manager(manager) - .participants([optimist, pragmatist, creative]) - .with_request_info() - .build() - ) + from agent_framework import GroupChatBuilder, MemoryCheckpointStorage - # Pause only before specific participants + storage = MemoryCheckpointStorage() workflow = ( GroupChatBuilder() - .set_manager(manager) - .participants([optimist, pragmatist, creative]) - .with_request_info(agents=[pragmatist]) # Only pause before pragmatist + .with_select_speaker_func(my_selection_function) + .participants([agent1, agent2]) + .with_checkpointing(storage) .build() ) """ - from ._orchestration_request_info import resolve_request_info_filter - - self._request_info_enabled = True - self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + self._checkpoint_storage = checkpoint_storage return self - def _get_participant_metadata(self) -> dict[str, Any]: - if self._participant_metadata is None: - self._participant_metadata = prepare_participant_metadata( - self._participants, - executor_id_factory=lambda name, participant: ( - participant.id if isinstance(participant, Executor) else f"groupchat_agent:{name}" - ), - description_factory=lambda name, participant: ( - participant.id if isinstance(participant, Executor) else participant.__class__.__name__ - ), - ) - return self._participant_metadata - - def _build_participant_specs(self) -> dict[str, GroupChatParticipantSpec]: - metadata = self._get_participant_metadata() - descriptions: Mapping[str, str] = metadata["descriptions"] - specs: dict[str, GroupChatParticipantSpec] = {} - for name, participant in self._participants.items(): - specs[name] = GroupChatParticipantSpec( - name=name, - participant=participant, - description=descriptions[name], - ) - return specs + def with_request_info(self, *, agents: Sequence[str | AgentProtocol] | None = None) -> "GroupChatBuilder": + """Enable request info after agent participant responses. - def build(self) -> Workflow: - """Build and validate the group chat workflow. + This enables human-in-the-loop (HIL) scenarios for the group chat orchestration. + When enabled, the workflow pauses after each agent participant runs, emitting + a RequestInfoEvent that allows the caller to review the conversation and optionally + inject guidance for the agent participant to iterate. The caller provides input via + the standard response_handler/request_info pattern. - Assembles the orchestrator, participants, and their interconnections into - a complete workflow graph. The orchestrator delegates speaker selection to - the manager, routes requests to the appropriate participants, and collects - their responses to continue or complete the conversation. + Simulated flow with HIL: + Input -> Orchestrator -> [Participant <-> Request Info] -> Orchestrator -> [Participant <-> Request Info] -> ... - Returns: - Validated Workflow instance ready for execution + Note: This is only available for agent participants. Executor participants can incorporate + request info handling in their own implementation if desired. - Raises: - ValueError: If manager or participants are not configured (when using default factory) + Args: + agents: Optional list of agents names to enable request info for. + If None, enables HIL for all agent participants. - Wiring pattern: - - Orchestrator receives initial input (str, ChatMessage, or list[ChatMessage]) - - Orchestrator queries manager for next action (participant selection or finish) - - If participant selected: request routed directly to participant entry node - - Participant pipeline: AgentExecutor for agents or custom executor chains - - Participant response flows back to orchestrator - - Orchestrator updates state and queries manager again - - When manager returns finish directive: orchestrator yields final message and becomes idle + Returns: + Self for fluent chaining + """ + from ._orchestration_request_info import resolve_request_info_filter - Usage: + self._request_info_enabled = True + self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) - .. code-block:: python + return self - from agent_framework import GroupChatBuilder + def _resolve_orchestrator(self, participants: Sequence[Executor]) -> Executor: + """Determine the orchestrator to use for the workflow. - # Execute the workflow - workflow = GroupChatBuilder().set_manager(manager_agent).participants(agent1=agent1, agent2=agent2).build() - async for message in workflow.run("Solve this problem collaboratively"): - print(message.text) + Args: + participants: List of resolved participant executors """ - # Manager is only required when using the default orchestrator factory - # Custom factories (e.g., MagenticBuilder) provide their own orchestrator with embedded manager - if ( - self._manager is None - and self._manager_participant is None - and self._orchestrator_factory == _default_orchestrator_factory - ): + if self._orchestrator is not None: + return self._orchestrator + + if self._agent_orchestrator is not None and self._selection_func is not None: raise ValueError( - "manager must be configured before build() when using default orchestrator. " - "Call set_manager(...) or set_select_speakers_func(...) before build()." + "Both agent-based orchestrator and selection function are configured; only one can be used at a time." ) - if not self._participants: - raise ValueError("participants must be configured before build()") - - metadata = self._get_participant_metadata() - participant_specs = self._build_participant_specs() - wiring = _GroupChatConfig( - manager=self._manager, - manager_participant=self._manager_participant, - manager_name=self._manager_name, - participants=participant_specs, - max_rounds=self._max_rounds, - termination_condition=self._termination_condition, - participant_aliases=metadata["aliases"], - participant_executors=metadata["executors"], - ) - # Determine participant factory - wrap if request info is enabled - participant_factory = self._participant_factory - if self._request_info_enabled: - # Create a wrapper factory that adds request info interceptor before each participant - base_factory = participant_factory - agent_filter = self._request_info_filter - - def _factory_with_request_info( - spec: GroupChatParticipantSpec, - config: _GroupChatConfig, - ) -> _GroupChatParticipantPipeline: - pipeline = list(base_factory(spec, config)) - if pipeline: - # Add interceptor executor BEFORE the participant (prepend) - interceptor = RequestInfoInterceptor( - executor_id=f"request_info:{spec.name}", - agent_filter=agent_filter, - ) - pipeline.insert(0, interceptor) - return tuple(pipeline) + if self._selection_func is not None: + return GroupChatOrchestrator( + id=self.DEFAULT_ORCHESTRATOR_ID, + participant_registry=ParticipantRegistry(participants), + selection_func=self._selection_func, + name=self._orchestrator_name, + max_rounds=self._max_rounds, + termination_condition=self._termination_condition, + ) - participant_factory = _factory_with_request_info + if self._agent_orchestrator is not None: + return AgentBasedGroupChatOrchestrator( + agent=self._agent_orchestrator, + participant_registry=ParticipantRegistry(participants), + max_rounds=self._max_rounds, + termination_condition=self._termination_condition, + ) - result = assemble_group_chat_workflow( - wiring=wiring, - participant_factory=participant_factory, - orchestrator_factory=self._orchestrator_factory, - interceptors=self._interceptors, - checkpoint_storage=self._checkpoint_storage, + raise RuntimeError( + "Orchestrator could not be resolved. Please provide one via with_orchestrator(), " + "with_agent_orchestrator(), or with_select_speaker_func()." ) - if not isinstance(result, Workflow): - raise TypeError("Expected Workflow from assemble_group_chat_workflow") - return result - - -# endregion + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + executors: list[Executor] = [] + for participant in self._participants.values(): + if isinstance(participant, Executor): + executors.append(participant) + elif isinstance(participant, AgentProtocol): + if self._request_info_enabled and ( + not self._request_info_filter or resolve_agent_id(participant) in self._request_info_filter + ): + # Handle request info enabled agents + executors.append(AgentApprovalExecutor(participant)) + else: + executors.append(AgentExecutor(participant)) + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) -# region Default manager implementation - - -DEFAULT_MANAGER_INSTRUCTIONS = """You are coordinating a team conversation to solve the user's task. -Your role is to orchestrate collaboration between multiple participants by selecting who speaks next. -Leverage each participant's unique expertise as described in their descriptions. -Have participants build on each other's contributions - earlier participants gather information, -later ones refine and synthesize. -Only finish the task after multiple relevant participants have contributed their expertise.""" - -DEFAULT_MANAGER_STRUCTURED_OUTPUT_PROMPT = """Return your decision using the following structure: -- next_agent: name of the participant who should act next (use null when finish is true) -- message: instruction for that participant (empty string if not needed) -- finish: boolean indicating if the task is complete -- final_response: when finish is true, provide the final answer to the user""" + return executors + def build(self) -> Workflow: + """Build and validate the group chat workflow. -class ManagerDirectiveModel(BaseModel): - """Pydantic model for structured manager directive output.""" - - next_agent: str | None = Field( - default=None, - description="Name of the participant who should act next (null when finish is true)", - ) - message: str = Field( - default="", - description="Instruction for the selected participant", - ) - finish: bool = Field( - default=False, - description="Whether the task is complete", - ) - final_response: str | None = Field( - default=None, - description="Final answer to the user when finish is true", - ) + Assembles the orchestrator and participants into a complete workflow graph. + The workflow graph consists of bi-directional edges between the orchestrator and each participant, + allowing for message exchanges in both directions. + Returns: + Validated Workflow instance ready for execution + """ + if not self._participants: + raise ValueError("participants must be configured before build()") -class _SpeakerSelectorAdapter: - """Adapter that turns a simple speaker selector into a full manager directive.""" + # Resolve orchestrator and participants to executors + participants: list[Executor] = self._resolve_participants() + orchestrator: Executor = self._resolve_orchestrator(participants) - def __init__( - self, - selector: Callable[[GroupChatStateSnapshot], Awaitable[Any]] | Callable[[GroupChatStateSnapshot], Any], - *, - manager_name: str, - final_message: ChatMessage | str | Callable[[GroupChatStateSnapshot], Any] | None = None, - ) -> None: - self._selector = selector - self._manager_name = manager_name - self._final_message = final_message - self.name = manager_name - - async def __call__(self, state: GroupChatStateSnapshot) -> GroupChatDirective: - result = await _maybe_await(self._selector(state)) - if result is None: - message = await self._resolve_final_message(state) - return GroupChatDirective(finish=True, final_message=message) - - if isinstance(result, Sequence) and not isinstance(result, (str, bytes, bytearray)): - if not result: - message = await self._resolve_final_message(state) - return GroupChatDirective(finish=True, final_message=message) - if len(result) != 1: # type: ignore[arg-type] - raise ValueError("Speaker selector must return a single participant name or None.") - first_item = result[0] # type: ignore[index] - if not isinstance(first_item, str): - raise TypeError("Speaker selector must return a participant name (str) or None.") - result = first_item - - if not isinstance(result, str): - raise TypeError("Speaker selector must return a participant name (str) or None.") - - return GroupChatDirective(agent_name=result) - - async def _resolve_final_message(self, state: GroupChatStateSnapshot) -> ChatMessage: - final_message = self._final_message - if callable(final_message): - value = await _maybe_await(final_message(state)) - else: - value = final_message - - if value is None: - message = ChatMessage( - role=Role.ASSISTANT, - text="Conversation completed.", - author_name=self._manager_name, - ) - elif isinstance(value, ChatMessage): - message = value - else: - message = ChatMessage( - role=Role.ASSISTANT, - text=str(value), - author_name=self._manager_name, - ) + # Build workflow graph + workflow_builder = WorkflowBuilder().set_start_executor(orchestrator) + for participant in participants: + # Orchestrator and participant bi-directional edges + workflow_builder = workflow_builder.add_edge(orchestrator, participant) + workflow_builder = workflow_builder.add_edge(participant, orchestrator) + if self._checkpoint_storage is not None: + workflow_builder = workflow_builder.with_checkpointing(self._checkpoint_storage) - if not message.author_name: - patch = message.to_dict() - patch["author_name"] = self._manager_name - message = ChatMessage.from_dict(patch) - return message + return workflow_builder.build() # endregion diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index a3f803ae87..e2ef233ee4 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -2,58 +2,53 @@ """High-level builder for conversational handoff workflows. -The handoff pattern models a coordinator agent that optionally routes -control to specialist agents before handing the conversation back to the user. -The flow is intentionally cyclical by default: +The handoff pattern models a group of agents that can intelligently route +control to other agents based on the conversation context. - user input -> coordinator -> optional specialist -> request user input -> ... +The flow is typically: -An autonomous interaction mode can bypass the user input request and continue routing -responses back to agents until a handoff occurs or termination criteria are met. + user input -> Agent A -> Agent B -> Agent C -> Agent A -> ... -> output + +Depending of wether request info is enabled, the flow may include user input (except when an agent hands off): + + user input -> [Agent A -> Request info] -> [Agent B -> Request info] -> [Agent C -> ... -> output + +The difference between a group chat workflow and a handoff workflow is that in group chat there is +always a orchestrator that decides who to speak next, while in handoff the agents themselves decide +who to handoff to next by invoking a tool call that names the target agent. + +Group Chat: centralized orchestration of multiple agents +Handoff: decentralized routing by agents themselves Key properties: - The entire conversation is maintained and reused on every hop -- The coordinator signals a handoff by invoking a tool call that names the specialist +- Agents signal handoffs by invoking a tool call that names the other agents - In human_in_loop mode (default), the workflow requests user input after each agent response that doesn't trigger a handoff - In autonomous mode, agents continue responding until they invoke a handoff tool or reach a termination condition or turn limit """ +import inspect import logging -import re import sys from collections.abc import Awaitable, Callable, Mapping, Sequence -from dataclasses import dataclass, field -from typing import Any, Literal - -from agent_framework import ( - AgentProtocol, - AgentRunResponse, - AIFunction, - ChatMessage, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, - Role, - ai_function, -) - -from .._agents import ChatAgent +from dataclasses import dataclass +from typing import Any, cast + +from typing_extensions import Never + +from .._agents import AgentProtocol, ChatAgent from .._middleware import FunctionInvocationContext, FunctionMiddleware +from .._threads import AgentThread +from .._tools import AIFunction, ai_function +from .._types import AgentRunResponse, ChatMessage, Role from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse -from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator +from ._agent_utils import resolve_agent_id +from ._base_group_chat_orchestrator import TerminationCondition from ._checkpoint import CheckpointStorage -from ._executor import Executor, handler -from ._group_chat import ( - _default_participant_factory, # type: ignore[reportPrivateUsage] - _GroupChatConfig, # type: ignore[reportPrivateUsage] - _GroupChatParticipantPipeline, # type: ignore[reportPrivateUsage] - assemble_group_chat_workflow, -) -from ._orchestration_request_info import RequestInfoInterceptor +from ._events import WorkflowEvent from ._orchestrator_helpers import clean_conversation_for_handoff -from ._participant_utils import GroupChatParticipantSpec, prepare_participant_metadata, sanitize_identifier from ._request_info_mixin import response_handler from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder @@ -67,144 +62,75 @@ logger = logging.getLogger(__name__) -_HANDOFF_TOOL_PATTERN = re.compile(r"(?:handoff|transfer)[_\s-]*to[_\s-]*(?P[\w-]+)", re.IGNORECASE) -_DEFAULT_AUTONOMOUS_TURN_LIMIT = 50 +# region Handoff events +class HandoffSentEvent(WorkflowEvent): + """Base class for handoff workflow events.""" -def _create_handoff_tool(alias: str, description: str | None = None) -> AIFunction[Any, Any]: - """Construct the synthetic handoff tool that signals routing to `alias`.""" - sanitized = sanitize_identifier(alias) - tool_name = f"handoff_to_{sanitized}" - doc = description or f"Handoff to the {alias} agent." - - # Note: approval_mode is intentionally NOT set for handoff tools. - # Handoff tools are framework-internal signals that trigger routing logic, - # not actual function executions. They are automatically intercepted by - # _AutoHandoffMiddleware which short-circuits execution and provides synthetic - # results, so the function body never actually runs in practice. - @ai_function(name=tool_name, description=doc) - def _handoff_tool(context: str | None = None) -> str: - """Return a deterministic acknowledgement that encodes the target alias.""" - return f"Handoff to {alias}" - - return _handoff_tool - - -def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: - """Produce a deep copy of the ChatAgent while preserving runtime configuration.""" - options = agent.default_options - middleware = list(agent.middleware or []) - - # Reconstruct the original tools list by combining regular tools with MCP tools. - # ChatAgent.__init__ separates MCP tools into _local_mcp_tools during initialization, - # so we need to recombine them here to pass the complete tools list to the constructor. - # This makes sure MCP tools are preserved when cloning agents for handoff workflows. - tools_from_options = options.get("tools") - all_tools = list(tools_from_options) if tools_from_options else [] - if agent._local_mcp_tools: # type: ignore - all_tools.extend(agent._local_mcp_tools) # type: ignore - - logit_bias = options.get("logit_bias") - metadata = options.get("metadata") - - return ChatAgent( - chat_client=agent.chat_client, - instructions=options.get("instructions"), - id=agent.id, - name=agent.name, - description=agent.description, - chat_message_store_factory=agent.chat_message_store_factory, - context_provider=agent.context_provider, - middleware=middleware, - # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. - allow_multiple_tool_calls=False, - frequency_penalty=options.get("frequency_penalty"), - logit_bias=dict(logit_bias) if logit_bias else None, - max_tokens=options.get("max_tokens"), - metadata=dict(metadata) if metadata else None, - model_id=options.get("model_id"), - presence_penalty=options.get("presence_penalty"), - response_format=options.get("response_format"), - seed=options.get("seed"), - stop=options.get("stop"), - store=options.get("store"), - temperature=options.get("temperature"), - tool_choice=options.get("tool_choice"), # type: ignore[arg-type] - tools=all_tools if all_tools else None, - top_p=options.get("top_p"), - user=options.get("user"), - ) + def __init__(self, source: str, target: str, data: Any | None = None) -> None: + """Initialize handoff sent event. + Args: + source: Identifier of the source agent initiating the handoff + target: Identifier of the target agent receiving the handoff + data: Optional event-specific data + """ + super().__init__(data) + self.source = source + self.target = target -@dataclass -class HandoffUserInputRequest: - """Request message emitted when the workflow needs fresh user input. - Note: The conversation field is intentionally excluded from checkpoint serialization - to prevent duplication. The conversation is preserved in the coordinator's state - and will be reconstructed on restore. See issue #2667. - """ +# endregion - conversation: list[ChatMessage] - awaiting_agent_id: str - prompt: str - source_executor_id: str - def to_dict(self) -> dict[str, Any]: - """Serialize to dict, excluding conversation to prevent checkpoint duplication. +@dataclass +class HandoffConfiguration: + """Configuration for handoff routing between agents. - The conversation is already preserved in the workflow coordinator's state. - Including it here would cause duplicate messages when restoring from checkpoint. - """ - return { - "awaiting_agent_id": self.awaiting_agent_id, - "prompt": self.prompt, - "source_executor_id": self.source_executor_id, - } + Attributes: + target_id: Identifier of the target agent to hand off to + description: Optional human-readable description of the handoff + """ - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "HandoffUserInputRequest": - """Deserialize from dict, initializing conversation as empty. + target_id: str + description: str | None = None - The conversation will be reconstructed from the coordinator's state on restore. + def __init__(self, *, target: str | AgentProtocol, description: str | None = None) -> None: + """Initialize HandoffConfiguration. + + Args: + target: Target agent identifier or AgentProtocol instance + description: Optional human-readable description of the handoff """ - return cls( - conversation=[], - awaiting_agent_id=data["awaiting_agent_id"], - prompt=data["prompt"], - source_executor_id=data["source_executor_id"], - ) + self.target_id = resolve_agent_id(target) if isinstance(target, AgentProtocol) else target + self.description = description + def __eq__(self, other: Any) -> bool: + """Determine equality based on source_id and target_id.""" + if not isinstance(other, HandoffConfiguration): + return False -@dataclass -class _ConversationWithUserInput: - """Internal message carrying full conversation + new user messages from gateway to coordinator. + return self.target_id == other.target_id - Attributes: - full_conversation: The conversation messages to process. - is_post_restore: If True, indicates this message was created after a checkpoint restore. - The coordinator should append these messages to its existing conversation rather - than replacing it. This prevents duplicate messages (see issue #2667). - """ + def __hash__(self) -> int: + """Compute hash based on source_id and target_id.""" + return hash(self.target_id) - full_conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc] - is_post_restore: bool = False +def get_handoff_tool_name(target_id: str) -> str: + """Get the standardized handoff tool name for a given target agent ID.""" + return f"handoff_to_{target_id}" -@dataclass -class _ConversationForUserInput: - """Internal message from coordinator to gateway specifying which agent will receive the response.""" - conversation: list[ChatMessage] - next_agent_id: str +HANDOFF_FUNCTION_RESULT_KEY = "handoff_to" class _AutoHandoffMiddleware(FunctionMiddleware): """Intercept handoff tool invocations and short-circuit execution with synthetic results.""" - def __init__(self, handoff_targets: Mapping[str, str]) -> None: + def __init__(self, handoffs: Sequence[HandoffConfiguration]) -> None: """Initialise middleware with the mapping from tool name to specialist id.""" - self._targets = {name.lower(): target for name, target in handoff_targets.items()} + self._handoff_functions = {get_handoff_tool_name(handoff.target_id): handoff.target_id for handoff in handoffs} async def process( self, @@ -212,782 +138,500 @@ async def process( next: Callable[[FunctionInvocationContext], Awaitable[None]], ) -> None: """Intercept matching handoff tool calls and inject synthetic results.""" - name = getattr(context.function, "name", "") - normalized = name.lower() if name else "" - target = self._targets.get(normalized) - if target is None: + if context.function.name not in self._handoff_functions: await next(context) return # Short-circuit execution and provide deterministic response payload for the tool call. - context.result = {"handoff_to": target} + context.result = {HANDOFF_FUNCTION_RESULT_KEY: self._handoff_functions[context.function.name]} context.terminate = True -class _InputToConversation(Executor): - """Normalizes initial workflow input into a list[ChatMessage].""" +@dataclass +class HandoffAgentUserRequest: + """Request issued to the user after an agent run in a handoff workflow. + + Attributes: + agent_response: The response generated by the agent at the most recent turn + """ - @handler - async def from_str(self, prompt: str, ctx: WorkflowContext[list[ChatMessage]]) -> None: - """Convert a raw user prompt into a conversation containing a single user message.""" - await ctx.send_message([ChatMessage(Role.USER, text=prompt)]) + agent_response: AgentRunResponse + + @staticmethod + def create_response(response: str | list[str] | ChatMessage | list[ChatMessage]) -> list[ChatMessage]: + """Create a HandoffAgentUserRequest from a simple text response.""" + messages: list[ChatMessage] = [] + if isinstance(response, str): + messages.append(ChatMessage(role=Role.USER, text=response)) + elif isinstance(response, ChatMessage): + messages.append(response) + elif isinstance(response, list): + for item in response: + if isinstance(item, ChatMessage): + messages.append(item) + elif isinstance(item, str): + messages.append(ChatMessage(role=Role.USER, text=item)) + else: + raise TypeError("List items must be either str or ChatMessage instances") + else: + raise TypeError("Response must be str, list of str, ChatMessage, or list of ChatMessage") - @handler - async def from_message(self, message: ChatMessage, ctx: WorkflowContext[list[ChatMessage]]) -> None: - """Pass through an existing chat message as the initial conversation.""" - await ctx.send_message([message]) + return messages - @handler - async def from_messages(self, messages: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: - """Forward a list of chat messages as the starting conversation history.""" - await ctx.send_message(list(messages)) + @staticmethod + def terminate() -> list[ChatMessage]: + """Create a termination response for the handoff workflow.""" + return [] -@dataclass -class _HandoffResolution: - """Result of handoff detection containing the target alias and originating call.""" - - target: str - function_call: FunctionCallContent | None = None - - -def _resolve_handoff_target(agent_response: AgentRunResponse) -> _HandoffResolution | None: - """Detect handoff intent from tool call metadata.""" - for message in agent_response.messages: - resolution = _resolution_from_message(message) - if resolution: - return resolution - - for request in agent_response.user_input_requests: - if isinstance(request, FunctionApprovalRequestContent): - resolution = _resolution_from_function_call(request.function_call) - if resolution: - return resolution - - return None - - -def _resolution_from_message(message: ChatMessage) -> _HandoffResolution | None: - """Inspect an assistant message for embedded handoff tool metadata.""" - for content in getattr(message, "contents", ()): - if isinstance(content, FunctionApprovalRequestContent): - resolution = _resolution_from_function_call(content.function_call) - if resolution: - return resolution - elif isinstance(content, FunctionCallContent): - resolution = _resolution_from_function_call(content) - if resolution: - return resolution - return None - - -def _resolution_from_function_call(function_call: FunctionCallContent | None) -> _HandoffResolution | None: - """Wrap the target resolved from a function call in a `_HandoffResolution`.""" - if function_call is None: - return None - target = _target_from_function_call(function_call) - if not target: - return None - return _HandoffResolution(target=target, function_call=function_call) - - -def _target_from_function_call(function_call: FunctionCallContent) -> str | None: - """Extract the handoff target from the tool name or structured arguments.""" - name_candidate = _target_from_tool_name(function_call.name) - if name_candidate: - return name_candidate - - arguments = function_call.parse_arguments() - if isinstance(arguments, Mapping): - value = arguments.get("handoff_to") - if isinstance(value, str) and value.strip(): - return value.strip() - elif isinstance(arguments, str): - stripped = arguments.strip() - if stripped: - name_candidate = _target_from_tool_name(stripped) - if name_candidate: - return name_candidate - return stripped - - return None - - -def _target_from_tool_name(name: str | None) -> str | None: - """Parse the specialist alias encoded in a handoff tool's name.""" - if not name: - return None - match = _HANDOFF_TOOL_PATTERN.search(name) - if match: - parsed = match.group("target").strip() - if parsed: - return parsed - return None +# In autonomous mode, the agent continues responding until it requests a handoff +# or reaches a turn limit, after which it requests user input to continue. +_AUTONOMOUS_MODE_DEFAULT_PROMPT = "User did not respond. Continue assisting autonomously." +_DEFAULT_AUTONOMOUS_TURN_LIMIT = 50 + +# region Handoff Agent Executor -class _HandoffCoordinator(BaseGroupChatOrchestrator): - """Coordinates agent-to-agent transfers and user turn requests.""" +class HandoffAgentExecutor(AgentExecutor): + """Specialized AgentExecutor that supports handoff tool interception.""" def __init__( self, + agent: AgentProtocol, + handoffs: Sequence[HandoffConfiguration], *, - starting_agent_id: str, - specialist_ids: Mapping[str, str], - input_gateway_id: str, - termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]], - id: str, - handoff_tool_targets: Mapping[str, str] | None = None, - return_to_previous: bool = False, - interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop", - autonomous_turn_limit: int | None = None, + agent_thread: AgentThread | None = None, + is_start_agent: bool = False, + termination_condition: TerminationCondition | None = None, + autonomous_mode: bool = False, + autonomous_mode_prompt: str | None = None, + autonomous_mode_turn_limit: int | None = None, ) -> None: - """Create a coordinator that manages routing between specialists and the user.""" - super().__init__(id) - self._starting_agent_id = starting_agent_id - self._specialist_by_alias = dict(specialist_ids) - self._specialist_ids = set(specialist_ids.values()) - self._input_gateway_id = input_gateway_id - self._termination_condition = termination_condition - self._handoff_tool_targets = {k.lower(): v for k, v in (handoff_tool_targets or {}).items()} - self._return_to_previous = return_to_previous - self._current_agent_id: str | None = None # Track the current agent handling conversation - self._interaction_mode = interaction_mode - self._autonomous_turn_limit = autonomous_turn_limit - self._autonomous_turns = 0 + """Initialize the HandoffAgentExecutor. - def _get_author_name(self) -> str: - """Get the coordinator name for orchestrator-generated messages.""" - return "handoff_coordinator" + Args: + agent: The agent to execute + handoffs: Sequence of handoff configurations defining target agents + agent_thread: Optional AgentThread that manages the agent's execution context + is_start_agent: Whether this agent is the starting agent in the handoff workflow. + There can only be one starting agent in a handoff workflow. + termination_condition: Optional callable that determines when to terminate the workflow + autonomous_mode: Whether the agent should operate involve external systems after + a response that does not trigger a handoff or before the turn + limit is reached. This allows the agent to perform long-running + tasks (e.g., research, coding, analysis) without prematurely returning + control to the coordinator or user. + autonomous_mode_prompt: Prompt to provide to the agent when continuing in autonomous mode. + This will guide the agent in the absence of user input. + autonomous_mode_turn_limit: Maximum number of autonomous turns before requesting user input. + """ + cloned_agent = self._prepare_agent_with_handoffs(agent, handoffs) + super().__init__(cloned_agent, agent_thread=agent_thread) - def _extract_agent_id_from_source(self, source: str | None) -> str | None: - """Extract the original agent ID from the source executor ID. + self._handoff_targets = {handoff.target_id for handoff in handoffs} + self._termination_condition = termination_condition + self._is_start_agent = is_start_agent - When a request info interceptor is in the pipeline, the source will be - like 'request_info:agent_name'. This method extracts the - actual agent ID. + # Autonomous mode members + self._autonomous_mode = autonomous_mode + self._autonomous_mode_prompt = autonomous_mode_prompt or _AUTONOMOUS_MODE_DEFAULT_PROMPT + self._autonomous_mode_turn_limit = autonomous_mode_turn_limit or _DEFAULT_AUTONOMOUS_TURN_LIMIT + self._autonomous_mode_turns = 0 + + def _prepare_agent_with_handoffs( + self, + agent: AgentProtocol, + handoffs: Sequence[HandoffConfiguration], + ) -> AgentProtocol: + """Prepare an agent by adding handoff tools for the specified target agents. Args: - source: The source executor ID from the workflow context + agent: The agent to prepare + handoffs: Sequence of handoff configurations defining target agents Returns: - The actual agent ID, or the original source if not an interceptor + A new AgentExecutor instance with handoff tools added """ - if source is None: - return None - if source.startswith("request_info:"): - return source[len("request_info:") :] - # TODO(@moonbox3): Remove legacy prefix support in a separate PR (GA cleanup) - if source.startswith("human_review:"): - return source[len("human_review:") :] - if source.startswith("human_input_interceptor:"): - return source[len("human_input_interceptor:") :] - return source - - @handler - async def handle_agent_response( - self, - response: AgentExecutorResponse, - ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], list[ChatMessage] | _ConversationForUserInput], - ) -> None: - """Process an agent's response and determine whether to route, request input, or terminate.""" - raw_source = ctx.get_source_executor_id() - source = self._extract_agent_id_from_source(raw_source) - is_starting_agent = source == self._starting_agent_id - - # On first turn of a run, conversation is empty - # Track new messages only, build authoritative history incrementally - conversation_msgs = self._get_conversation() - if not conversation_msgs: - # First response from starting agent - initialize with authoritative conversation snapshot - # Keep the FULL conversation including tool calls (OpenAI SDK default behavior) - full_conv = self._conversation_from_response(response) - self._conversation = list(full_conv) - else: - # Subsequent responses - append only new messages from this agent - # Keep ALL messages including tool calls to maintain complete history. - # This includes assistant messages with function calls and tool role messages with results. - new_messages = response.agent_run_response.messages or [] - self._conversation.extend(new_messages) - - self._apply_response_metadata(self._conversation, response.agent_run_response) - - conversation = list(self._conversation) - - # Check for handoff from ANY agent (starting agent or specialist) - target = self._resolve_specialist(response.agent_run_response, conversation) - if target is not None: - # Update current agent when handoff occurs - self._current_agent_id = target - self._autonomous_turns = 0 - logger.info(f"Handoff detected: {source} -> {target}. Routing control to specialist '{target}'.") - - # Clean tool-related content before sending to next agent - cleaned = clean_conversation_for_handoff(conversation) - request = AgentExecutorRequest(messages=cleaned, should_respond=True) - await ctx.send_message(request, target_id=target) - return - - # No handoff detected - response must come from starting agent or known specialist - if not is_starting_agent and source not in self._specialist_ids: - raise RuntimeError(f"HandoffCoordinator received response from unknown executor '{source}'.") - - # Update current agent when they respond without handoff - self._current_agent_id = source - if await self._check_termination(): - # Clean the output conversation for display - cleaned_output = clean_conversation_for_handoff(conversation) - await ctx.yield_output(cleaned_output) - return - - if self._interaction_mode == "autonomous": - self._autonomous_turns += 1 - if self._autonomous_turn_limit is not None and self._autonomous_turns >= self._autonomous_turn_limit: - logger.info( - f"Autonomous turn limit reached ({self._autonomous_turn_limit}). " - "Yielding conversation and stopping." - ) - cleaned_output = clean_conversation_for_handoff(conversation) - await ctx.yield_output(cleaned_output) - return - - # In autonomous mode, agents continue iterating until they invoke a handoff tool - logger.info( - f"Agent '{source}' responded without handoff (turn {self._autonomous_turns}). " - "Continuing autonomous execution." + if not isinstance(agent, ChatAgent): + raise TypeError( + "Handoff can only be applied to ChatAgent. Please ensure the agent is a ChatAgent instance." ) - cleaned = clean_conversation_for_handoff(conversation) - request = AgentExecutorRequest(messages=cleaned, should_respond=True) - await ctx.send_message(request, target_id=source) - return - logger.info( - f"Agent '{source}' responded without handoff. " - f"Requesting user input. Return-to-previous: {self._return_to_previous}" + # Clone the agent to avoid mutating the original + cloned_agent = self._clone_chat_agent(agent) # type: ignore + # Add handoff tools to the cloned agent + self._apply_auto_tools(cloned_agent, handoffs) + # Add middleware to handle handoff tool invocations + middleware = _AutoHandoffMiddleware(handoffs) + existing_middleware = list(cloned_agent.middleware or []) + existing_middleware.append(middleware) + cloned_agent.middleware = existing_middleware + + return cloned_agent + + def _clone_chat_agent(self, agent: ChatAgent) -> ChatAgent: + """Produce a deep copy of the ChatAgent while preserving runtime configuration.""" + options = agent.default_options + middleware = list(agent.middleware or []) + + # Reconstruct the original tools list by combining regular tools with MCP tools. + # ChatAgent.__init__ separates MCP tools into _local_mcp_tools during initialization, + # so we need to recombine them here to pass the complete tools list to the constructor. + # This makes sure MCP tools are preserved when cloning agents for handoff workflows. + tools_from_options = options.get("tools") + all_tools = list(tools_from_options) if tools_from_options else [] + if agent._local_mcp_tools: # type: ignore + all_tools.extend(agent._local_mcp_tools) # type: ignore + + logit_bias = options.get("logit_bias") + metadata = options.get("metadata") + + return ChatAgent( + chat_client=agent.chat_client, + instructions=options.get("instructions"), + id=agent.id, + name=agent.name, + description=agent.description, + chat_message_store_factory=agent.chat_message_store_factory, + context_providers=agent.context_provider, + middleware=middleware, + # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. + allow_multiple_tool_calls=False, + frequency_penalty=options.get("frequency_penalty"), + logit_bias=dict(logit_bias) if logit_bias else None, + max_tokens=options.get("max_tokens"), + metadata=dict(metadata) if metadata else None, + model_id=options.get("model_id"), + presence_penalty=options.get("presence_penalty"), + response_format=options.get("response_format"), + seed=options.get("seed"), + stop=options.get("stop"), + store=options.get("store"), + temperature=options.get("temperature"), + tool_choice=options.get("tool_choice"), # type: ignore[arg-type] + tools=all_tools if all_tools else None, + top_p=options.get("top_p"), + user=options.get("user"), ) - # Clean conversation before sending to gateway for user input request - # This removes tool messages that shouldn't be shown to users - cleaned_for_display = clean_conversation_for_handoff(conversation) - - # The awaiting_agent_id is the agent that just responded and is awaiting user input - # This is the source of the current response (fallback to starting agent if source is unknown) - next_agent_id = source or self._starting_agent_id - - message_to_gateway = _ConversationForUserInput(conversation=cleaned_for_display, next_agent_id=next_agent_id) - await ctx.send_message(message_to_gateway, target_id=self._input_gateway_id) # type: ignore[arg-type] - - @handler - async def handle_user_input( - self, - message: _ConversationWithUserInput, - ctx: WorkflowContext[AgentExecutorRequest, list[ChatMessage]], - ) -> None: - """Receive user input from gateway, update history, and route to agent. + def _apply_auto_tools(self, agent: ChatAgent, targets: Sequence[HandoffConfiguration]) -> None: + """Attach synthetic handoff tools to a chat agent and return the target lookup table. - The message.full_conversation may contain: - - Full conversation history + new user messages (normal flow) - - Only new user messages (post-checkpoint-restore flow, see issue #2667) + Creates handoff tools for each specialist agent that this agent can route to. - The gateway sets message.is_post_restore=True when resuming after a checkpoint - restore. In that case, we append the new messages to the existing conversation - rather than replacing it. + Args: + agent: The ChatAgent to add handoff tools to + targets: Sequence of handoff configurations defining target agents """ - incoming = message.full_conversation - - if message.is_post_restore and self._conversation: - # Post-restore: append new user messages to existing conversation - # The coordinator already has its conversation restored from checkpoint - self._conversation.extend(incoming) - else: - # Normal flow: replace with full conversation - self._conversation = list(incoming) if incoming else self._conversation - - # Reset autonomous turn counter on new user input - self._autonomous_turns = 0 - - # Check termination before sending to agent - if await self._check_termination(): - await ctx.yield_output(list(self._conversation)) - return + default_options = agent.default_options + existing_tools = list(default_options.get("tools") or []) + existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - # Determine routing target based on return-to-previous setting - target_agent_id = self._starting_agent_id - if self._return_to_previous and self._current_agent_id: - # Route back to the current agent that's handling the conversation - target_agent_id = self._current_agent_id - logger.info( - f"Return-to-previous enabled: routing user input to current agent '{target_agent_id}' " - f"(bypassing coordinator '{self._starting_agent_id}')" - ) - else: - logger.info(f"Routing user input to coordinator '{target_agent_id}'") - - # Clean conversation before sending to target agent - # Removes tool-related messages that shouldn't be resent on every turn - cleaned = clean_conversation_for_handoff(self._conversation) - request = AgentExecutorRequest(messages=cleaned, should_respond=True) - await ctx.send_message(request, target_id=target_agent_id) - - def _resolve_specialist(self, agent_response: AgentRunResponse, conversation: list[ChatMessage]) -> str | None: - """Resolve the specialist executor id requested by the agent response, if any.""" - resolution = _resolve_handoff_target(agent_response) - if not resolution: - return None + new_tools: list[AIFunction[Any, Any]] = [] + for target in targets: + tool = self._create_handoff_tool(target.target_id, target.description) + if tool.name in existing_names: + raise ValueError( + f"Agent '{resolve_agent_id(agent)}' already has a tool named '{tool.name}'. " + f"Handoff tool name '{tool.name}' conflicts with existing tool." + "Please rename the existing tool or modify the target agent ID to avoid conflicts." + ) + new_tools.append(tool) - candidate = resolution.target - normalized = candidate.lower() - resolved_id: str | None - if normalized in self._handoff_tool_targets: - resolved_id = self._handoff_tool_targets[normalized] + if new_tools: + default_options["tools"] = existing_tools + new_tools # type: ignore[operator] else: - resolved_id = self._specialist_by_alias.get(candidate) - - if resolved_id: - if resolution.function_call: - self._append_tool_acknowledgement(conversation, resolution.function_call, resolved_id) - return resolved_id - - lowered = candidate.lower() - for alias, exec_id in self._specialist_by_alias.items(): - if alias.lower() == lowered: - if resolution.function_call: - self._append_tool_acknowledgement(conversation, resolution.function_call, exec_id) - return exec_id + default_options["tools"] = existing_tools - logger.warning("Handoff requested unknown specialist '%s'.", candidate) - return None + def _create_handoff_tool(self, target_id: str, description: str | None = None) -> AIFunction[Any, Any]: + """Construct the synthetic handoff tool that signals routing to `target_id`.""" + tool_name = get_handoff_tool_name(target_id) + doc = description or f"Handoff to the {target_id} agent." + # Note: approval_mode is intentionally NOT set for handoff tools. + # Handoff tools are framework-internal signals that trigger routing logic, + # not actual function executions. They are automatically intercepted by + # _AutoHandoffMiddleware which short-circuits execution and provides synthetic + # results, so the function body never actually runs in practice. - def _append_tool_acknowledgement( - self, - conversation: list[ChatMessage], - function_call: FunctionCallContent, - resolved_id: str, - ) -> None: - """Append a synthetic tool result acknowledging the resolved specialist id.""" - call_id = getattr(function_call, "call_id", None) - if not call_id: - return + @ai_function(name=tool_name, description=doc) + def _handoff_tool(context: str | None = None) -> str: + """Return a deterministic acknowledgement that encodes the target alias.""" + return f"Handoff to {target_id}" - result_payload: Any = {"handoff_to": resolved_id} - result_content = FunctionResultContent(call_id=call_id, result=result_payload) - tool_message = ChatMessage( - role=Role.TOOL, - contents=[result_content], - author_name=function_call.name, - ) - # Add tool acknowledgement to both the conversation being sent and the full history - conversation.extend((tool_message,)) - self._append_messages((tool_message,)) - - def _conversation_from_response(self, response: AgentExecutorResponse) -> list[ChatMessage]: - """Return the authoritative conversation snapshot from an executor response.""" - conversation = response.full_conversation - if conversation is None: - raise RuntimeError( - "AgentExecutorResponse.full_conversation missing; AgentExecutor must populate it in handoff workflows." - ) - return list(conversation) + return _handoff_tool @override - def _snapshot_pattern_metadata(self) -> dict[str, Any]: - """Serialize pattern-specific state. + async def _run_agent_and_emit(self, ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse]) -> None: + """Override to support handoff.""" + # When the full conversation is empty, it means this is the first run. + # Broadcast the initial cache to all other agents. Subsequent runs won't + # need this since responses are broadcast after each agent run and user input. + if self._is_start_agent and not self._full_conversation: + await self._broadcast_messages(self._cache.copy(), cast(WorkflowContext[AgentExecutorRequest], ctx)) + + # Append the cache to the full conversation history + self._full_conversation.extend(self._cache) + + # Check termination condition before running the agent + if await self._check_terminate_and_yield(cast(WorkflowContext[Never, list[ChatMessage]], ctx)): + return - Includes the current agent for return-to-previous routing. + # Run the agent + if ctx.is_streaming(): + # Streaming mode: emit incremental updates + response = await self._run_agent_streaming(cast(WorkflowContext, ctx)) + else: + # Non-streaming mode: use run() and emit single event + response = await self._run_agent(cast(WorkflowContext, ctx)) - Returns: - Dict containing current agent if return-to-previous is enabled - """ - metadata: dict[str, Any] = {} - if self._return_to_previous: - metadata["current_agent_id"] = self._current_agent_id - if self._interaction_mode == "autonomous": - metadata["autonomous_turns"] = self._autonomous_turns - return metadata + # Clear the cache after running the agent + self._cache.clear() - @override - def _restore_pattern_metadata(self, metadata: dict[str, Any]) -> None: - """Restore pattern-specific state. + # A function approval request is issued by the base AgentExecutor + if response is None: + # Agent did not complete (e.g., waiting for user input); do not emit response + logger.debug("AgentExecutor %s: Agent did not complete, awaiting user input", self.id) + return - Restores the current agent for return-to-previous routing. + # Remove function call related content from the agent response for full conversation history + cleaned_response = clean_conversation_for_handoff(response.messages) + # Append the agent response to the full conversation history. This list removes + # function call related content such that the result stays consistent regardless + # of which agent yields the final output. + self._full_conversation.extend(cleaned_response) + # Broadcast the cleaned response to all other agents + await self._broadcast_messages(cleaned_response, cast(WorkflowContext[AgentExecutorRequest], ctx)) + + # Check if a handoff was requested + if handoff_target := self._is_handoff_requested(response): + if handoff_target not in self._handoff_targets: + raise ValueError( + f"Agent '{resolve_agent_id(self._agent)}' attempted to handoff to unknown " + f"target '{handoff_target}'. Valid targets are: {', '.join(self._handoff_targets)}" + ) - Args: - metadata: Pattern-specific state dict - """ - if self._return_to_previous and "current_agent_id" in metadata: - self._current_agent_id = metadata["current_agent_id"] - if self._interaction_mode == "autonomous" and "autonomous_turns" in metadata: - self._autonomous_turns = metadata["autonomous_turns"] - - def _apply_response_metadata(self, conversation: list[ChatMessage], agent_response: AgentRunResponse) -> None: - """Merge top-level response metadata into the latest assistant message.""" - if not agent_response.additional_properties: + await cast(WorkflowContext[AgentExecutorRequest], ctx).send_message( + AgentExecutorRequest(messages=[], should_respond=True), target_id=handoff_target + ) + await ctx.add_event(HandoffSentEvent(source=self.id, target=handoff_target)) + self._autonomous_mode_turns = 0 # Reset autonomous mode turn counter on handoff return - # Find the most recent assistant message contributed by this response - for message in reversed(conversation): - if message.role == Role.ASSISTANT: - metadata = agent_response.additional_properties or {} - if not metadata: - return - # Merge metadata without mutating shared dict from agent response - merged = dict(message.additional_properties or {}) - for key, value in metadata.items(): - merged.setdefault(key, value) - message.additional_properties = merged - break - - -class _UserInputGateway(Executor): - """Bridges conversation context with the request & response cycle and re-enters the loop.""" - - def __init__(self, *, starting_agent_id: str, prompt: str | None, id: str) -> None: - """Initialise the gateway that requests user input and forwards responses.""" - super().__init__(id) - self._starting_agent_id = starting_agent_id - self._prompt = prompt or "Provide your next input for the conversation." - - @handler - async def request_input(self, message: _ConversationForUserInput, ctx: WorkflowContext) -> None: - """Emit a `HandoffUserInputRequest` capturing the conversation snapshot.""" - if not message.conversation: - raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.") - request = HandoffUserInputRequest( - conversation=list(message.conversation), - awaiting_agent_id=message.next_agent_id, - prompt=self._prompt, - source_executor_id=self.id, - ) - await ctx.request_info(request, object) - - @handler - async def request_input_legacy(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None: - """Legacy handler for backward compatibility - emit user input request with starting agent.""" - if not conversation: - raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.") - request = HandoffUserInputRequest( - conversation=list(conversation), - awaiting_agent_id=self._starting_agent_id, - prompt=self._prompt, - source_executor_id=self.id, - ) - await ctx.request_info(request, object) + # Handle case where no handoff was requested + if self._autonomous_mode and self._autonomous_mode_turns < self._autonomous_mode_turn_limit: + # In autonomous mode, continue running the agent until a handoff is requested + # or a termination condition is met. + # This allows the agent to perform long-running tasks without returning control + # to the coordinator or user prematurely. + self._cache.extend([ChatMessage(role=Role.USER, text=self._autonomous_mode_prompt)]) + self._autonomous_mode_turns += 1 + await self._run_agent_and_emit(ctx) + else: + # The response is handled via `handle_response` + self._autonomous_mode_turns = 0 # Reset autonomous mode turn counter on handoff + await ctx.request_info(HandoffAgentUserRequest(response), list[ChatMessage]) @response_handler - async def resume_from_user( + async def handle_response( self, - original_request: HandoffUserInputRequest, - response: object, - ctx: WorkflowContext[_ConversationWithUserInput], + original_request: HandoffAgentUserRequest, + response: list[ChatMessage], + ctx: WorkflowContext[AgentExecutorResponse, AgentRunResponse], ) -> None: - """Convert user input responses back into chat messages and resume the workflow. + """Handle user response for a request that is issued after agent runs. - After checkpoint restore, original_request.conversation will be empty (not serialized - to prevent duplication - see issue #2667). In this case, we send only the new user - messages and let the coordinator append them to its already-restored conversation. - """ - user_messages = _as_user_messages(response) + The request only occurs when the agent did not request a handoff and + autonomous mode is disabled. - if original_request.conversation: - # Normal flow: have conversation history from the original request - conversation = list(original_request.conversation) - conversation.extend(user_messages) - message = _ConversationWithUserInput(full_conversation=conversation, is_post_restore=False) - else: - # Post-restore flow: conversation was not serialized, send only new user messages - # The coordinator will append these to its already-restored conversation - message = _ConversationWithUserInput(full_conversation=user_messages, is_post_restore=True) + Note that this is different that the `handle_user_input_response` method + in the base AgentExecutor, which handles function approval responses. - await ctx.send_message(message, target_id="handoff-coordinator") + Args: + original_request: The original HandoffAgentUserRequest issued to the user + response: The user's response messages + ctx: The workflow context + If the response is empty, it indicates termination of the handoff workflow. + """ + if not response: + await cast(WorkflowContext[Never, list[ChatMessage]], ctx).yield_output(self._full_conversation) + return -def _as_user_messages(payload: Any) -> list[ChatMessage]: - """Normalize arbitrary payloads into user-authored chat messages. + # Broadcast the user response to all other agents + await self._broadcast_messages(response, cast(WorkflowContext[AgentExecutorRequest], ctx)) - Handles various input formats: - - ChatMessage instances (converted to USER role if needed) - - List of ChatMessage instances - - Mapping with 'text' or 'content' key - - Any other type (converted to string) + # Append the user response messages to the cache + self._cache.extend(response) + await self._run_agent_and_emit(ctx) - Returns: - List of ChatMessage instances with USER role. - """ - if isinstance(payload, ChatMessage): - if payload.role == Role.USER: - return [payload] - return [ChatMessage(Role.USER, text=payload.text)] - if isinstance(payload, list): - # Check if all items are ChatMessage instances - all_chat_messages = all(isinstance(msg, ChatMessage) for msg in payload) # type: ignore[arg-type] - if all_chat_messages: - messages: list[ChatMessage] = payload # type: ignore[assignment] - return [msg if msg.role == Role.USER else ChatMessage(Role.USER, text=msg.text) for msg in messages] - if isinstance(payload, Mapping): # User supplied structured data - text = payload.get("text") or payload.get("content") # type: ignore[union-attr] - if isinstance(text, str) and text.strip(): - return [ChatMessage(Role.USER, text=text.strip())] - return [ChatMessage(Role.USER, text=str(payload))] # type: ignore[arg-type] - - -def _default_termination_condition(conversation: list[ChatMessage]) -> bool: - """Default termination: stop after 10 user messages.""" - user_message_count = sum(1 for msg in conversation if msg.role == Role.USER) - return user_message_count >= 10 + async def _broadcast_messages( + self, + messages: list[ChatMessage], + ctx: WorkflowContext[AgentExecutorRequest], + ) -> None: + """Broadcast the workflow cache to the agent before running.""" + agent_executor_request = AgentExecutorRequest( + messages=messages, + should_respond=False, # Other agents do not need to respond yet + ) + # Since all agents are connected via fan-out, we can directly send the message + await ctx.send_message(agent_executor_request) + def _is_handoff_requested(self, response: AgentRunResponse) -> str | None: + """Determine if the agent response includes a handoff request. -class HandoffBuilder: - r"""Fluent builder for conversational handoff workflows with coordinator and specialist agents. - - The handoff pattern enables a coordinator agent to route requests to specialist agents. - Interaction mode controls whether the workflow requests user input after each agent response or - completes autonomously once agents finish responding. A termination condition determines when - the workflow should stop requesting input and complete. - - Routing Patterns: - - **Single-Tier (Default):** Only the coordinator can hand off to specialists. By default, after any specialist - responds, control returns to the user for more input. This creates a cyclical flow: - user -> coordinator -> [optional specialist] -> user -> coordinator -> ... - Use `with_interaction_mode("autonomous")` to skip requesting additional user input and yield the - final conversation when an agent responds without delegating. - - **Multi-Tier (Advanced):** Specialists can hand off to other specialists using `.add_handoff()`. - This provides more flexibility for complex workflows but is less controllable than the single-tier - pattern. Users lose real-time visibility into intermediate steps during specialist-to-specialist - handoffs (though the full conversation history including all handoffs is preserved and can be - inspected afterward). - - - Key Features: - - **Automatic handoff detection**: The coordinator invokes a handoff tool whose - arguments (for example `{"handoff_to": "shipping_agent"}`) identify the specialist to receive control. - - **Auto-generated tools**: By default the builder synthesizes `handoff_to_` tools for the coordinator, - so you don't manually define placeholder functions. - - **Full conversation history**: The entire conversation (including any - `ChatMessage.additional_properties`) is preserved and passed to each agent. - - **Termination control**: By default, terminates after 10 user messages. Override with - `.with_termination_condition(lambda conv: ...)` for custom logic (e.g., detect "goodbye"). - - **Interaction modes**: Choose `human_in_loop` (default) to prompt users between agent turns, - or `autonomous` to continue routing back to agents without prompting for user input until a - handoff occurs or a termination/turn limit is reached (default autonomous turn limit: 50). - - **Checkpointing**: Optional persistence for resumable workflows. - - Usage (Single-Tier): - - .. code-block:: python - - from agent_framework import HandoffBuilder - from agent_framework.openai import OpenAIChatClient - - chat_client = OpenAIChatClient() - - # Create coordinator and specialist agents - coordinator = chat_client.create_agent( - instructions=( - "You are a frontline support agent. Assess the user's issue and decide " - "whether to hand off to 'refund_agent' or 'shipping_agent'. When delegation is " - "required, call the matching handoff tool (for example `handoff_to_refund_agent`)." - ), - name="coordinator_agent", - ) + If a handoff tool is invoked, the middleware will short-circuit execution + and provide a synthetic result that includes the target agent ID. The message + that contains the function result will be the last message in the response. + """ + if not response.messages: + return None - refund = chat_client.create_agent( - instructions="You handle refund requests. Ask for order details and process refunds.", - name="refund_agent", - ) + last_message = response.messages[-1] + for content in last_message.contents: + if content.type == "function_result": + # Use string comparison instead of isinstance to improve performance + if content.result and isinstance(content.result, dict): + handoff_target = content.result.get(HANDOFF_FUNCTION_RESULT_KEY) # type: ignore + if isinstance(handoff_target, str): + return handoff_target + else: + continue - shipping = chat_client.create_agent( - instructions="You resolve shipping issues. Track packages and update delivery status.", - name="shipping_agent", - ) + return None - # Build the handoff workflow - default single-tier routing - workflow = ( - HandoffBuilder( - name="customer_support", - participants=[coordinator, refund, shipping], - ) - .set_coordinator(coordinator) - .build() - ) + async def _check_terminate_and_yield(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check termination conditions and yield completion if met. - # Run the workflow - events = await workflow.run_stream("My package hasn't arrived yet") - async for event in events: - if isinstance(event, RequestInfoEvent): - # Request user input - user_response = input("You: ") - await workflow.send_response(event.data.request_id, user_response) - - **Multi-Tier Routing with .add_handoff():** - - .. code-block:: python - - # Enable specialist-to-specialist handoffs with fluent API - workflow = ( - HandoffBuilder(participants=[coordinator, replacement, delivery, billing]) - .set_coordinator(coordinator) - .add_handoff(coordinator, [replacement, delivery, billing]) # Coordinator routes to all - .add_handoff(replacement, [delivery, billing]) # Replacement delegates to delivery/billing - .add_handoff(delivery, billing) # Delivery escalates to billing - .build() - ) + Args: + ctx: Workflow context for yielding output - # Flow: User → Coordinator → Replacement → Delivery → Back to User - # (Replacement hands off to Delivery without returning to user) + Returns: + True if termination condition met and output yielded, False otherwise + """ + if self._termination_condition is None: + return False - **Use Participant Factories for State Isolation:** + terminated = self._termination_condition(self._full_conversation) + if inspect.isawaitable(terminated): + terminated = await terminated - .. code-block:: python - # Define factories that produce fresh agent instances per workflow run - def create_coordinator() -> AgentProtocol: - return chat_client.create_agent( - instructions="You are the coordinator agent...", - name="coordinator_agent", - ) + if terminated: + await ctx.yield_output(self._full_conversation) + return True + return False - def create_specialist() -> AgentProtocol: - return chat_client.create_agent( - instructions="You are the specialist agent...", - name="specialist_agent", - ) + @override + async def on_checkpoint_save(self) -> dict[str, Any]: + """Serialize the executor state for checkpointing.""" + state = await super().on_checkpoint_save() + state["_autonomous_mode_turns"] = self._autonomous_mode_turns + return state + @override + async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: + """Restore the executor state from a checkpoint.""" + await super().on_checkpoint_restore(state) + if "_autonomous_mode_turns" in state: + self._autonomous_mode_turns = state["_autonomous_mode_turns"] - workflow = ( - HandoffBuilder( - participant_factories={ - "coordinator": create_coordinator, - "specialist": create_specialist, - } - ) - .set_coordinator("coordinator") - .build() - ) - **Custom Termination Condition:** +# endregion Handoff Agent Executor - .. code-block:: python +# region Handoff workflow builder - # Terminate when user says goodbye or after 5 exchanges - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator(coordinator) - .with_termination_condition( - lambda conv: ( - sum(1 for msg in conv if msg.role.value == "user") >= 5 - or any("goodbye" in msg.text.lower() for msg in conv[-2:]) - ) - ) - .build() - ) - **Checkpointing:** +class HandoffBuilder: + r"""Fluent builder for conversational handoff workflows with multiple agents. - .. code-block:: python + The handoff pattern enables a group of agents to route control among themselves. - from agent_framework import InMemoryCheckpointStorage + Routing Pattern: + Agents can hand off to other agents using `.add_handoff()`. This provides a decentralized + approach to multi-agent collaboration. Handoffs can be configured using `.add_handoff`. If + none are specified, all agents can hand off to all others by default (making a mesh topology). - storage = InMemoryCheckpointStorage() - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator(coordinator) - .with_checkpointing(storage) - .build() - ) + Participants must be agents. Support for custom executors is not available in handoff workflows. + + Outputs: + The final conversation history as a list of ChatMessage once the group chat completes. - Args: - name: Optional workflow name for identification and logging. - participants: List of agents (AgentProtocol) or executors to participate in the handoff. - The first agent you specify as coordinator becomes the orchestrating agent. - participant_factories: Mapping of factory names to callables that produce agents or - executors when invoked. This allows for lazy instantiation - and state isolation per workflow instance created by this builder. - description: Optional human-readable description of the workflow. - - Raises: - ValueError: If participants list is empty, contains duplicates, or coordinator not specified. - TypeError: If participants are not AgentProtocol or Executor instances. + Note: + Agents in handoff workflows must be ChatAgent instances and support local tool calls. """ def __init__( self, *, name: str | None = None, - participants: Sequence[AgentProtocol | Executor] | None = None, - participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] | None = None, + participants: Sequence[AgentProtocol] | None = None, + participant_factories: Mapping[str, Callable[[], AgentProtocol]] | None = None, description: str | None = None, ) -> None: r"""Initialize a HandoffBuilder for creating conversational handoff workflows. The builder starts in an unconfigured state and requires you to call: 1. `.participants([...])` - Register agents - 2. or `.participant_factories({...})` - Register agent/executor factories - 3. `.set_coordinator(...)` - Designate which agent receives initial user input - 4. `.build()` - Construct the final Workflow + 2. or `.participant_factories({...})` - Register agent factories + 3. `.build()` - Construct the final Workflow Optional configuration methods allow you to customize context management, termination logic, and persistence. Args: name: Optional workflow identifier used in logging and debugging. - If not provided, a default name will be generated. - participants: Optional list of agents (AgentProtocol) or executors that will - participate in the handoff workflow. You can also call - `.participants([...])` later. Each participant must have a - unique identifier (name for agents, id for executors). - participant_factories: Optional mapping of factory names to callables that produce agents or - executors when invoked. This allows for lazy instantiation - and state isolation per workflow instance created by this builder. + If not provided, a default name will be generated. + participants: Optional list of agents that will participate in the handoff workflow. + You can also call `.participants([...])` later. Each participant must have a + unique identifier (`.name` is preferred if set, otherwise `.id` is used). + participant_factories: Optional mapping of factory names to callables that produce agents when invoked. + This allows for lazy instantiation and state isolation per workflow instance + created by this builder. description: Optional human-readable description explaining the workflow's - purpose. Useful for documentation and observability. - - Note: - Participants must have stable names/ids because the workflow maps the - handoff tool arguments to these identifiers. Agent names should match - the strings emitted by the coordinator's handoff tool (e.g., a tool that - outputs `{\"handoff_to\": \"billing\"}` requires an agent named `billing`). + purpose. Useful for documentation and observability. """ self._name = name self._description = description - self._executors: dict[str, Executor] = {} - self._aliases: dict[str, str] = {} - self._starting_agent_id: str | None = None - self._checkpoint_storage: CheckpointStorage | None = None - self._request_prompt: str | None = None - # Termination condition - self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] = ( - _default_termination_condition - ) - self._handoff_config: dict[str, list[str]] = {} # Maps agent_id -> [target_agent_ids] - self._return_to_previous: bool = False - self._interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop" - self._autonomous_turn_limit: int | None = _DEFAULT_AUTONOMOUS_TURN_LIMIT - self._request_info_enabled: bool = False - self._request_info_filter: set[str] | None = None - - self._participant_factories: dict[str, Callable[[], AgentProtocol | Executor]] = {} + + # Participant related members + self._participants: dict[str, AgentProtocol] = {} + self._participant_factories: dict[str, Callable[[], AgentProtocol]] = {} + self._start_id: str | None = None if participant_factories: self.participant_factories(participant_factories) if participants: self.participants(participants) - # region Fluent Configuration Methods + # Handoff related members + self._handoff_config: dict[str, set[HandoffConfiguration]] = {} + + # Checkpoint related members + self._checkpoint_storage: CheckpointStorage | None = None + + # Autonomous mode related + self._autonomous_mode: bool = False + self._autonomous_mode_prompts: dict[str, str] = {} + self._autonomous_mode_turn_limits: dict[str, int] = {} + self._autonomous_mode_enabled_agents: list[str] = [] + + # Termination related members + self._termination_condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] | None = None def participant_factories( - self, participant_factories: Mapping[str, Callable[[], AgentProtocol | Executor]] + self, participant_factories: Mapping[str, Callable[[], AgentProtocol]] ) -> "HandoffBuilder": - """Register factories that produce agents or executors for the handoff workflow. + """Register factories that produce agents for the handoff workflow. - Each factory is a callable that returns an AgentProtocol or Executor instance. + Each factory is a callable that returns an AgentProtocol instance. Factories are invoked when building the workflow, allowing for lazy instantiation and state isolation per workflow instance. Args: - participant_factories: Mapping of factory names to callables that return AgentProtocol or Executor - instances. Each produced participant must have a unique identifier (name for - agents, id for executors). + participant_factories: Mapping of factory names to callables that return AgentProtocol + instances. Each produced participant must have a unique identifier + (`.name` is preferred if set, otherwise `.id` is used). Returns: Self for method chaining. @@ -1002,7 +646,7 @@ def participant_factories( from agent_framework import ChatAgent, HandoffBuilder - def create_coordinator() -> ChatAgent: + def create_triage() -> ChatAgent: return ... @@ -1015,17 +659,17 @@ def create_billing_agent() -> ChatAgent: factories = { - "coordinator": create_coordinator, + "triage": create_triage, "refund": create_refund_agent, "billing": create_billing_agent, } + # Handoff will be created automatically unless specified otherwise + # The default creates a mesh topology where all agents can handoff to all others builder = HandoffBuilder().participant_factories(factories) - # Use the factory IDs to create handoffs and set the coordinator - builder.add_handoff("coordinator", ["refund", "billing"]) - builder.set_coordinator("coordinator") + builder.with_start_agent("triage") """ - if self._executors: + if self._participants: raise ValueError( "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." ) @@ -1039,17 +683,12 @@ def create_billing_agent() -> ChatAgent: self._participant_factories = dict(participant_factories) return self - def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "HandoffBuilder": - """Register the agents or executors that will participate in the handoff workflow. - - Each participant must have a unique identifier (name for agents, id for executors). - The workflow will automatically create an alias map so agents can be referenced by - their name, id, or executor id when routing. + def participants(self, participants: Sequence[AgentProtocol]) -> "HandoffBuilder": + """Register the agents that will participate in the handoff workflow. Args: - participants: Sequence of AgentProtocol or Executor instances. Each must have - a unique identifier. For agents, the name attribute is used as the - primary identifier and must match handoff target strings. + participants: Sequence of AgentProtocol instances. Each must have a unique identifier. + (`.name` is preferred if set, otherwise `.id` is used). Returns: Self for method chaining. @@ -1057,7 +696,7 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han Raises: ValueError: If participants is empty, contains duplicates, or `.participants(...)` or `.participant_factories(...)` has already been called. - TypeError: If participants are not AgentProtocol or Executor instances. + TypeError: If participants are not AgentProtocol instances. Example: @@ -1067,175 +706,81 @@ def participants(self, participants: Sequence[AgentProtocol | Executor]) -> "Han from agent_framework.openai import OpenAIChatClient client = OpenAIChatClient() - coordinator = client.create_agent(instructions="...", name="coordinator") + triage = client.create_agent(instructions="...", name="triage_agent") refund = client.create_agent(instructions="...", name="refund_agent") billing = client.create_agent(instructions="...", name="billing_agent") - builder = HandoffBuilder().participants([coordinator, refund, billing]) - # Now you can call .set_coordinator() to designate the entry point - - Note: - This method resets any previously configured coordinator, so you must call - `.set_coordinator(...)` again after changing participants. + builder = HandoffBuilder().participants([triage, refund, billing]) + builder.with_start_agent(triage) """ if self._participant_factories: raise ValueError( "Cannot mix .participants([...]) and .participant_factories() in the same builder instance." ) - if self._executors: + if self._participants: raise ValueError("participants have already been assigned") if not participants: raise ValueError("participants cannot be empty") - named: dict[str, AgentProtocol | Executor] = {} + named: dict[str, AgentProtocol] = {} for participant in participants: - if isinstance(participant, Executor): - identifier = participant.id - elif isinstance(participant, AgentProtocol): - identifier = participant.name or participant.id + if isinstance(participant, AgentProtocol): + resolved_id = self._resolve_to_id(participant) else: raise TypeError( f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." ) - if identifier in named: - raise ValueError(f"Duplicate participant name '{identifier}' detected") - named[identifier] = participant + if resolved_id in named: + raise ValueError(f"Duplicate participant name '{resolved_id}' detected") + named[resolved_id] = participant - metadata = prepare_participant_metadata( - named, - description_factory=lambda name, participant: getattr(participant, "description", None) or name, - ) - - wrapped = metadata["executors"] - self._executors = {executor.id: executor for executor in wrapped.values()} - self._aliases = metadata["aliases"] - self._starting_agent_id = None - - return self - - def set_coordinator(self, agent: str | AgentProtocol | Executor) -> "HandoffBuilder": - r"""Designate which agent receives initial user input and orchestrates specialist routing. - - The coordinator agent is responsible for analyzing user requests and deciding whether to: - 1. Handle the request directly and respond to the user, OR - 2. Hand off to a specialist agent by including handoff metadata in the response - - After a specialist responds, the workflow automatically returns control to the user - (default) creating a cyclical flow: user -> coordinator -> [specialist] -> user -> ... - Configure `with_interaction_mode("autonomous")` to continue with the responding agent - without requesting another user turn until a handoff occurs or a termination/turn limit is met. - - Args: - agent: The agent to use as the coordinator. Can be: - - Factory name (str): If using participant factories - - AgentProtocol instance: The actual agent object - - Executor instance: A custom executor wrapping an agent - - Returns: - Self for method chaining. - - Raises: - ValueError: 1) If `agent` is an AgentProtocol or Executor instance but `.participants(...)` hasn't - been called yet, or if it is not in the participants list. - 2) If `agent` is a factory name (str) but `.participant_factories(...)` hasn't been - called yet, or if it is not in the participant_factories list. - TypeError: If `agent` is not a str, AgentProtocol, or Executor instance. - - Example: - - .. code-block:: python - - # Use factory name with `.participant_factories()` - builder = ( - HandoffBuilder() - .participant_factories({ - "coordinator": create_coordinator, - "refund": create_refund_agent, - "billing": create_billing_agent, - }) - .set_coordinator("coordinator") - ) - - # Or pass the agent object directly - builder = HandoffBuilder().participants([coordinator, refund, billing]).set_coordinator(coordinator) - - Note: - The coordinator determines routing by invoking a handoff tool call whose - arguments identify the target specialist (for example `{\"handoff_to\": \"billing\"}`). - Decorate the tool with `approval_mode="always_require"` to ensure the workflow - intercepts the call before execution and can make the transition. - """ - if isinstance(agent, (AgentProtocol, Executor)): - if not self._executors: - raise ValueError( - "Call participants(...) before coordinator(...). If using participant_factories, " - "pass the factory name (str) instead of the agent instance." - ) - resolved = self._resolve_to_id(agent) - if resolved not in self._executors: - raise ValueError(f"coordinator '{resolved}' is not part of the participants list") - self._starting_agent_id = resolved - elif isinstance(agent, str): - if agent not in self._participant_factories: - raise ValueError( - f"coordinator factory name '{agent}' is not part of the participant_factories list. If " - "you are using participant instances, call .participants(...) and pass the agent instance instead." - ) - self._starting_agent_id = agent - else: - raise TypeError( - "coordinator must be a factory name (str), AgentProtocol, or Executor instance. " - f"Got {type(agent).__name__}." - ) + self._participants = named return self def add_handoff( self, - source: str | AgentProtocol | Executor, - targets: str | AgentProtocol | Executor | Sequence[str | AgentProtocol | Executor], + source: str | AgentProtocol, + targets: Sequence[str] | Sequence[AgentProtocol], *, - tool_name: str | None = None, - tool_description: str | None = None, + description: str | None = None, ) -> "HandoffBuilder": """Add handoff routing from a source agent to one or more target agents. - This method enables specialist-to-specialist handoffs by configuring which agents - can hand off to which others. Call this method multiple times to build a complete - routing graph. By default, only the starting agent can hand off to all other participants; - use this method to enable additional routing paths. + This method enables agent-to-agent handoffs by configuring which agents + can hand off to which others. Call this method multiple times to build a + complete routing graph. If no handoffs are specified, all agents can hand off + to all others by default (mesh topology). Args: source: The agent that can initiate the handoff. Can be: - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - - Executor instance: A custom executor wrapping an agent - Cannot mix factory names and instances across source and targets targets: One or more target agents that the source can hand off to. Can be: - Factory name (str): If using participant factories - AgentProtocol instance: The actual agent object - - Executor instance: A custom executor wrapping an agent - - Single target: "billing_agent" or agent_instance + - Single target: ["billing_agent"] or [agent_instance] - Multiple targets: ["billing_agent", "support_agent"] or [agent1, agent2] - Cannot mix factory names and instances across source and targets - tool_name: Optional custom name for the handoff tool. Currently not used in the - implementation - tools are always auto-generated as "handoff_to_". - Reserved for future enhancement. - tool_description: Optional custom description for the handoff tool. Currently not used - in the implementation - descriptions are always auto-generated as - "Handoff to the agent.". Reserved for future enhancement. + description: Optional custom description for the handoff. If not provided, the description + of the target agent(s) will be used. If the target agent has no description, + no description will be set for the handoff tool, which is not recommended. + If multiple targets are provided, description will be shared among all handoff + tools. To configure distinct descriptions for multiple targets, call add_handoff() + separately for each target. Returns: Self for method chaining. Raises: ValueError: 1) If source or targets are not in the participants list, or if - participants(...) hasn't been called yet. + participants(...) hasn't been called yet. 2) If source or targets are factory names (str) but participant_factories(...) - hasn't been called yet, or if they are not in the participant_factories list. + hasn't been called yet, or if they are not in the participant_factories list. TypeError: If mixing factory names (str) and AgentProtocol/Executor instances Examples: @@ -1263,10 +808,9 @@ def add_handoff( workflow = ( HandoffBuilder(participants=[triage, replacement, delivery, billing]) - .set_coordinator(triage) .add_handoff(triage, [replacement, delivery, billing]) .add_handoff(replacement, [delivery, billing]) - .add_handoff(delivery, billing) + .add_handoff(delivery, [billing]) .build() ) @@ -1274,9 +818,7 @@ def add_handoff( - Handoff tools are automatically registered for each source agent - If a source agent is configured multiple times via add_handoff, targets are merged """ - if isinstance(source, str) and ( - isinstance(targets, str) or (isinstance(targets, Sequence) and all(isinstance(t, str) for t in targets)) - ): + if isinstance(source, str) and all(isinstance(t, str) for t in targets): # Both source and targets are factory names if not self._participant_factories: raise ValueError("Call participant_factories(...) before add_handoff(...)") @@ -1284,90 +826,120 @@ def add_handoff( if source not in self._participant_factories: raise ValueError(f"Source factory name '{source}' is not in the participant_factories list") - target_list: list[str] = [targets] if isinstance(targets, str) else list(targets) # type: ignore - for target in target_list: + for target in targets: if target not in self._participant_factories: raise ValueError(f"Target factory name '{target}' is not in the participant_factories list") - self._handoff_config[source] = target_list # type: ignore + # Merge with existing handoff configuration for this source + if source in self._handoff_config: + # Add new targets to existing list, avoiding duplicates + for t in targets: + if t in self._handoff_config[source]: + logger.warning(f"Handoff from '{source}' to '{t}' is already configured; overwriting.") + self._handoff_config[source].add(HandoffConfiguration(target=t, description=description)) + else: + self._handoff_config[source] = set() + for t in targets: + self._handoff_config[source].add(HandoffConfiguration(target=t, description=description)) return self - if isinstance(source, (AgentProtocol, Executor)) and ( - isinstance(targets, (AgentProtocol, Executor)) - or all(isinstance(t, (AgentProtocol, Executor)) for t in targets) - ): + if isinstance(source, (AgentProtocol)) and all(isinstance(t, AgentProtocol) for t in targets): # Both source and targets are instances - if not self._executors: + if not self._participants: raise ValueError("Call participants(...) before add_handoff(...)") # Resolve source agent ID source_id = self._resolve_to_id(source) - if source_id not in self._executors: + if source_id not in self._participants: raise ValueError(f"Source agent '{source}' is not in the participants list") - # Normalize targets to list - target_list: list[AgentProtocol | Executor] = ( # type: ignore[no-redef] - [targets] if isinstance(targets, (AgentProtocol, Executor)) else list(targets) - ) # type: ignore - # Resolve all target IDs target_ids: list[str] = [] - for target in target_list: + for target in targets: target_id = self._resolve_to_id(target) - if target_id not in self._executors: + if target_id not in self._participants: raise ValueError(f"Target agent '{target}' is not in the participants list") target_ids.append(target_id) # Merge with existing handoff configuration for this source if source_id in self._handoff_config: # Add new targets to existing list, avoiding duplicates - existing = self._handoff_config[source_id] - for target_id in target_ids: - if target_id not in existing: - existing.append(target_id) + for t in target_ids: + if t in self._handoff_config[source_id]: + logger.warning(f"Handoff from '{source_id}' to '{t}' is already configured; overwriting.") + self._handoff_config[source_id].add(HandoffConfiguration(target=t, description=description)) else: - self._handoff_config[source_id] = target_ids + self._handoff_config[source_id] = set() + for t in target_ids: + self._handoff_config[source_id].add(HandoffConfiguration(target=t, description=description)) return self raise TypeError( - "Cannot mix factory names (str) and AgentProtocol/Executor instances " - "across source and targets in add_handoff()" + "Cannot mix factory names (str) and AgentProtocol instances across source and targets in add_handoff()" ) - def request_prompt(self, prompt: str | None) -> "HandoffBuilder": - """Set a custom prompt message displayed when requesting user input. + def with_start_agent(self, agent: str | AgentProtocol) -> "HandoffBuilder": + """Set the agent that will initiate the handoff workflow. - By default, the workflow uses a generic prompt: "Provide your next input for the - conversation." Use this method to customize the message shown to users when the - workflow needs their response. + If not specified, the first registered participant will be used as the starting agent. Args: - prompt: Custom prompt text to display, or None to use the default prompt. - + agent: The agent that will start the workflow. Can be: + - Factory name (str): If using participant factories + - AgentProtocol instance: The actual agent object Returns: Self for method chaining. + """ + if isinstance(agent, str): + if self._participant_factories: + if agent not in self._participant_factories: + raise ValueError(f"Start agent factory name '{agent}' is not in the participant_factories list") + else: + raise ValueError("Call participant_factories(...) before with_start_agent(...)") + self._start_id = agent + elif isinstance(agent, AgentProtocol): + resolved_id = self._resolve_to_id(agent) + if self._participants: + if resolved_id not in self._participants: + raise ValueError(f"Start agent '{resolved_id}' is not in the participants list") + else: + raise ValueError("Call participants(...) before with_start_agent(...)") + self._start_id = resolved_id + else: + raise TypeError("Start agent must be a factory name (str) or an AgentProtocol instance") - Example: - - .. code-block:: python + return self - workflow = ( - HandoffBuilder(participants=[triage, refund, billing]) - .set_coordinator("triage") - .request_prompt("How can we help you today?") - .build() - ) + def with_autonomous_mode( + self, + *, + agents: Sequence[AgentProtocol] | Sequence[str] | None = None, + prompts: dict[str, str] | None = None, + turn_limits: dict[str, int] | None = None, + ) -> "HandoffBuilder": + """Enable autonomous mode for the handoff workflow. - # For more context-aware prompts, you can access the prompt via - # RequestInfoEvent.data.prompt in your event handling loop + Autonomous mode allows agents to continue responding without user input. + The default behavior when autonomous mode is disabled is to return control to the user + after each agent response that does not trigger a handoff. With autonomous mode enabled, + agents can continue the conversation until they request a handoff or the turn limit is reached. - Note: - The prompt is static and set once during workflow construction. If you need - dynamic prompts based on conversation state, you'll need to handle that in - your application's event processing logic. + Args: + agents: Optional list of agents to enable autonomous mode for. Can be: + - Factory names (str): If using participant factories + - AgentProtocol instances: The actual agent objects + - If not provided, all agents will operate in autonomous mode. + prompts: Optional mapping of agent identifiers/factory names to custom prompts to use when continuing + in autonomous mode. If not provided, a default prompt will be used. + turn_limits: Optional mapping of agent identifiers/factory names to maximum number of autonomous turns + before returning control to the user. If not provided, a default turn limit will be used. """ - self._request_prompt = prompt + self._autonomous_mode = True + self._autonomous_mode_prompts = prompts or {} + self._autonomous_mode_turn_limits = turn_limits or {} + self._autonomous_mode_enabled_agents = [self._resolve_to_id(agent) for agent in agents] if agents else [] + return self def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "HandoffBuilder": @@ -1394,12 +966,7 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "HandoffB from agent_framework import InMemoryCheckpointStorage storage = InMemoryCheckpointStorage() - workflow = ( - HandoffBuilder(participants=[triage, refund, billing]) - .set_coordinator("triage") - .with_checkpointing(storage) - .build() - ) + workflow = HandoffBuilder(participants=[triage, refund, billing]).with_checkpointing(storage).build() # Run workflow with a session ID for resumption async for event in workflow.run_stream("Help me", session_id="user_123"): @@ -1424,16 +991,14 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "HandoffB self._checkpoint_storage = checkpoint_storage return self - def with_termination_condition( - self, condition: Callable[[list[ChatMessage]], bool | Awaitable[bool]] - ) -> "HandoffBuilder": + def with_termination_condition(self, termination_condition: TerminationCondition) -> "HandoffBuilder": """Set a custom termination condition for the handoff workflow. The condition can be either synchronous or asynchronous. Args: - condition: Function that receives the full conversation and returns True - (or awaitable True) if the workflow should terminate (not request further user input). + termination_condition: Function that receives the full conversation and returns True + (or awaitable True) if the workflow should terminate. Returns: Self for chaining. @@ -1456,199 +1021,15 @@ async def check_termination(conv: list[ChatMessage]) -> bool: builder.with_termination_condition(check_termination) """ - self._termination_condition = condition - return self - - def with_interaction_mode( - self, - interaction_mode: Literal["human_in_loop", "autonomous"] = "human_in_loop", - *, - autonomous_turn_limit: int | None = None, - ) -> "HandoffBuilder": - """Choose whether the workflow requests user input or runs autonomously after agent replies. - - In autonomous mode, agents (including specialists) continue iterating on their task - until they explicitly invoke a handoff tool or the turn limit is reached. This allows - specialists to perform long-running autonomous tasks (e.g., research, coding, analysis) - without prematurely returning control to the coordinator or user. - - Args: - interaction_mode: `"human_in_loop"` (default) requests user input after each agent response - that does not trigger a handoff. `"autonomous"` lets agents continue - working until they invoke a handoff tool or the turn limit is reached. - - Keyword Args: - autonomous_turn_limit: Maximum number of agent responses before the workflow yields - when in autonomous mode. Only applicable when interaction_mode - is `"autonomous"`. Default is 50. Set to `None` to disable - the limit (use with caution). Ignored with a warning if provided - when interaction_mode is `"human_in_loop"`. - - Returns: - Self for chaining. - - Example: - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[coordinator, research_agent]) - .set_coordinator(coordinator) - .add_handoff(coordinator, research_agent) - .add_handoff(research_agent, coordinator) - .with_interaction_mode("autonomous", autonomous_turn_limit=20) - .build() - ) - - # Flow: User asks a question - # -> Coordinator routes to Research Agent - # -> Research Agent iterates (researches, analyzes, refines) - # -> Research Agent calls handoff_to_coordinator when done - # -> Coordinator provides final response - """ - if interaction_mode not in ("human_in_loop", "autonomous"): - raise ValueError("interaction_mode must be either 'human_in_loop' or 'autonomous'") - self._interaction_mode = interaction_mode - - if autonomous_turn_limit is not None: - if interaction_mode != "autonomous": - logger.warning( - f"autonomous_turn_limit={autonomous_turn_limit} was provided but interaction_mode is " - f"'{interaction_mode}'; ignoring." - ) - elif autonomous_turn_limit <= 0: - raise ValueError("autonomous_turn_limit must be positive when provided") - else: - self._autonomous_turn_limit = autonomous_turn_limit - - return self - - def enable_return_to_previous(self, enabled: bool = True) -> "HandoffBuilder": - """Enable direct return to the current agent after user input, bypassing the coordinator. - - When enabled, after a specialist responds without requesting another handoff, user input - routes directly back to that same specialist instead of always routing back to the - coordinator agent for re-evaluation. - - This is useful when a specialist needs multiple turns with the user to gather information - or resolve an issue, avoiding unnecessary coordinator involvement while maintaining context. - - Flow Comparison: - - **Default (disabled):** - User -> Coordinator -> Specialist -> User -> Coordinator -> Specialist -> ... - - **With return_to_previous (enabled):** - User -> Coordinator -> Specialist -> User -> Specialist -> ... - - Args: - enabled: Whether to enable return-to-previous routing. Default is True. - - Returns: - Self for method chaining. - - Example: - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[triage, technical_support, billing]) - .set_coordinator("triage") - .add_handoff(triage, [technical_support, billing]) - .enable_return_to_previous() # Enable direct return routing - .build() - ) - - # Flow: User asks question - # -> Triage routes to Technical Support - # -> Technical Support asks clarifying question - # -> User provides more info - # -> Routes back to Technical Support (not Triage) - # -> Technical Support continues helping - - Multi-tier handoff example: - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator("triage") - .add_handoff(triage, [specialist_a, specialist_b]) - .add_handoff(specialist_a, specialist_b) - .enable_return_to_previous() - .build() - ) - - # Flow: User asks question - # -> Triage routes to Specialist A - # -> Specialist A hands off to Specialist B - # -> Specialist B asks clarifying question - # -> User provides more info - # -> Routes back to Specialist B (who is currently handling the conversation) - - Note: - This feature routes to whichever agent most recently responded, whether that's - the coordinator or a specialist. The conversation continues with that agent until - they either hand off to another agent or the termination condition is met. - """ - self._return_to_previous = enabled - return self - - def with_request_info( - self, - *, - agents: Sequence[str | AgentProtocol | Executor] | None = None, - ) -> "HandoffBuilder": - """Enable request info before participants run in the workflow. - - When enabled, the workflow pauses before each participant runs, emitting - a RequestInfoEvent that allows the caller to review the conversation and - optionally inject guidance before the participant responds. The caller provides - input via the standard response_handler/request_info pattern. - - Args: - agents: Optional filter - only pause before these specific agents/executors. - Accepts agent names (str), agent instances, or executor instances. - If None (default), pauses before every participant. - - Returns: - self: The builder instance for fluent chaining. - - Example: - - .. code-block:: python - - # Pause before all participants - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") - .with_request_info() - .build() - ) - - # Pause only before specialist agents (not coordinator) - workflow = ( - HandoffBuilder(participants=[coordinator, refund, shipping]) - .set_coordinator("coordinator_agent") - .with_request_info(agents=[refund, shipping]) - .build() - ) - """ - from ._orchestration_request_info import resolve_request_info_filter - - self._request_info_enabled = True - self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + self._termination_condition = termination_condition return self def build(self) -> Workflow: """Construct the final Workflow instance from the configured builder. This method validates the configuration and assembles all internal components: - - Input normalization executor - Starting agent executor - - Handoff coordinator - Specialist agent executors - - User input gateway - Request/response handling Returns: @@ -1657,388 +1038,192 @@ def build(self) -> Workflow: Raises: ValueError: If participants or coordinator were not configured, or if required configuration is invalid. - - Example (Minimal): - - .. code-block:: python - - workflow = ( - HandoffBuilder(participants=[coordinator, refund, billing]).set_coordinator("coordinator").build() - ) - - # Run the workflow - async for event in workflow.run_stream("I need help"): - # Handle events... - pass - - Example (Full Configuration): - - .. code-block:: python - - from agent_framework import InMemoryCheckpointStorage - - storage = InMemoryCheckpointStorage() - workflow = ( - HandoffBuilder( - name="support_workflow", - participants=[coordinator, refund, billing], - description="Customer support with specialist routing", - ) - .set_coordinator("coordinator") - .with_termination_condition(lambda conv: len(conv) > 20) - .request_prompt("How can we help?") - .with_checkpointing(storage) - .build() - ) - - Note: - After calling build(), the builder instance should not be reused. Create a - new builder if you need to construct another workflow with different configuration. """ - if not self._executors and not self._participant_factories: + if not self._participants and not self._participant_factories: raise ValueError( "No participants or participant_factories have been configured. " "Call participants(...) or participant_factories(...) first." ) - if self._starting_agent_id is None: - raise ValueError("Must call set_coordinator(...) before building the workflow.") - - # Resolve executors, aliases, and handoff tool targets - # This will instantiate participants if using factories, and validate handoff config - start_executor_id, executors, aliases, handoff_tool_targets = self._resolve_executors_and_handoffs() - - specialists = {exec_id: executor for exec_id, executor in executors.items() if exec_id != start_executor_id} - if not specialists: - logger.warning("Handoff workflow has no specialist agents; the coordinator will loop with the user.") - - descriptions = { - exec_id: getattr(executor, "description", None) or exec_id for exec_id, executor in executors.items() - } - participant_specs = { - exec_id: GroupChatParticipantSpec(name=exec_id, participant=executor, description=descriptions[exec_id]) - for exec_id, executor in executors.items() - } - - input_node = _InputToConversation(id="input-conversation") - user_gateway = _UserInputGateway( - starting_agent_id=start_executor_id, - prompt=self._request_prompt, - id="handoff-user-input", - ) - builder = WorkflowBuilder(name=self._name, description=self._description).set_start_executor(input_node) - - specialist_aliases = { - alias: specialists[exec_id].id for alias, exec_id in aliases.items() if exec_id in specialists - } - - def _handoff_orchestrator_factory(_: _GroupChatConfig) -> Executor: - return _HandoffCoordinator( - starting_agent_id=start_executor_id, - specialist_ids=specialist_aliases, - input_gateway_id=user_gateway.id, - termination_condition=self._termination_condition, - id="handoff-coordinator", - handoff_tool_targets=handoff_tool_targets, - return_to_previous=self._return_to_previous, - interaction_mode=self._interaction_mode, - autonomous_turn_limit=self._autonomous_turn_limit, - ) - - wiring = _GroupChatConfig( - manager=None, - manager_participant=None, - manager_name=self._starting_agent_id, - participants=participant_specs, - max_rounds=None, - participant_aliases=aliases, - participant_executors=executors, - ) - - # Determine participant factory - wrap with request info interceptor if enabled - participant_factory: Callable[[GroupChatParticipantSpec, _GroupChatConfig], _GroupChatParticipantPipeline] = ( - _default_participant_factory - ) - if self._request_info_enabled: - base_factory = _default_participant_factory - agent_filter = self._request_info_filter - - def _factory_with_request_info( - spec: GroupChatParticipantSpec, - config: _GroupChatConfig, - ) -> _GroupChatParticipantPipeline: - pipeline = list(base_factory(spec, config)) - if pipeline: - # Add interceptor executor BEFORE the participant (prepend) - interceptor = RequestInfoInterceptor( - executor_id=f"request_info:{spec.name}", - agent_filter=agent_filter, - ) - pipeline.insert(0, interceptor) - return tuple(pipeline) - - participant_factory = _factory_with_request_info - - result = assemble_group_chat_workflow( - wiring=wiring, - participant_factory=participant_factory, - orchestrator_factory=_handoff_orchestrator_factory, - interceptors=(), - checkpoint_storage=self._checkpoint_storage, - builder=builder, - return_builder=True, - ) - if not isinstance(result, tuple): - raise TypeError("Expected tuple from assemble_group_chat_workflow with return_builder=True") - builder, coordinator = result - - # When request_info is enabled, the input should go through the interceptor first - if self._request_info_enabled: - # Get the entry executor from the builder's registered executors - starting_entry_id = f"request_info:{self._starting_agent_id}" - starting_entry_executor = builder._executors.get(starting_entry_id) # type: ignore - if starting_entry_executor: - builder = builder.add_edge(input_node, starting_entry_executor) - else: - # Fallback to direct connection if interceptor not found - builder = builder.add_edge(input_node, executors[start_executor_id]) - else: - builder = builder.add_edge(input_node, executors[start_executor_id]) - builder = builder.add_edge(coordinator, user_gateway) - builder = builder.add_edge(user_gateway, coordinator) + if self._start_id is None: + raise ValueError("Must call with_start_agent(...) before building the workflow.") + + # Resolve agents (either from instances or factories) + # The returned map keys are either executor IDs or factory names, which is need to resolve handoff configs + resolved_agents = self._resolve_agents() + # Resolve handoff configurations to use agent display names + # The returned map keys are executor IDs + resolved_handoffs = self._resolve_handoffs(resolved_agents) + # Resolve agents into executors + executors = self._resolve_executors(resolved_agents, resolved_handoffs) + + # Build the workflow graph + start_executor = executors[self._resolve_to_id(resolved_agents[self._start_id])] + builder = WorkflowBuilder( + name=self._name, + description=self._description, + ).set_start_executor(start_executor) + + # Add the appropriate edges + # In handoff workflows, all executors are connected, making a fully connected graph. + # This is because for all agents to stay synchronized, the active agent must be able to + # broadcast updates to all others via edges. Handoffs are controlled internally by the + # `HandoffAgentExecutor` instances using handoff tools and middleware. + for executor in executors.values(): + targets = [e for e in executors.values() if e.id != executor.id] + # Fan-out requires at least 2 targets. Just in case there are only 2 agents total, + # we add a direct edge if there's only 1 target. + if len(targets) > 1: + builder = builder.add_fan_out_edges(executor, targets) + elif len(targets) == 1: + builder = builder.add_edge(executor, targets[0]) + + # Configure checkpointing if enabled + if self._checkpoint_storage: + builder.with_checkpointing(self._checkpoint_storage) return builder.build() - # endregion Fluent Configuration Methods - # region Internal Helper Methods - def _resolve_executors(self) -> tuple[dict[str, Executor], dict[str, str]]: - """Resolve participant factories into executor instances. + def _resolve_agents(self) -> dict[str, AgentProtocol]: + """Resolve participant factories into agent instances. - If executors were provided directly via participants(...), those are returned as-is. - If participant factories were provided via participant_factories(...), those - are invoked to create executor instances and aliases. + If agent instances were provided directly via participants(...), those are + returned as-is. If participant factories were provided via participant_factories(...), + those are invoked to create the agent instances. Returns: - Tuple of (executors map, aliases map) + Map of executor IDs or factory names to `AgentProtocol` instances """ - if self._executors and self._participant_factories: + if self._participants and self._participant_factories: raise ValueError("Cannot have both executors and participant_factories configured") - if self._executors: - if self._aliases: - # Return existing executors and aliases - return self._executors, self._aliases - raise ValueError("Aliases is empty despite executors being provided") + if self._participants: + return self._participants if self._participant_factories: # Invoke each factory to create participant instances - executor_ids_to_executors: dict[str, AgentProtocol | Executor] = {} - factory_names_to_ids: dict[str, str] = {} + factory_names_to_agents: dict[str, AgentProtocol] = {} for factory_name, factory in self._participant_factories.items(): - instance: Executor | AgentProtocol = factory() - if isinstance(instance, Executor): - identifier = instance.id - elif isinstance(instance, AgentProtocol): - identifier = instance.name or instance.id + instance = factory() + if isinstance(instance, AgentProtocol): + resolved_id = self._resolve_to_id(instance) else: - raise TypeError( - f"Participants must be AgentProtocol or Executor instances. Got {type(instance).__name__}." - ) + raise TypeError(f"Participants must be AgentProtocol instances. Got {type(instance).__name__}.") - if identifier in executor_ids_to_executors: - raise ValueError(f"Duplicate participant name '{identifier}' detected") - executor_ids_to_executors[identifier] = instance - factory_names_to_ids[factory_name] = identifier + if resolved_id in factory_names_to_agents: + raise ValueError(f"Duplicate participant name '{resolved_id}' detected") - # Prepare metadata and wrap instances as needed - metadata = prepare_participant_metadata( - executor_ids_to_executors, - description_factory=lambda name, participant: getattr(participant, "description", None) or name, - ) - - wrapped = metadata["executors"] - # Map executors by factory name (not executor.id) because handoff configs reference factory names - # This allows users to configure handoffs using the factory names they provided - executors = { - factory_name: wrapped[executor_id] for factory_name, executor_id in factory_names_to_ids.items() - } - aliases = metadata["aliases"] + # Map executors by factory name (not executor.id) because handoff configs reference factory names + # This allows users to configure handoffs using the factory names they provided + factory_names_to_agents[factory_name] = instance - return executors, aliases + return factory_names_to_agents raise ValueError("No executors or participant_factories have been configured") - def _resolve_handoffs(self, executors: Mapping[str, Executor]) -> tuple[dict[str, Executor], dict[str, str]]: + def _resolve_handoffs(self, agents: Mapping[str, AgentProtocol]) -> dict[str, list[HandoffConfiguration]]: """Handoffs may be specified using factory names or instances; resolve to executor IDs. Args: - executors: Map of executor IDs or factory names to Executor instances + agents: Map of agent IDs or factory names to `AgentProtocol` instances Returns: - Tuple of (updated executors map, handoff configuration map) - The updated executors map may have modified agents with handoff tools added - and maps executor IDs to Executor instances. - The handoff configuration map maps executor IDs to lists of target executor IDs. + Map of executor IDs to list of HandoffConfiguration instances """ - handoff_tool_targets: dict[str, str] = {} - updated_executors = {executor.id: executor for executor in executors.values()} - # Determine which agents should have handoff tools + # Updated map that used agent resolved IDs as keys + updated_handoff_configurations: dict[str, list[HandoffConfiguration]] = {} if self._handoff_config: # Use explicit handoff configuration from add_handoff() calls - for source_id, target_ids in self._handoff_config.items(): - executor = executors.get(source_id) - if not executor: + for source_id, handoff_configurations in self._handoff_config.items(): + source_agent = agents.get(source_id) + if not source_agent: raise ValueError( f"Handoff source agent '{source_id}' not found. " "Please make sure source has been added as either a participant or participant_factory." ) - - if isinstance(executor, AgentExecutor): - # Build targets map for this source agent - targets_map: dict[str, Executor] = {} - for target_id in target_ids: - target_executor = executors.get(target_id) - if not target_executor: - raise ValueError( - f"Handoff target agent '{target_id}' not found. " - "Please make sure target has been added as either a participant or participant_factory." - ) - targets_map[target_executor.id] = target_executor - - # Register handoff tools for this agent - updated_executor, tool_targets = self._prepare_agent_with_handoffs(executor, targets_map) - updated_executors[updated_executor.id] = updated_executor - handoff_tool_targets.update(tool_targets) + for handoff_config in handoff_configurations: + target_agent = agents.get(handoff_config.target_id) + if not target_agent: + raise ValueError( + f"Handoff target agent '{handoff_config.target_id}' not found for source '{source_id}'. " + "Please make sure target has been added as either a participant or participant_factory." + ) + + updated_handoff_configurations.setdefault(self._resolve_to_id(source_agent), []).append( + HandoffConfiguration( + target=self._resolve_to_id(target_agent), + description=handoff_config.description or target_agent.description, + ) + ) else: - if self._starting_agent_id is None or self._starting_agent_id not in executors: - raise RuntimeError("Failed to resolve default handoff configuration due to missing starting agent.") - - # Default behavior: only coordinator gets handoff tools to all specialists - starting_executor = executors[self._starting_agent_id] - specialists = { - executor.id: executor for executor in executors.values() if executor.id != starting_executor.id - } - - if isinstance(starting_executor, AgentExecutor) and specialists: - starting_executor, tool_targets = self._prepare_agent_with_handoffs(starting_executor, specialists) - updated_executors[starting_executor.id] = starting_executor - handoff_tool_targets.update(tool_targets) # Update references after potential agent modifications + # Use default handoff configuration: all agents can hand off to all others (mesh topology) + for source_id, source_agent in agents.items(): + for target_id, target_agent in agents.items(): + if source_id == target_id: + continue # Skip self-handoff + updated_handoff_configurations.setdefault(self._resolve_to_id(source_agent), []).append( + HandoffConfiguration( + target=self._resolve_to_id(target_agent), + description=target_agent.description, + ) + ) - return updated_executors, handoff_tool_targets + return updated_handoff_configurations - def _resolve_executors_and_handoffs(self) -> tuple[str, dict[str, Executor], dict[str, str], dict[str, str]]: - """Resolve participant factories into executor instances and handoff configurations. + def _resolve_executors( + self, + agents: dict[str, AgentProtocol], + handoffs: dict[str, list[HandoffConfiguration]], + ) -> dict[str, HandoffAgentExecutor]: + """Resolve agents into HandoffAgentExecutors. - If executors were provided directly via participants(...), those are returned as-is. - If participant factories were provided via participant_factories(...), those - are invoked to create executor instances and aliases. + Args: + agents: Map of agent IDs or factory names to `AgentProtocol` instances + handoffs: Map of executor IDs to list of HandoffConfiguration instances Returns: - Tuple of (executors map, aliases map, handoff configuration map) + Tuple of (starting executor ID, list of HandoffAgentExecutor instances) """ - # Resolve the participant factories now. This doesn't break the factory pattern - # since the Handoff builder still creates new instances per workflow build. - executors, aliases = self._resolve_executors() - # `self._starting_agent_id` is either a factory name or executor ID at this point, - # resolve to executor ID - if self._starting_agent_id in executors: - start_executor_id = executors[self._starting_agent_id].id - else: - raise RuntimeError("Failed to resolve starting agent ID during build.") + executors: dict[str, HandoffAgentExecutor] = {} - # Resolve handoffs - # This will update the `executors` dict to a map of executor IDs to executors - updated_executors, handoff_tool_targets = self._resolve_handoffs(executors) + for id, agent in agents.items(): + # Note that here `id` may be either factory name or agent resolved ID + resolved_id = self._resolve_to_id(agent) + if resolved_id not in handoffs or not handoffs.get(resolved_id): + logger.warning( + f"No handoff configuration found for agent '{resolved_id}'. " + "This agent will not be able to hand off to any other agents and your workflow may get stuck." + ) - return start_executor_id, updated_executors, aliases, handoff_tool_targets + # Autonomous mode is enabled only for specified agents (or all if none specified) + autonomous_mode = self._autonomous_mode and ( + not self._autonomous_mode_enabled_agents or id in self._autonomous_mode_enabled_agents + ) - def _resolve_to_id(self, candidate: str | AgentProtocol | Executor) -> str: + executors[resolved_id] = HandoffAgentExecutor( + agent=agent, + handoffs=handoffs.get(resolved_id, []), + is_start_agent=(id == self._start_id), + termination_condition=self._termination_condition, + autonomous_mode=autonomous_mode, + autonomous_mode_prompt=self._autonomous_mode_prompts.get(id, None), + autonomous_mode_turn_limit=self._autonomous_mode_turn_limits.get(id, None), + ) + + return executors + + def _resolve_to_id(self, candidate: str | AgentProtocol) -> str: """Resolve a participant reference into a concrete executor identifier.""" - if isinstance(candidate, Executor): - return candidate.id if isinstance(candidate, AgentProtocol): - name: str | None = getattr(candidate, "name", None) - if not name: - raise ValueError("AgentProtocol without a name cannot be resolved to an executor id.") - return self._aliases.get(name, name) + return resolve_agent_id(candidate) if isinstance(candidate, str): - if candidate in self._aliases: - return self._aliases[candidate] return candidate - raise TypeError(f"Invalid starting agent reference: {type(candidate).__name__}") - def _apply_auto_tools(self, agent: ChatAgent, specialists: Mapping[str, Executor]) -> dict[str, str]: - """Attach synthetic handoff tools to a chat agent and return the target lookup table. - - Creates handoff tools for each specialist agent that this agent can route to. - The tool_targets dict maps various name formats (tool name, sanitized name, alias) - to executor IDs to enable flexible handoff target resolution. - - Args: - agent: The ChatAgent to add handoff tools to - specialists: Map of executor IDs or factory names to specialist executors this agent can hand off to - - Returns: - Dict mapping tool names (in various formats) to executor IDs for handoff resolution - """ - default_options = agent.default_options - existing_tools = list(default_options.get("tools") or []) - existing_names = {getattr(tool, "name", "") for tool in existing_tools if hasattr(tool, "name")} - - tool_targets: dict[str, str] = {} - new_tools: list[Any] = [] - for executor in specialists.values(): - alias = executor.id - sanitized = sanitize_identifier(alias) - tool = _create_handoff_tool(alias, executor.description if isinstance(executor, AgentExecutor) else None) - if tool.name not in existing_names: - new_tools.append(tool) - # Map multiple name variations to the same executor ID for robust resolution - tool_targets[tool.name.lower()] = executor.id - tool_targets[sanitized] = executor.id - tool_targets[alias.lower()] = executor.id - - if new_tools: - default_options["tools"] = existing_tools + new_tools - else: - default_options["tools"] = existing_tools - - return tool_targets - - def _prepare_agent_with_handoffs( - self, - executor: AgentExecutor, - target_agents: Mapping[str, Executor], - ) -> tuple[AgentExecutor, dict[str, str]]: - """Prepare an agent by adding handoff tools for the specified target agents. + raise TypeError(f"Invalid starting agent reference: {type(candidate).__name__}") - Args: - executor: The agent executor to prepare - target_agents: Map of executor IDs to target executors this agent can hand off to + # endregion Internal Helper Methods - Returns: - Tuple of (updated executor, tool_targets map) - """ - agent = getattr(executor, "_agent", None) - if not isinstance(agent, ChatAgent): - return executor, {} - - cloned_agent = _clone_chat_agent(agent) - tool_targets = self._apply_auto_tools(cloned_agent, target_agents) - if tool_targets: - middleware = _AutoHandoffMiddleware(tool_targets) - existing_middlewares = list(cloned_agent.middleware or []) - existing_middlewares.append(middleware) - cloned_agent.middleware = existing_middlewares - - new_executor = AgentExecutor( - cloned_agent, - agent_thread=getattr(executor, "_agent_thread", None), - output_response=getattr(executor, "_output_response", False), - id=executor.id, - ) - return new_executor, tool_targets - # endregion Internal Helper Methods +# endregion Handoff workflow builder diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 4697daba3c..03171e795c 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -7,40 +7,36 @@ import re import sys from abc import ABC, abstractmethod -from collections.abc import AsyncIterable, Sequence +from collections.abc import Sequence from dataclasses import dataclass, field from enum import Enum -from typing import Any, TypeVar, cast -from uuid import uuid4 +from typing import Any, ClassVar, TypeVar, cast + +from typing_extensions import Never from agent_framework import ( AgentProtocol, AgentRunResponse, - AgentRunResponseUpdate, ChatMessage, - FunctionApprovalRequestContent, - FunctionResultContent, Role, ) -from ._base_group_chat_orchestrator import BaseGroupChatOrchestrator -from ._checkpoint import CheckpointStorage, WorkflowCheckpoint -from ._const import EXECUTOR_STATE_KEY, WORKFLOW_RUN_KWARGS_KEY -from ._events import AgentRunUpdateEvent, WorkflowEvent -from ._executor import Executor, handler -from ._group_chat import ( - GroupChatBuilder, - _GroupChatConfig, # type: ignore[reportPrivateUsage] - _GroupChatParticipantPipeline, # type: ignore[reportPrivateUsage] - _GroupChatRequestMessage, # type: ignore[reportPrivateUsage] - _GroupChatResponseMessage, # type: ignore[reportPrivateUsage] - group_chat_orchestrator, +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._base_group_chat_orchestrator import ( + BaseGroupChatOrchestrator, + GroupChatParticipantMessage, + GroupChatRequestMessage, + GroupChatResponseMessage, + GroupChatWorkflowContext_T_Out, + ParticipantRegistry, ) -from ._message_utils import normalize_messages_input +from ._checkpoint import CheckpointStorage +from ._events import ExecutorEvent +from ._executor import Executor, handler from ._model_utils import DictConvertible, encode_value -from ._participant_utils import GroupChatParticipantSpec, participant_description from ._request_info_mixin import response_handler -from ._workflow import Workflow, WorkflowRunResult +from ._workflow import Workflow +from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext if sys.version_info >= (3, 11): @@ -103,14 +99,6 @@ def _message_from_payload(payload: Any) -> ChatMessage: raise TypeError("Unable to reconstruct ChatMessage from payload") -# region Magentic event metadata constants - -# Event type identifiers for magentic_event_type in additional_properties -MAGENTIC_EVENT_TYPE_ORCHESTRATOR = "orchestrator_message" -MAGENTIC_EVENT_TYPE_AGENT_DELTA = "agent_delta" - -# endregion Magentic event metadata constants - # region Magentic One Prompts ORCHESTRATOR_TASK_LEDGER_FACTS_PROMPT = """Below I will present you a request. @@ -276,204 +264,6 @@ def _new_participant_descriptions() -> dict[str, str]: return {} -def _new_chat_message_list() -> list[ChatMessage]: - """Typed default factory for ChatMessage list to satisfy type checkers.""" - return [] - - -@dataclass -class _MagenticStartMessage(DictConvertible): - """Internal: A message to start a magentic workflow.""" - - messages: list[ChatMessage] = field(default_factory=_new_chat_message_list) - run_kwargs: dict[str, Any] = field(default_factory=dict) - - def __init__( - self, - messages: str | ChatMessage | Sequence[str] | Sequence[ChatMessage] | None = None, - *, - task: ChatMessage | None = None, - run_kwargs: dict[str, Any] | None = None, - ) -> None: - normalized = normalize_messages_input(messages) - if task is not None: - normalized += normalize_messages_input(task) - if not normalized: - raise ValueError("MagenticStartMessage requires at least one message input.") - self.messages: list[ChatMessage] = normalized - self.run_kwargs: dict[str, Any] = run_kwargs or {} - - @property - def task(self) -> ChatMessage: - """Final user message for the task.""" - return self.messages[-1] - - @classmethod - def from_string(cls, task_text: str) -> "_MagenticStartMessage": - """Create a MagenticStartMessage from a simple string.""" - return cls(task_text) - - def to_dict(self) -> dict[str, Any]: - """Create a dict representation of the message.""" - return { - "messages": [message.to_dict() for message in self.messages], - "task": self.task.to_dict(), - } - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticStartMessage": - """Create from a dict.""" - if "messages" in data: - raw_messages = data["messages"] - if not isinstance(raw_messages, Sequence) or isinstance(raw_messages, (str, bytes)): - raise TypeError("MagenticStartMessage 'messages' must be a sequence.") - messages: list[ChatMessage] = [ChatMessage.from_dict(raw) for raw in raw_messages] # type: ignore[arg-type] - return cls(messages) - if "task" in data: - task = ChatMessage.from_dict(data["task"]) - return cls(task) - raise KeyError("Expected 'messages' or 'task' in MagenticStartMessage payload.") - - -@dataclass -class _MagenticRequestMessage(_GroupChatRequestMessage): - """Internal: A request message type for agents in a magentic workflow.""" - - task_context: str = "" - - -class _MagenticResponseMessage(_GroupChatResponseMessage): - """Internal: A response message type. - - When emitted by the orchestrator you can mark it as a broadcast to all agents, - or target a specific agent by name. - """ - - def __init__( - self, - body: ChatMessage, - target_agent: str | None = None, # deliver only to this agent if set - broadcast: bool = False, # deliver to all agents if True - ) -> None: - agent_name = body.author_name or "" - super().__init__( - agent_name=agent_name, - message=body, - ) - self.body = body - self.target_agent = target_agent - self.broadcast = broadcast - - def to_dict(self) -> dict[str, Any]: - """Create a dict representation of the message.""" - return {"body": self.body.to_dict(), "target_agent": self.target_agent, "broadcast": self.broadcast} - - @classmethod - def from_dict(cls, value: dict[str, Any]) -> "_MagenticResponseMessage": - """Create from a dict.""" - body = ChatMessage.from_dict(value["body"]) - target_agent = value.get("target_agent") - broadcast = value.get("broadcast", False) - return cls(body=body, target_agent=target_agent, broadcast=broadcast) - - -# region Human Intervention Types - - -class MagenticHumanInterventionKind(str, Enum): - """The kind of human intervention being requested.""" - - PLAN_REVIEW = "plan_review" # Review and approve/revise the initial plan - TOOL_APPROVAL = "tool_approval" # Approve a tool/function call - STALL = "stall" # Workflow has stalled and needs guidance - - -class MagenticHumanInterventionDecision(str, Enum): - """Decision options for human intervention responses.""" - - APPROVE = "approve" # Approve (plan review, tool approval) - REVISE = "revise" # Request revision with feedback (plan review) - REJECT = "reject" # Reject/deny (tool approval) - CONTINUE = "continue" # Continue with current state (stall) - REPLAN = "replan" # Trigger replanning (stall) - GUIDANCE = "guidance" # Provide guidance text (stall, tool approval) - - -@dataclass -class _MagenticHumanInterventionRequest: - """Unified request for human intervention in a Magentic workflow. - - This request is emitted when the workflow needs human input. The `kind` field - indicates what type of intervention is needed, and the relevant fields are - populated based on the kind. - - Attributes: - request_id: Unique identifier for correlating responses - kind: The type of intervention needed (plan_review, tool_approval, stall) - - # Plan review fields - task_text: The task description (plan_review) - facts_text: Extracted facts from the task (plan_review) - plan_text: The proposed or current plan (plan_review, stall) - round_index: Number of review rounds so far (plan_review) - - # Tool approval fields - agent_id: The agent requesting input (tool_approval) - prompt: Description of what input is needed (tool_approval) - context: Additional context (tool_approval) - conversation_snapshot: Recent conversation history (tool_approval) - - # Stall intervention fields - stall_count: Number of consecutive stall rounds (stall) - max_stall_count: Threshold that triggered intervention (stall) - stall_reason: Description of why progress stalled (stall) - last_agent: Last active agent (stall) - """ - - request_id: str = field(default_factory=lambda: str(uuid4())) - kind: MagenticHumanInterventionKind = MagenticHumanInterventionKind.PLAN_REVIEW - - # Plan review fields - task_text: str = "" - facts_text: str = "" - plan_text: str = "" - round_index: int = 0 - - # Tool approval fields - agent_id: str = "" - prompt: str = "" - context: str | None = None - conversation_snapshot: list[ChatMessage] = field(default_factory=list) # type: ignore - - # Stall intervention fields - stall_count: int = 0 - max_stall_count: int = 3 - stall_reason: str = "" - last_agent: str = "" - - -@dataclass -class _MagenticHumanInterventionReply: - """Unified reply to a human intervention request. - - The relevant fields depend on the original request kind and the decision made. - - Attributes: - decision: The human's decision (approve, revise, continue, replan, guidance) - edited_plan_text: New plan text if directly editing (plan_review with approve/revise) - comments: Feedback for revision or guidance text (plan_review, stall with guidance) - response_text: Free-form response text (tool_approval) - """ - - decision: MagenticHumanInterventionDecision - edited_plan_text: str | None = None - comments: str | None = None - response_text: str | None = None - - -# endregion Human Intervention Types - - @dataclass class _MagenticTaskLedger(DictConvertible): """Internal: Task ledger for the Standard Magentic manager.""" @@ -493,7 +283,7 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticTaskLedger": @dataclass -class _MagenticProgressLedgerItem(DictConvertible): +class MagenticProgressLedgerItem(DictConvertible): """Internal: A progress ledger item.""" reason: str @@ -503,7 +293,7 @@ def to_dict(self) -> dict[str, Any]: return {"reason": self.reason, "answer": self.answer} @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedgerItem": + def from_dict(cls, data: dict[str, Any]) -> "MagenticProgressLedgerItem": answer_value = data.get("answer") if not isinstance(answer_value, (str, bool)): answer_value = "" # Default to empty string if not str or bool @@ -511,14 +301,14 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedgerItem": @dataclass -class _MagenticProgressLedger(DictConvertible): +class MagenticProgressLedger(DictConvertible): """Internal: A progress ledger for tracking workflow progress.""" - is_request_satisfied: _MagenticProgressLedgerItem - is_in_loop: _MagenticProgressLedgerItem - is_progress_being_made: _MagenticProgressLedgerItem - next_speaker: _MagenticProgressLedgerItem - instruction_or_question: _MagenticProgressLedgerItem + is_request_satisfied: MagenticProgressLedgerItem + is_in_loop: MagenticProgressLedgerItem + is_progress_being_made: MagenticProgressLedgerItem + next_speaker: MagenticProgressLedgerItem + instruction_or_question: MagenticProgressLedgerItem def to_dict(self) -> dict[str, Any]: return { @@ -530,13 +320,13 @@ def to_dict(self) -> dict[str, Any]: } @classmethod - def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger": + def from_dict(cls, data: dict[str, Any]) -> "MagenticProgressLedger": return cls( - is_request_satisfied=_MagenticProgressLedgerItem.from_dict(data.get("is_request_satisfied", {})), - is_in_loop=_MagenticProgressLedgerItem.from_dict(data.get("is_in_loop", {})), - is_progress_being_made=_MagenticProgressLedgerItem.from_dict(data.get("is_progress_being_made", {})), - next_speaker=_MagenticProgressLedgerItem.from_dict(data.get("next_speaker", {})), - instruction_or_question=_MagenticProgressLedgerItem.from_dict(data.get("instruction_or_question", {})), + is_request_satisfied=MagenticProgressLedgerItem.from_dict(data.get("is_request_satisfied", {})), + is_in_loop=MagenticProgressLedgerItem.from_dict(data.get("is_in_loop", {})), + is_progress_being_made=MagenticProgressLedgerItem.from_dict(data.get("is_progress_being_made", {})), + next_speaker=MagenticProgressLedgerItem.from_dict(data.get("next_speaker", {})), + instruction_or_question=MagenticProgressLedgerItem.from_dict(data.get("instruction_or_question", {})), ) @@ -544,7 +334,7 @@ def from_dict(cls, data: dict[str, Any]) -> "_MagenticProgressLedger": class MagenticContext(DictConvertible): """Context for the Magentic manager.""" - task: ChatMessage + task: str chat_history: list[ChatMessage] = field(default_factory=_new_chat_history) participant_descriptions: dict[str, str] = field(default_factory=_new_participant_descriptions) round_count: int = 0 @@ -553,7 +343,7 @@ class MagenticContext(DictConvertible): def to_dict(self) -> dict[str, Any]: return { - "task": _message_to_payload(self.task), + "task": self.task, "chat_history": [_message_to_payload(msg) for msg in self.chat_history], "participant_descriptions": dict(self.participant_descriptions), "round_count": self.round_count, @@ -563,14 +353,27 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "MagenticContext": + # Validate required fields + # `task` is required + task = data.get("task") + if task is None or not isinstance(task, str): + raise ValueError("MagenticContext requires a 'task' string field.") + # `chat_history` is required chat_history_payload = data.get("chat_history", []) history: list[ChatMessage] = [] for item in chat_history_payload: history.append(_message_from_payload(item)) + # `participant_descriptions` is required + participant_descriptions = data.get("participant_descriptions") + if not isinstance(participant_descriptions, dict) or not participant_descriptions: + raise ValueError("MagenticContext requires a 'participant_descriptions' dictionary field.") + if not all(isinstance(k, str) and isinstance(v, str) for k, v in participant_descriptions.items()): # type: ignore + raise ValueError("MagenticContext 'participant_descriptions' must be a dict of str to str.") + return cls( - task=_message_from_payload(data.get("task")), + task=task, chat_history=history, - participant_descriptions=dict(data.get("participant_descriptions", {})), + participant_descriptions=participant_descriptions, # type: ignore round_count=data.get("round_count", 0), stall_count=data.get("stall_count", 0), reset_count=data.get("reset_count", 0), @@ -597,13 +400,6 @@ def _team_block(participants: dict[str, str]) -> str: return "\n".join(f"- {name}: {desc}" for name, desc in participants.items()) -def _first_assistant(messages: list[ChatMessage]) -> ChatMessage | None: - for msg in reversed(messages): - if msg.role == Role.ASSISTANT: - return msg - return None - - def _extract_json(text: str) -> dict[str, Any]: """Potentially temp helper method. @@ -693,7 +489,7 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: ... @abstractmethod - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: """Create a progress ledger.""" ... @@ -724,6 +520,8 @@ class StandardMagenticManager(MagenticManagerBase): task_ledger: _MagenticTaskLedger | None + MANAGER_NAME: ClassVar[str] = "StandardMagenticManager" + def __init__( self, agent: AgentProtocol, @@ -797,25 +595,21 @@ async def _complete( (temperature, seed, instructions, etc.). """ response: AgentRunResponse = await self._agent.run(messages) - out_messages = response.messages if response else None - if out_messages: - last = out_messages[-1] - return ChatMessage( - role=last.role, - text=last.text, - author_name=last.author_name or MAGENTIC_MANAGER_NAME, - ) - return ChatMessage(role=Role.ASSISTANT, text="No output produced.", author_name=MAGENTIC_MANAGER_NAME) + if not response.messages: + raise RuntimeError("Agent returned no messages in response.") + if len(response.messages) > 1: + logger.warning("Agent returned multiple messages; using the last one.") + + return response.messages[-1] async def plan(self, magentic_context: MagenticContext) -> ChatMessage: """Create facts and plan using the model, then render a combined task ledger as a single assistant message.""" - task_text = magentic_context.task.text team_text = _team_block(magentic_context.participant_descriptions) # Gather facts facts_user = ChatMessage( role=Role.USER, - text=self.task_ledger_facts_prompt.format(task=task_text), + text=self.task_ledger_facts_prompt.format(task=magentic_context.task), ) facts_msg = await self._complete([*magentic_context.chat_history, facts_user]) @@ -834,7 +628,7 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: magentic_context.chat_history.extend([facts_user, facts_msg, plan_user, plan_msg]) combined = self.task_ledger_full_prompt.format( - task=task_text, + task=magentic_context.task, team=team_text, facts=facts_msg.text, plan=plan_msg.text, @@ -846,13 +640,14 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: if self.task_ledger is None: raise RuntimeError("replan() called before plan(); call plan() once before requesting a replan.") - task_text = magentic_context.task.text team_text = _team_block(magentic_context.participant_descriptions) # Update facts facts_update_user = ChatMessage( role=Role.USER, - text=self.task_ledger_facts_update_prompt.format(task=task_text, old_facts=self.task_ledger.facts.text), + text=self.task_ledger_facts_update_prompt.format( + task=magentic_context.task, old_facts=self.task_ledger.facts.text + ), ) updated_facts = await self._complete([*magentic_context.chat_history, facts_update_user]) @@ -876,14 +671,14 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: magentic_context.chat_history.extend([facts_update_user, updated_facts, plan_update_user, updated_plan]) combined = self.task_ledger_full_prompt.format( - task=task_text, + task=magentic_context.task, team=team_text, facts=updated_facts.text, plan=updated_plan.text, ) return ChatMessage(role=Role.ASSISTANT, text=combined, author_name=MAGENTIC_MANAGER_NAME) - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: """Use the model to produce a JSON progress ledger based on the conversation so far. Adds lightweight retries with backoff for transient parse issues and avoids selecting a @@ -897,7 +692,7 @@ async def create_progress_ledger(self, magentic_context: MagenticContext) -> _Ma team_text = _team_block(magentic_context.participant_descriptions) prompt = self.progress_ledger_prompt.format( - task=magentic_context.task.text, + task=magentic_context.task, team=team_text, names=names_csv, ) @@ -910,7 +705,7 @@ async def create_progress_ledger(self, magentic_context: MagenticContext) -> _Ma raw = await self._complete([*magentic_context.chat_history, user_message]) try: ledger_dict = _extract_json(raw.text) - return _coerce_model(_MagenticProgressLedger, ledger_dict) + return _coerce_model(MagenticProgressLedger, ledger_dict) except Exception as ex: last_error = ex attempts += 1 @@ -927,7 +722,7 @@ async def create_progress_ledger(self, magentic_context: MagenticContext) -> _Ma async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: """Ask the model to produce the final answer addressed to the user.""" - prompt = self.final_answer_prompt.format(task=magentic_context.task.text) + prompt = self.final_answer_prompt.format(task=magentic_context.task) user_message = ChatMessage(role=Role.USER, text=prompt) response = await self._complete([*magentic_context.chat_history, user_message]) # Ensure role is assistant @@ -956,627 +751,370 @@ def on_checkpoint_restore(self, state: dict[str, Any]) -> None: # endregion Magentic Manager -# region Magentic Executors +# region Magentic Orchestrator -class MagenticOrchestratorExecutor(BaseGroupChatOrchestrator): - """Magentic orchestrator executor that handles all orchestration logic. +class MagenticResetSignal: + """Signal to indicate that the Magentic workflow should reset. - This executor manages the entire Magentic One workflow including: - - Initial planning and task ledger creation - - Progress tracking and completion detection - - Agent coordination and message routing - - Reset and replanning logic + This signal can be raised within the orchestrator's inner loop to trigger + a reset of the Magentic context, clearing chat history and resetting + stall counts. """ - # Typed attributes (initialized in __init__) - _agent_executors: dict[str, "MagenticAgentExecutor"] - _context: "MagenticContext | None" - _task_ledger: "ChatMessage | None" - _inner_loop_lock: asyncio.Lock - _require_plan_signoff: bool - _plan_review_round: int - _max_plan_review_rounds: int - _terminated: bool - _enable_stall_intervention: bool - - def __init__( - self, - manager: MagenticManagerBase, - participants: dict[str, str], - *, - require_plan_signoff: bool = False, - max_plan_review_rounds: int = 10, - enable_stall_intervention: bool = False, - executor_id: str | None = None, - ) -> None: - """Initializes a new instance of the MagenticOrchestratorExecutor. - - Args: - manager: The Magentic manager instance. - participants: A dictionary of participant IDs to their names. - require_plan_signoff: Whether to require plan sign-off from a human. - max_plan_review_rounds: The maximum number of plan review rounds. - enable_stall_intervention: Whether to request human input on stalls instead of auto-replan. - executor_id: An optional executor ID. - """ - super().__init__(executor_id or f"magentic_orchestrator_{uuid4().hex[:8]}") - self._manager = manager - self._participants = participants - self._context = None - self._task_ledger = None - self._require_plan_signoff = require_plan_signoff - self._plan_review_round = 0 - self._max_plan_review_rounds = max_plan_review_rounds - self._enable_stall_intervention = enable_stall_intervention - # Registry of agent executors for internal coordination (e.g., resets) - self._agent_executors = {} - # Terminal state marker to stop further processing after completion/limits - self._terminated = False - # Tracks whether checkpoint state has been applied for this run - - def _get_author_name(self) -> str: - """Get the magentic manager name for orchestrator-generated messages.""" - return MAGENTIC_MANAGER_NAME - - def register_agent_executor(self, name: str, executor: "MagenticAgentExecutor") -> None: - """Register an agent executor for internal control (no messages).""" - self._agent_executors[name] = executor - - async def _emit_orchestrator_message( - self, - ctx: WorkflowContext[Any, list[ChatMessage]], - message: ChatMessage, - kind: str, - ) -> None: - """Emit orchestrator message to the workflow event stream. - - Emits an AgentRunUpdateEvent (for agent wrapper consumers) with metadata indicating - the orchestrator event type. - - Args: - ctx: Workflow context for adding events to the stream - message: Orchestrator message to emit (task, plan, instruction, notice) - kind: Message classification (user_task, task_ledger, instruction, notice) - - Example: - async for event in workflow.run_stream("task"): - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - if props and props.get("magentic_event_type") == "orchestrator_message": - kind = props.get("orchestrator_message_kind", "") - print(f"Orchestrator {kind}: {event.data.text}") - """ - # Emit AgentRunUpdateEvent with metadata - update = AgentRunResponseUpdate( - text=message.text, - role=message.role, - author_name=self._get_author_name(), - additional_properties={ - "magentic_event_type": MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - "orchestrator_message_kind": kind, - "orchestrator_id": self.id, - }, - ) - await ctx.add_event(AgentRunUpdateEvent(executor_id=self.id, data=update)) - - @override - async def on_checkpoint_save(self) -> dict[str, Any]: - """Capture current orchestrator state for checkpointing. - - Uses OrchestrationState for structure but maintains Magentic's complex metadata - at the top level for backward compatibility with existing checkpoints. - - Returns: - Dict ready for checkpoint persistence - """ - state: dict[str, Any] = { - "plan_review_round": self._plan_review_round, - "max_plan_review_rounds": self._max_plan_review_rounds, - "require_plan_signoff": self._require_plan_signoff, - "terminated": self._terminated, - } - if self._context is not None: - state["magentic_context"] = self._context.to_dict() - if self._task_ledger is not None: - state["task_ledger"] = _message_to_payload(self._task_ledger) + pass - try: - state["manager_state"] = self._manager.on_checkpoint_save() - except Exception as exc: - logger.warning(f"Failed to save manager state for checkpoint: {exc}\nSkipping...") - return state +class MagenticOrchestratorEventType(str, Enum): + """Types of Magentic orchestrator events.""" - @override - async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - """Restore orchestrator state from checkpoint. + PLAN_CREATED = "plan_created" + REPLANNED = "replanned" + PROGRESS_LEDGER_UPDATED = "progress_ledger_updated" - Maintains backward compatibility with existing Magentic checkpoints - while supporting OrchestrationState structure. - Args: - state: Checkpoint data dict - """ - # Support both old format (direct keys) and new format (wrapped in OrchestrationState) - if "metadata" in state and isinstance(state.get("metadata"), dict): - # New OrchestrationState format - extract metadata - from ._orchestration_state import OrchestrationState +@dataclass +class MagenticOrchestratorEvent(ExecutorEvent): + """Base class for Magentic orchestrator events.""" - orch_state = OrchestrationState.from_dict(state) - state = orch_state.metadata + def __init__( + self, + executor_id: str, + event_type: MagenticOrchestratorEventType, + data: ChatMessage | MagenticProgressLedger, + ) -> None: + super().__init__(executor_id, data) + self.event_type = event_type - ctx_payload = state.get("magentic_context") - if ctx_payload is not None: - try: - if isinstance(ctx_payload, dict): - self._context = MagenticContext.from_dict(ctx_payload) # type: ignore[arg-type] - else: - self._context = None - except Exception as exc: # pragma: no cover - defensive - logger.warning(f"Failed to restore magentic context: {exc}") - self._context = None - ledger_payload = state.get("task_ledger") - if ledger_payload is not None: - try: - self._task_ledger = _message_from_payload(ledger_payload) - except Exception as exc: # pragma: no cover - logger.warning(f"Failed to restore task ledger message: {exc}") - self._task_ledger = None + def __repr__(self) -> str: + return f"{self.__class__.__name__}(executor_id={self.executor_id}, event_type={self.event_type})" - if "plan_review_round" in state: - try: - self._plan_review_round = int(state["plan_review_round"]) - except Exception: # pragma: no cover - logger.debug("Ignoring invalid plan_review_round in checkpoint state") - if "max_plan_review_rounds" in state: - self._max_plan_review_rounds = state.get("max_plan_review_rounds") # type: ignore[assignment] - if "require_plan_signoff" in state: - self._require_plan_signoff = bool(state.get("require_plan_signoff")) - if "terminated" in state: - self._terminated = bool(state.get("terminated")) - manager_state = state.get("manager_state") - if manager_state is not None: - try: - self._manager.on_checkpoint_restore(manager_state) - except Exception as exc: # pragma: no cover - logger.warning(f"Failed to restore manager state: {exc}") +# region Request info related types - self._reconcile_restored_participants() - def _reconcile_restored_participants(self) -> None: - """Ensure restored participant roster matches the current workflow graph.""" - if self._context is None: - return +@dataclass +class MagenticPlanReviewResponse: + """Response to a human plan review request. - restored = self._context.participant_descriptions or {} - expected = self._participants + Attributes: + review: List of messages containing feedback and suggested revisions. If empty, + the plan is considered approved. + """ - restored_names = set(restored.keys()) - expected_names = set(expected.keys()) + review: list[ChatMessage] - if restored_names != expected_names: - missing = ", ".join(sorted(expected_names - restored_names)) or "none" - unexpected = ", ".join(sorted(restored_names - expected_names)) or "none" - raise RuntimeError( - "Magentic checkpoint restore failed: participant names do not match the checkpoint. " - "Ensure MagenticBuilder.participants keys remain stable across runs. " - f"Missing names: {missing}; unexpected names: {unexpected}." - ) + @staticmethod + def approve() -> "MagenticPlanReviewResponse": + """Create an approval response.""" + return MagenticPlanReviewResponse(review=[]) - # Refresh descriptions so prompt surfaces always reflect the rebuilt workflow inputs. - for name, description in expected.items(): - restored[name] = description + @staticmethod + def revise(feedback: str | list[str] | ChatMessage | list[ChatMessage]) -> "MagenticPlanReviewResponse": + """Create a revision response with feedback.""" + if isinstance(feedback, str): + feedback = [ChatMessage(role=Role.USER, text=feedback)] + elif isinstance(feedback, ChatMessage): + feedback = [feedback] + elif isinstance(feedback, list): + feedback = [ChatMessage(role=Role.USER, text=item) if isinstance(item, str) else item for item in feedback] - @handler - async def handle_start_message( - self, - message: _MagenticStartMessage, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - """Handle the initial start message to begin orchestration.""" - if getattr(self, "_terminated", False): - return - logger.info("Magentic Orchestrator: Received start message") + return MagenticPlanReviewResponse(review=feedback) - # Store run_kwargs in SharedState so agent executors can access them - # Always store (even empty dict) so retrieval is deterministic - await context.set_shared_state(WORKFLOW_RUN_KWARGS_KEY, message.run_kwargs or {}) - self._context = MagenticContext( - task=message.task, - participant_descriptions=self._participants, - ) - if message.messages: - self._context.chat_history.extend(message.messages) +@dataclass +class MagenticPlanReviewRequest: + """Request for human review of a proposed plan. - # Non-streaming callback for the orchestrator receipt of the task - await self._emit_orchestrator_message(context, message.task, ORCH_MSG_KIND_USER_TASK) + Attributes: + plan: The proposed plan message. + current_progress: The current progress ledger, if available. + During the initial plan review, this will be None. In subsequent + reviews after replanning (due to stalls), this will contain the + latest progress ledger that determined no progress had been made + or the workflow was in a loop. + is_stalled: Whether the workflow is currently stalled. + """ - # Initial planning using the manager with real model calls - self._task_ledger = await self._manager.plan(self._context.clone(deep=True)) + plan: ChatMessage + current_progress: MagenticProgressLedger | None + is_stalled: bool - # If a human must sign off, ask now and return. The response handler will resume. - if self._require_plan_signoff: - await self._send_plan_review_request(cast(WorkflowContext, context)) - return + def approve(self) -> MagenticPlanReviewResponse: + """Create an approval response.""" + return MagenticPlanReviewResponse.approve() - # Add task ledger to conversation history - self._context.chat_history.append(self._task_ledger) + def revise(self, feedback: str | list[str] | ChatMessage | list[ChatMessage]) -> MagenticPlanReviewResponse: + """Create a revision response with feedback.""" + return MagenticPlanReviewResponse.revise(feedback) - logger.debug("Task ledger created.") - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) +# endregion Human Intervention Types - # Start the inner loop - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) - @handler - async def handle_task_text( - self, - task_text: str, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - await self.handle_start_message(_MagenticStartMessage.from_string(task_text), context) +class MagenticOrchestrator(BaseGroupChatOrchestrator): + """Magentic orchestrator that defines the workflow structure. - @handler - async def handle_task_message( - self, - task_message: ChatMessage, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - await self.handle_start_message(_MagenticStartMessage(task_message), context) + This orchestrator manages the overall Magentic workflow in the following structure: - @handler - async def handle_task_messages( - self, - conversation: list[ChatMessage], - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - await self.handle_start_message(_MagenticStartMessage(conversation), context) + 1. Upon receiving the task (a list of messages), it creates the plan using the manager + then runs the inner loop. + 2. The inner loop is distributed and implementation is decentralized. In the orchestrator, + it is responsible for: + - Creating the progress ledger using the manager. + - Checking for task completion. + - Detecting stalling or looping and triggering replanning if needed. + - Sending requests to participants based on the progress ledger's next speaker. + - Issue requests for human intervention if enabled and needed. + 3. The inner loop waits for responses from the selected participant, then continues the loop. + 4. The orchestrator breaks out of the inner loop when the replanning or final answer conditions are met. + 5. The outer loop handles replanning and reenters the inner loop. + """ - @handler - async def handle_response_message( + def __init__( self, - message: _MagenticResponseMessage, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + manager: MagenticManagerBase, + participant_registry: ParticipantRegistry, + *, + require_plan_signoff: bool = False, ) -> None: - """Handle responses from agents.""" - if getattr(self, "_terminated", False): - return - - if self._context is None: - raise RuntimeError("Magentic Orchestrator: Received response but not initialized") + """Initialize the Magentic orchestrator. - logger.debug("Magentic Orchestrator: Received response from agent") + Args: + manager: The Magentic manager instance to use for planning and progress tracking. + participant_registry: Registry of participants involved in the workflow. - # Add transfer message if needed - if message.body.role != Role.USER: - transfer_msg = ChatMessage( - role=Role.USER, - text=f"Transferred to {getattr(message.body, 'author_name', 'agent')}", - ) - self._context.chat_history.append(transfer_msg) + Keyword Args: + require_plan_signoff: If True, requires human approval of the initial plan before proceeding. + """ + super().__init__("magentic_orchestrator", participant_registry) + self._manager = manager + self._require_plan_signoff = require_plan_signoff - # Add agent response to context - self._context.chat_history.append(message.body) + # Task related state + self._magentic_context: MagenticContext | None = None + self._task_ledger: ChatMessage | None = None + self._progress_ledger: MagenticProgressLedger | None = None - # Continue with inner loop - await self._run_inner_loop(context) + # Termination related state + self._terminated: bool = False + self._max_rounds = manager.max_round_count - @response_handler - async def handle_human_intervention_response( + @override + async def _handle_messages( self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], + messages: list[ChatMessage], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handle unified human intervention responses. - - Routes the response to the appropriate handler based on the original request kind. - """ - if getattr(self, "_terminated", False): - return - - if self._context is None: - return + """Handle the initial task messages to start the workflow.""" + if self._terminated: + raise RuntimeError( + "This Magentic workflow has already been completed. No further messages can be processed. " + "Use the builder to create a new workflow instance to handle additional tasks." + ) - if original_request.kind == MagenticHumanInterventionKind.PLAN_REVIEW: - await self._handle_plan_review_response(original_request, response, context) - elif original_request.kind == MagenticHumanInterventionKind.STALL: - await self._handle_stall_intervention_response(original_request, response, context) - # TOOL_APPROVAL is handled by MagenticAgentExecutor, not the orchestrator + if not messages: + raise ValueError("Magentic orchestrator requires at least one message to start the workflow.") - async def _handle_plan_review_response( - self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], - ) -> None: - """Handle plan review response.""" - if self._context is None: - return + if len(messages) > 1: + raise ValueError("Magentic only support a single task message to start the workflow.") - is_approve = response.decision == MagenticHumanInterventionDecision.APPROVE - - if is_approve: - # Close the review loop on approval (no further plan review requests this run) - self._require_plan_signoff = False - # If the user supplied an edited plan, adopt it - if response.edited_plan_text: - # Update the manager's internal ledger and rebuild the combined message - mgr_ledger = getattr(self._manager, "task_ledger", None) - if mgr_ledger is not None: - mgr_ledger.plan.text = response.edited_plan_text - team_text = _team_block(self._participants) - combined = self._manager.task_ledger_full_prompt.format( - task=self._context.task.text, - team=team_text, - facts=(mgr_ledger.facts.text if mgr_ledger else ""), - plan=response.edited_plan_text, - ) - self._task_ledger = ChatMessage( - role=Role.ASSISTANT, - text=combined, - author_name=MAGENTIC_MANAGER_NAME, - ) - # If approved with comments but no edited text, apply comments via replan and proceed - elif response.comments: - self._context.chat_history.append( - ChatMessage(role=Role.USER, text=f"Human plan feedback: {response.comments}") - ) - self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) + if messages[0].text.strip() == "": + raise ValueError("Magentic task message must contain non-empty text.") - # Record the signed-off plan (no broadcast) - if self._task_ledger: - self._context.chat_history.append(self._task_ledger) - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) + self._magentic_context = MagenticContext( + task=messages[0].text, + participant_descriptions=self._participant_registry.participants, + chat_history=list(messages), + ) - # Enter the normal coordination loop - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, + # Initial planning using the manager with real model calls + self._task_ledger = await self._manager.plan(self._magentic_context.clone(deep=True)) + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.PLAN_CREATED, + data=self._task_ledger, ) - await self._run_inner_loop(ctx2) - return + ) - # Otherwise, REVISION round - self._plan_review_round += 1 - if self._plan_review_round > self._max_plan_review_rounds: - logger.warning("Magentic Orchestrator: Max plan review rounds reached. Proceeding with current plan.") - self._require_plan_signoff = False - notice = ChatMessage( - role=Role.ASSISTANT, - text=( - "Plan review closed after max rounds. Proceeding with the current plan and will no longer " - "prompt for plan approval." - ), - author_name=MAGENTIC_MANAGER_NAME, - ) - self._context.chat_history.append(notice) - await self._emit_orchestrator_message(context, notice, ORCH_MSG_KIND_NOTICE) - if self._task_ledger: - self._context.chat_history.append(self._task_ledger) - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) + # If a human must sign off, ask now and return. The response handler will resume. + if self._require_plan_signoff: + await self._send_plan_review_request(cast(WorkflowContext, ctx)) return - # If the user provided an edited plan, adopt it and ask for confirmation - if response.edited_plan_text: - mgr_ledger2 = getattr(self._manager, "task_ledger", None) - if mgr_ledger2 is not None: - mgr_ledger2.plan.text = response.edited_plan_text - team_text = _team_block(self._participants) - combined = self._manager.task_ledger_full_prompt.format( - task=self._context.task.text, - team=team_text, - facts=(mgr_ledger2.facts.text if mgr_ledger2 else ""), - plan=response.edited_plan_text, - ) - self._task_ledger = ChatMessage(role=Role.ASSISTANT, text=combined, author_name=MAGENTIC_MANAGER_NAME) - await self._send_plan_review_request(cast(WorkflowContext, context)) - return + # Add task ledger to conversation history + self._magentic_context.chat_history.append(self._task_ledger) - # Else pass comments into the chat history and replan - if response.comments: - self._context.chat_history.append( - ChatMessage(role=Role.USER, text=f"Human plan feedback: {response.comments}") - ) + logger.debug("Task ledger created.") - self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) - await self._send_plan_review_request(cast(WorkflowContext, context)) + # Start the inner loop + await self._run_inner_loop(ctx) - async def _handle_stall_intervention_response( + @override + async def _handle_response( self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[ - _MagenticResponseMessage | _MagenticRequestMessage | _MagenticHumanInterventionRequest, list[ChatMessage] - ], + response: AgentExecutorResponse | GroupChatResponseMessage, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Handle stall intervention response.""" - if self._context is None: - return + """Handle a response message from a participant.""" + if self._magentic_context is None or self._task_ledger is None: + raise RuntimeError("Context or task ledger not initialized") - ctx = self._context - logger.info( - f"Stall intervention response: decision={response.decision.value}, " - f"stall_count was {original_request.stall_count}" - ) + messages = self._process_participant_response(response) - if response.decision == MagenticHumanInterventionDecision.CONTINUE: - ctx.stall_count = 0 - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) - return + self._magentic_context.chat_history.extend(messages) - if response.decision == MagenticHumanInterventionDecision.REPLAN: - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._reset_and_replan(ctx2) - return + # Broadcast participant messages to all participants for context, except + # the participant that just responded + participant = ctx.get_source_executor_id() + await self._broadcast_messages_to_participants( + messages, + cast(WorkflowContext[AgentExecutorRequest | GroupChatParticipantMessage], ctx), + participants=[p for p in self._participant_registry.participants if p != participant], + ) - if response.decision == MagenticHumanInterventionDecision.GUIDANCE: - ctx.stall_count = 0 - guidance = response.comments or response.response_text - if guidance: - guidance_msg = ChatMessage( - role=Role.USER, - text=f"Human guidance to help with stall: {guidance}", - ) - ctx.chat_history.append(guidance_msg) - await self._emit_orchestrator_message(context, guidance_msg, ORCH_MSG_KIND_NOTICE) - ctx2 = cast( - WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - context, - ) - await self._run_inner_loop(ctx2) - return + await self._run_inner_loop(ctx) - async def _run_outer_loop( + @response_handler + async def handle_plan_review_response( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + original_request: MagenticPlanReviewRequest, + response: MagenticPlanReviewResponse, + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: - """Run the outer orchestration loop - planning phase.""" - if self._context is None: - raise RuntimeError("Context not initialized") - - logger.info("Magentic Orchestrator: Outer loop - entering inner loop") + """Handle the human response to the plan review request. + + Logic: + There are code paths which will trigger a plan review request to the human: + - Initial plan creation if `require_plan_signoff` is True. + - Potentially during the inner loop if stalling is detected (resetting and replanning). + + The human can either approve the plan or request revisions with comments. + - If approved, proceed to run the outer loop, which simply adds the task ledger + to the conversation and enters the inner loop. + - If revision requested, append the review comments to the chat history, + trigger replanning via the manager, emit a REPLANNED event, then run the outer loop. + """ + if self._magentic_context is None or self._task_ledger is None: + raise RuntimeError("Context or task ledger not initialized") - # Add task ledger to history if not already there - if self._task_ledger and ( - not self._context.chat_history or self._context.chat_history[-1] != self._task_ledger - ): - self._context.chat_history.append(self._task_ledger) + # Case 1: Approved + if len(response.review) == 0: + logger.debug("Magentic Orchestrator: Plan review approved by human.") + await self._run_outer_loop(ctx) + return + # Case 2: Revision requested + logger.debug("Magentic Orchestrator: Plan review revision requested by human.") + self._magentic_context.chat_history.extend(response.review) + self._task_ledger = await self._manager.replan(self._magentic_context.clone(deep=True)) + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.REPLANNED, + data=self._task_ledger, + ) + ) + # Continue the review process by sending the new plan for review again until approved + # We don't need to check if `_require_plan_signoff` is True here, since we are already + # in the review process. + await self._send_plan_review_request(cast(WorkflowContext, ctx), is_stalled=original_request.is_stalled) - # Optionally surface the updated task ledger via message callback (no broadcast) - if self._task_ledger is not None: - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) + async def _send_plan_review_request(self, ctx: WorkflowContext, is_stalled: bool = False) -> None: + """Send a human intervention request for plan review. - # Start inner loop - await self._run_inner_loop(context) + The response will be handled in the response handler `handle_plan_review_response`. + """ + if self._task_ledger is None: + raise RuntimeError("No task ledger available for plan review request.") + + await ctx.request_info( + MagenticPlanReviewRequest( + plan=self._task_ledger, + current_progress=self._progress_ledger, + is_stalled=is_stalled, + ), + MagenticPlanReviewResponse, + ) async def _run_inner_loop( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: """Run the inner orchestration loop. Coordination phase. Serialized with a lock.""" - if self._context is None or self._task_ledger is None: + if self._magentic_context is None or self._task_ledger is None: raise RuntimeError("Context or task ledger not initialized") - await self._run_inner_loop_helper(context) + await self._run_inner_loop_helper(ctx) async def _run_inner_loop_helper( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: """Run inner loop with exclusive access.""" # Narrow optional context for the remainder of this method - ctx = self._context - if ctx is None: + if self._magentic_context is None: raise RuntimeError("Context not initialized") # Check limits first - within_limits = await self._check_within_limits_or_complete(context) + within_limits = await self._check_within_limits_or_complete( + cast(WorkflowContext[Never, list[ChatMessage]], ctx) + ) if not within_limits: return - ctx.round_count += 1 - logger.info(f"Magentic Orchestrator: Inner loop - round {ctx.round_count}") + self._magentic_context.round_count += 1 + self._increment_round() + logger.debug(f"Magentic Orchestrator: Inner loop - round {self._round_index}") # Create progress ledger using the manager try: - current_progress_ledger = await self._manager.create_progress_ledger(ctx.clone(deep=True)) + self._progress_ledger = await self._manager.create_progress_ledger(self._magentic_context.clone(deep=True)) except Exception as ex: logger.warning(f"Magentic Orchestrator: Progress ledger creation failed, triggering reset: {ex}") - await self._reset_and_replan(context) + await self._reset_and_replan(ctx) return + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.PROGRESS_LEDGER_UPDATED, + data=self._progress_ledger, + ) + ) + logger.debug( - f"Progress evaluation: satisfied={current_progress_ledger.is_request_satisfied.answer}, " - f"next={current_progress_ledger.next_speaker.answer}" + f"Progress evaluation: satisfied={self._progress_ledger.is_request_satisfied.answer}, " + f"next={self._progress_ledger.next_speaker.answer}" ) # Check for task completion - if current_progress_ledger.is_request_satisfied.answer: + if self._progress_ledger.is_request_satisfied.answer: logger.info("Magentic Orchestrator: Task completed") - await self._prepare_final_answer(context) + await self._prepare_final_answer(cast(WorkflowContext[Never, list[ChatMessage]], ctx)) return # Check for stalling or looping - if not current_progress_ledger.is_progress_being_made.answer or current_progress_ledger.is_in_loop.answer: - ctx.stall_count += 1 + if not self._progress_ledger.is_progress_being_made.answer or self._progress_ledger.is_in_loop.answer: + self._magentic_context.stall_count += 1 else: - ctx.stall_count = max(0, ctx.stall_count - 1) - - if ctx.stall_count > self._manager.max_stall_count: - logger.info(f"Magentic Orchestrator: Stalling detected after {ctx.stall_count} rounds") - if self._enable_stall_intervention: - # Request human intervention instead of auto-replan - is_progress = current_progress_ledger.is_progress_being_made.answer - is_loop = current_progress_ledger.is_in_loop.answer - stall_reason = "No progress being made" if not is_progress else "" - if is_loop: - loop_msg = "Agents appear to be in a loop" - stall_reason = f"{stall_reason}; {loop_msg}" if stall_reason else loop_msg - next_speaker_val = current_progress_ledger.next_speaker.answer - last_agent = next_speaker_val if isinstance(next_speaker_val, str) else "" - # Get facts and plan from manager's task ledger - mgr_ledger = getattr(self._manager, "task_ledger", None) - facts_text = mgr_ledger.facts.text if mgr_ledger else "" - plan_text = mgr_ledger.plan.text if mgr_ledger else "" - request = _MagenticHumanInterventionRequest( - kind=MagenticHumanInterventionKind.STALL, - stall_count=ctx.stall_count, - max_stall_count=self._manager.max_stall_count, - task_text=ctx.task.text if ctx.task else "", - facts_text=facts_text, - plan_text=plan_text, - last_agent=last_agent, - stall_reason=stall_reason, - ) - await context.request_info(request, _MagenticHumanInterventionReply) - return - # Default behavior: auto-replan - await self._reset_and_replan(context) + self._magentic_context.stall_count = max(0, self._magentic_context.stall_count - 1) + + if self._magentic_context.stall_count > self._manager.max_stall_count: + logger.debug(f"Magentic Orchestrator: Stalling detected after {self._magentic_context.stall_count} rounds") + await self._reset_and_replan(ctx) return # Determine the next speaker and instruction - answer_val = current_progress_ledger.next_speaker.answer - if not isinstance(answer_val, str): + next_speaker = self._progress_ledger.next_speaker.answer + if not isinstance(next_speaker, str): # Fallback to first participant if ledger returns non-string logger.warning("Next speaker answer was not a string; selecting first participant as fallback") - answer_val = next(iter(self._participants.keys())) - next_speaker_value: str = answer_val - instruction = current_progress_ledger.instruction_or_question.answer + next_speaker = next(iter(self._participant_registry.participants.keys())) + instruction = self._progress_ledger.instruction_or_question.answer - if next_speaker_value not in self._participants: - logger.warning(f"Invalid next speaker: {next_speaker_value}") - await self._prepare_final_answer(context) + if next_speaker not in self._participant_registry.participants: + logger.warning(f"Invalid next speaker: {next_speaker}") + await self._prepare_final_answer(cast(WorkflowContext[Never, list[ChatMessage]], ctx)) return # Add instruction to conversation (assistant guidance) @@ -1585,505 +1123,232 @@ async def _run_inner_loop_helper( text=str(instruction), author_name=MAGENTIC_MANAGER_NAME, ) - ctx.chat_history.append(instruction_msg) - await self._emit_orchestrator_message(context, instruction_msg, ORCH_MSG_KIND_INSTRUCTION) - - # Determine the selected agent's executor id - target_executor_id = f"agent_{next_speaker_value}" + self._magentic_context.chat_history.append(instruction_msg) # Request specific agent to respond - logger.debug(f"Magentic Orchestrator: Requesting {next_speaker_value} to respond") - await context.send_message( - _MagenticRequestMessage( - agent_name=next_speaker_value, - instruction=str(instruction), - task_context=ctx.task.text, - ), - target_id=target_executor_id, + logger.debug(f"Magentic Orchestrator: Requesting {next_speaker} to respond") + await self._send_request_to_participant( + next_speaker, + cast(WorkflowContext[AgentExecutorRequest | GroupChatRequestMessage], ctx), + additional_instruction=str(instruction), ) async def _reset_and_replan( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: """Reset context and replan.""" - if self._context is None: - return + if self._magentic_context is None: + raise RuntimeError("Context not initialized") - logger.info("Magentic Orchestrator: Resetting and replanning") + logger.debug("Magentic Orchestrator: Resetting and replanning") # Reset context - self._context.reset() + self._magentic_context.reset() - # Replan - self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) - self._context.chat_history.append(self._task_ledger) - await self._emit_orchestrator_message(context, self._task_ledger, ORCH_MSG_KIND_TASK_LEDGER) + # Reset all participant states + await self._reset_participants(cast(WorkflowContext[MagenticResetSignal], ctx)) - # Internally reset all registered agent executors (no handler/messages involved) - for agent in self._agent_executors.values(): - with contextlib.suppress(Exception): - agent.reset() + # Replan + self._task_ledger = await self._manager.replan(self._magentic_context.clone(deep=True)) + await ctx.add_event( + MagenticOrchestratorEvent( + executor_id=self.id, + event_type=MagenticOrchestratorEventType.REPLANNED, + data=self._task_ledger, + ) + ) + # If a human must sign off, ask now and return. The response handler will resume. + if self._require_plan_signoff: + await self._send_plan_review_request(cast(WorkflowContext, ctx), is_stalled=True) + return + + self._magentic_context.chat_history.append(self._task_ledger) # Restart outer loop - await self._run_outer_loop(context) + await self._run_outer_loop(ctx) - async def _prepare_final_answer( + async def _run_outer_loop( self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], + ctx: WorkflowContext[GroupChatWorkflowContext_T_Out, list[ChatMessage]], ) -> None: + """Run the outer orchestration loop - planning phase.""" + if self._magentic_context is None: + raise RuntimeError("Context not initialized") + + logger.debug("Magentic Orchestrator: Outer loop - entering inner loop") + + # Add task ledger to history if not already there + if self._task_ledger and ( + not self._magentic_context.chat_history or self._magentic_context.chat_history[-1] != self._task_ledger + ): + self._magentic_context.chat_history.append(self._task_ledger) + + # Start inner loop + await self._run_inner_loop(ctx) + + async def _prepare_final_answer(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> None: """Prepare the final answer using the manager.""" - if self._context is None: - return + if self._magentic_context is None: + raise RuntimeError("Context not initialized") logger.info("Magentic Orchestrator: Preparing final answer") - final_answer = await self._manager.prepare_final_answer(self._context.clone(deep=True)) + final_answer = await self._manager.prepare_final_answer(self._magentic_context.clone(deep=True)) # Emit a completed event for the workflow - await context.yield_output([final_answer]) + await ctx.yield_output([final_answer]) - async def _check_within_limits_or_complete( - self, - context: WorkflowContext[_MagenticResponseMessage | _MagenticRequestMessage, list[ChatMessage]], - ) -> bool: - """Check if orchestrator is within operational limits.""" - if self._context is None: - return False - ctx = self._context + self._terminated = True - hit_round_limit = self._manager.max_round_count is not None and ctx.round_count >= self._manager.max_round_count - hit_reset_limit = self._manager.max_reset_count is not None and ctx.reset_count >= self._manager.max_reset_count + async def _check_within_limits_or_complete(self, ctx: WorkflowContext[Never, list[ChatMessage]]) -> bool: + """Check if orchestrator is within operational limits. - if hit_round_limit or hit_reset_limit: - limit_type = "round" if hit_round_limit else "reset" - logger.error(f"Magentic Orchestrator: Max {limit_type} count reached") + If limits are exceeded, yield a termination message and mark the workflow as terminated. - # Only emit completion once and then mark terminated - if not self._terminated: - self._terminated = True - # Get partial result - partial_result = _first_assistant(ctx.chat_history) - if partial_result is None: - partial_result = ChatMessage( - role=Role.ASSISTANT, - text=f"Stopped due to {limit_type} limit. No partial result available.", - author_name=MAGENTIC_MANAGER_NAME, - ) - - # Yield the partial result and signal completion - await context.yield_output([partial_result]) - return False + Args: + ctx: The workflow context. - return True + Returns: + True if within limits, False if limits exceeded and workflow is terminated. + """ + if self._magentic_context is None: + raise RuntimeError("Context not initialized") - async def _send_plan_review_request(self, context: WorkflowContext) -> None: - """Send a human intervention request for plan review.""" - # If plan sign-off is disabled (e.g., ran out of review rounds), do nothing - if not self._require_plan_signoff: - return - ledger = getattr(self._manager, "task_ledger", None) - facts_text = ledger.facts.text if ledger else "" - plan_text = ledger.plan.text if ledger else "" - task_text = self._context.task.text if self._context else "" - - req = _MagenticHumanInterventionRequest( - kind=MagenticHumanInterventionKind.PLAN_REVIEW, - task_text=task_text, - facts_text=facts_text, - plan_text=plan_text, - round_index=self._plan_review_round, + hit_round_limit = self._max_rounds is not None and self._round_index >= self._max_rounds + hit_reset_limit = ( + self._manager.max_reset_count is not None + and self._magentic_context.reset_count >= self._manager.max_reset_count ) - await context.request_info(req, _MagenticHumanInterventionReply) - -# region Magentic Executors + if hit_round_limit or hit_reset_limit: + limit_type = "round" if hit_round_limit else "reset" + logger.error(f"Magentic Orchestrator: Max {limit_type} count reached") + # Yield the full conversation with an indication of termination due to limits + await ctx.yield_output([ + *self._magentic_context.chat_history, + ChatMessage( + role=Role.ASSISTANT, + text=f"Workflow terminated due to reaching maximum {limit_type} count.", + author_name=MAGENTIC_MANAGER_NAME, + ), + ]) + self._terminated = True -class MagenticAgentExecutor(Executor): - """Magentic agent executor that wraps an agent for participation in workflows. + return False - Leverages enhanced AgentExecutor with conversation injection hooks for: - - Receiving task ledger broadcasts - - Responding to specific agent requests - - Resetting agent state when needed - - Surfacing tool approval requests (user_input_requests) as HITL events - """ + return True - def __init__( - self, - agent: AgentProtocol | Executor, - agent_id: str, - ) -> None: - super().__init__(f"agent_{agent_id}") - self._agent = agent - self._agent_id = agent_id - self._chat_history: list[ChatMessage] = [] - self._pending_human_input_request: _MagenticHumanInterventionRequest | None = None - self._pending_tool_request: FunctionApprovalRequestContent | None = None - self._current_request_message: _MagenticRequestMessage | None = None + async def _reset_participants(self, ctx: WorkflowContext[MagenticResetSignal]) -> None: + """Reset all participant executors.""" + # Orchestrator is connected to all participants. Sending the message without specifying + # a target will broadcast to all. + await ctx.send_message(MagenticResetSignal()) @override async def on_checkpoint_save(self) -> dict[str, Any]: - """Capture current executor state for checkpointing. + """Capture current orchestrator state for checkpointing.""" + state = await super().on_checkpoint_save() + state["terminated"] = self._terminated - Returns: - Dict containing serialized chat history - """ - from ._conversation_state import encode_chat_messages + if self._magentic_context is not None: + state["magentic_context"] = self._magentic_context.to_dict() + if self._task_ledger is not None: + state["task_ledger"] = _message_to_payload(self._task_ledger) + if self._progress_ledger is not None: + state["progress_ledger"] = self._progress_ledger.to_dict() - return { - "chat_history": encode_chat_messages(self._chat_history), - } + try: + state["manager_state"] = self._manager.on_checkpoint_save() + except Exception as exc: + logger.warning(f"Failed to save manager state for checkpoint: {exc}\nSkipping...") + + return state @override async def on_checkpoint_restore(self, state: dict[str, Any]) -> None: - """Restore executor state from checkpoint. + """Restore executor state from checkpoint.""" + await super().on_checkpoint_restore(state) + self._terminated = state.get("terminated", False) - Args: - state: Checkpoint data dict - """ - from ._conversation_state import decode_chat_messages - - history_payload = state.get("chat_history") - if history_payload: + magentic_context_data = state.get("magentic_context") + if magentic_context_data is not None: try: - self._chat_history = decode_chat_messages(history_payload) - except Exception as exc: # pragma: no cover - logger.warning(f"Agent {self._agent_id}: Failed to restore chat history: {exc}") - self._chat_history = [] - else: - self._chat_history = [] - - @handler - async def handle_response_message( - self, message: _MagenticResponseMessage, context: WorkflowContext[_MagenticResponseMessage] - ) -> None: - """Handle response message (task ledger broadcast).""" - logger.debug(f"Agent {self._agent_id}: Received response message") - - # Check if this message is intended for this agent - if message.target_agent is not None and message.target_agent != self._agent_id and not message.broadcast: - # Message is targeted to a different agent, ignore it - logger.debug(f"Agent {self._agent_id}: Ignoring message targeted to {message.target_agent}") - return - - # Add transfer message if needed - if message.body.role != Role.USER: - transfer_msg = ChatMessage( - role=Role.USER, - text=f"Transferred to {getattr(message.body, 'author_name', 'agent')}", - ) - self._chat_history.append(transfer_msg) - - # Add message to agent's history - self._chat_history.append(message.body) - - def _get_persona_adoption_role(self) -> Role: - """Determine the best role for persona adoption messages. - - Uses SYSTEM role if the agent supports it, otherwise falls back to USER. - """ - # Only BaseAgent-derived agents are assumed to support SYSTEM messages reliably. - from agent_framework import BaseAgent as _AF_AgentBase # local import to avoid cycles - - if isinstance(self._agent, _AF_AgentBase) and hasattr(self._agent, "chat_client"): - return Role.SYSTEM - # For other agent types or when we can't determine support, use USER - return Role.USER + self._magentic_context = MagenticContext.from_dict(magentic_context_data) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore Magentic context from checkpoint data") + self._magentic_context = None - @handler - async def handle_request_message( - self, message: _MagenticRequestMessage, context: WorkflowContext[_MagenticResponseMessage, AgentRunResponse] - ) -> None: - """Handle request to respond.""" - if message.agent_name != self._agent_id: - return + task_ledger_data = state.get("task_ledger") + if task_ledger_data is not None: + try: + self._task_ledger = _message_from_payload(task_ledger_data) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore task ledger from checkpoint data") + self._task_ledger = None - logger.info(f"Agent {self._agent_id}: Received request to respond") + progress_ledger_data = state.get("progress_ledger") + if progress_ledger_data is not None: + try: + self._progress_ledger = MagenticProgressLedger.from_dict(progress_ledger_data) + except Exception: # pragma: no cover - defensive + logger.warning("Failed to restore progress ledger from checkpoint data") + self._progress_ledger = None - # Store the original request message for potential continuation after human input - self._current_request_message = message + manager_state = state.get("manager_state") + if manager_state is not None: + try: + self._manager.on_checkpoint_restore(manager_state) + except Exception as exc: + logger.warning(f"Failed to restore manager state from checkpoint: {exc}\nSkipping...") - # Add persona adoption message with appropriate role - persona_role = self._get_persona_adoption_role() - persona_msg = ChatMessage( - role=persona_role, - text=f"Transferred to {self._agent_id}, adopt the persona immediately.", - ) - self._chat_history.append(persona_msg) - # Add the orchestrator's instruction as a USER message so the agent treats it as the prompt - if message.instruction: - self._chat_history.append(ChatMessage(role=Role.USER, text=message.instruction)) - try: - # If the participant is not an invokable BaseAgent, return a no-op response. - from agent_framework import BaseAgent as _AF_AgentBase # local import to avoid cycles +# endregion Magentic Orchestrator - if not isinstance(self._agent, _AF_AgentBase): - response: ChatMessage = ChatMessage( - role=Role.ASSISTANT, - text=f"{self._agent_id} is a workflow executor and cannot be invoked directly.", - author_name=self._agent_id, - ) - self._chat_history.append(response) - await self._emit_agent_message_event(context, response) - await context.send_message(_MagenticResponseMessage(body=response)) - else: - # Invoke the agent - agent_response = await self._invoke_agent(context) - if agent_response is None: - # Agent is waiting for human input - don't send response yet - return - self._chat_history.append(agent_response) - # Send response back to orchestrator - await context.send_message(_MagenticResponseMessage(body=agent_response)) - - except Exception as e: - logger.warning(f"Agent {self._agent_id} invoke failed: {e}") - # Fallback response - response = ChatMessage( - role=Role.ASSISTANT, - text=f"Agent {self._agent_id}: Error processing request - {str(e)[:100]}", - ) - self._chat_history.append(response) - await self._emit_agent_message_event(context, response) - await context.send_message(_MagenticResponseMessage(body=response)) +# region Magentic Agent Executor - def reset(self) -> None: - """Reset the internal chat history of the agent (internal operation).""" - logger.debug(f"Agent {self._agent_id}: Resetting chat history") - self._chat_history.clear() - self._pending_human_input_request = None - self._pending_tool_request = None - self._current_request_message = None - @response_handler - async def handle_tool_approval_response( - self, - original_request: _MagenticHumanInterventionRequest, - response: _MagenticHumanInterventionReply, - context: WorkflowContext[_MagenticResponseMessage, AgentRunResponse], - ) -> None: - """Handle human response for tool approval and continue agent execution. +class MagenticAgentExecutor(AgentExecutor): + """Specialized AgentExecutor for Magentic agent participants.""" - When a human provides input in response to a tool approval request, - this handler processes the response based on the decision type: + def __init__(self, agent: AgentProtocol) -> None: + """Initialize a Magentic Agent Executor. - - APPROVE: Execute the tool call with the provided response text - - REJECT: Do not execute the tool, inform the agent of rejection - - GUIDANCE: Execute the tool call with the guidance text as input + This executor wraps an AgentProtocol instance to be used as a participant + in a Magentic One workflow. Args: - original_request: The original human intervention request - response: The human's response containing the decision and any text - context: The workflow context - """ - response_text = response.response_text or response.comments or "" - decision = response.decision - logger.info( - f"Agent {original_request.agent_id}: Received tool approval response " - f"(decision={decision.value}): {response_text[:50] if response_text else ''}" - ) - - # Get the pending tool request to extract call_id - pending_tool_request = self._pending_tool_request - self._pending_human_input_request = None - self._pending_tool_request = None - - # Handle REJECT decision - do not execute the tool call - if decision == MagenticHumanInterventionDecision.REJECT: - rejection_reason = response_text or "Tool call rejected by human" - logger.info(f"Agent {self._agent_id}: Tool call rejected: {rejection_reason}") - - if pending_tool_request is not None: - # Create a FunctionResultContent indicating rejection - function_result = FunctionResultContent( - call_id=pending_tool_request.function_call.call_id, - result=f"Tool call was rejected by human reviewer. Reason: {rejection_reason}", - ) - result_msg = ChatMessage( - role=Role.USER, - contents=[function_result], - ) - self._chat_history.append(result_msg) - else: - # Fallback without pending tool request - rejection_msg = ChatMessage( - role=Role.USER, - text=f"Tool call '{original_request.prompt}' was rejected: {rejection_reason}", - author_name="human", - ) - self._chat_history.append(rejection_msg) - - # Re-invoke the agent so it can adapt to the rejection - agent_response = await self._invoke_agent(context) - if agent_response is None: - return - self._chat_history.append(agent_response) - await context.send_message(_MagenticResponseMessage(body=agent_response)) - return + agent: The agent instance to wrap. - # Handle APPROVE and GUIDANCE decisions - execute the tool call - if pending_tool_request is not None: - # Create a FunctionResultContent with the human's response - function_result = FunctionResultContent( - call_id=pending_tool_request.function_call.call_id, - result=response_text, - ) - # Add the function result as a message to continue the conversation - result_msg = ChatMessage( - role=Role.USER, - contents=[function_result], - ) - self._chat_history.append(result_msg) - - # Re-invoke the agent to continue execution - agent_response = await self._invoke_agent(context) - if agent_response is None: - # Agent is waiting for more human input - return - self._chat_history.append(agent_response) - await context.send_message(_MagenticResponseMessage(body=agent_response)) - else: - # Fallback: no pending tool request, just add as text message - logger.warning( - f"Agent {original_request.agent_id}: No pending tool request found for response, " - "using fallback text handling", - ) - human_response_msg = ChatMessage( - role=Role.USER, - text=f"Human response to '{original_request.prompt}': {response_text}", - author_name="human", - ) - self._chat_history.append(human_response_msg) - - # Create a response message indicating human input was received - agent_response_msg = ChatMessage( - role=Role.ASSISTANT, - text=f"Received human input for: {original_request.prompt}. Continuing with the task.", - author_name=original_request.agent_id, - ) - self._chat_history.append(agent_response_msg) - await context.send_message(_MagenticResponseMessage(body=agent_response_msg)) - - async def _emit_agent_delta_event( - self, - ctx: WorkflowContext[Any, Any], - update: AgentRunResponseUpdate, - ) -> None: - # Add metadata to identify this as an agent streaming update - props = update.additional_properties - if props is None: - props = {} - update.additional_properties = props - props["magentic_event_type"] = MAGENTIC_EVENT_TYPE_AGENT_DELTA - props["agent_id"] = self._agent_id - - # Emit AgentRunUpdateEvent with the agent response update - await ctx.add_event(AgentRunUpdateEvent(executor_id=self._agent_id, data=update)) - - async def _emit_agent_message_event( - self, - ctx: WorkflowContext[Any, Any], - message: ChatMessage, - ) -> None: - # Agent message completion is already communicated via streaming updates - # No additional event needed - pass - - async def _invoke_agent( - self, - ctx: WorkflowContext[_MagenticResponseMessage, AgentRunResponse], - ) -> ChatMessage | None: - """Invoke the wrapped agent and return a response. - - This method streams the agent's response updates, collects them into an - AgentRunResponse, and handles any human input requests (tool approvals). - - Note: - If multiple user input requests are present in the agent's response, - only the first one is processed. A warning is logged and subsequent - requests are ignored. This is a current limitation of the single-request - pending state architecture. - - Returns: - ChatMessage with the agent's response, or None if waiting for human input. + Notes: Magentic pattern requires a reset operation upon replanning. This executor + extends the base AgentExecutor to handle resets appropriately. In order to handle + resets, the agent threads and other states are reset when requested by the orchestrator. + And because of this, MagenticAgentExecutor does not support custom threads. """ - logger.debug(f"Agent {self._agent_id}: Running with {len(self._chat_history)} messages") - - run_kwargs: dict[str, Any] = await ctx.get_shared_state(WORKFLOW_RUN_KWARGS_KEY) - - updates: list[AgentRunResponseUpdate] = [] - # The wrapped participant is guaranteed to be an BaseAgent when this is called. - agent = cast("AgentProtocol", self._agent) - async for update in agent.run_stream(messages=self._chat_history, **run_kwargs): # type: ignore[attr-defined] - updates.append(update) - await self._emit_agent_delta_event(ctx, update) - - run_result: AgentRunResponse = AgentRunResponse.from_agent_run_response_updates(updates) - - # Handle human input requests (tool approval) - process one at a time - if run_result.user_input_requests: - if len(run_result.user_input_requests) > 1: - logger.warning( - f"Agent {self._agent_id}: Multiple user input requests received " - f"({len(run_result.user_input_requests)}), processing only the first one" - ) - - user_input_request = run_result.user_input_requests[0] + super().__init__(agent) - # Build a prompt from the request based on its type - prompt: str - context_text: str | None = None + @handler + async def handle_magentic_reset(self, signal: MagenticResetSignal, ctx: WorkflowContext) -> None: + """Handle reset signal from the Magentic orchestrator. - if isinstance(user_input_request, FunctionApprovalRequestContent): - fn_call = user_input_request.function_call - prompt = f"Approve function call: {fn_call.name}" - if fn_call.arguments: - context_text = f"Arguments: {fn_call.arguments}" - else: - # Fallback for unknown request types - request_type = type(user_input_request).__name__ - prompt = f"Agent {self._agent_id} requires human input ({request_type})" - logger.warning(f"Agent {self._agent_id}: Unrecognized user input request type: {request_type}") - - # Store the original FunctionApprovalRequestContent for later use - self._pending_tool_request = user_input_request - - # Create and send the human intervention request for tool approval - request = _MagenticHumanInterventionRequest( - kind=MagenticHumanInterventionKind.TOOL_APPROVAL, - agent_id=self._agent_id, - prompt=prompt, - context=context_text, - conversation_snapshot=list(self._chat_history[-5:]), - ) - self._pending_human_input_request = request - await ctx.request_info(request, _MagenticHumanInterventionReply) - return None # Signal that we're waiting for human input + This method resets the internal state of the agent executor, including + any threads or caches, to prepare for a fresh start after replanning. - messages: list[ChatMessage] | None = None - with contextlib.suppress(Exception): - messages = list(run_result.messages) # type: ignore[assignment] - if messages and len(messages) > 0: - last: ChatMessage = messages[-1] - author = last.author_name or self._agent_id - role: Role = last.role if last.role else Role.ASSISTANT - text = last.text or "" - msg = ChatMessage(role=role, text=text, author_name=author) - await self._emit_agent_message_event(ctx, msg) - return msg - - msg = ChatMessage( - role=Role.ASSISTANT, - text=f"Agent {self._agent_id}: No output produced", - author_name=self._agent_id, - ) - await self._emit_agent_message_event(ctx, msg) - return msg + Args: + signal: The MagenticResetSignal instance. + ctx: The workflow context. + """ + # Message related + self._cache.clear() + self._full_conversation.clear() + # Request into related + self._pending_agent_requests.clear() + self._pending_responses_to_agent.clear() + # Reset threads + self._agent_thread = self._agent.get_new_thread() -# endregion Magentic Executors +# endregion Magentic Agent Executor # region Magentic Workflow Builder @@ -2108,51 +1373,6 @@ class MagenticBuilder: These emit `MagenticHumanInterventionRequest` events that provide structured decision options (APPROVE, REVISE, CONTINUE, REPLAN, GUIDANCE) appropriate for Magentic's planning-based orchestration. - - Usage: - - .. code-block:: python - - from agent_framework import MagenticBuilder, StandardMagenticManager - from azure.ai.projects.aio import AIProjectClient - - # Create manager with LLM client - project_client = AIProjectClient.from_connection_string(...) - chat_client = project_client.inference.get_chat_completions_client() - - # Build Magentic workflow with agents - workflow = ( - MagenticBuilder() - .participants(researcher=research_agent, writer=writing_agent, coder=coding_agent) - .with_standard_manager(chat_client=chat_client, max_round_count=20, max_stall_count=3) - .with_plan_review(enable=True) - .with_checkpointing(checkpoint_storage) - .build() - ) - - # Execute workflow - async for message in workflow.run("Research and write article about AI agents"): - print(message.text) - - With custom manager: - - .. code-block:: python - - # Create custom manager subclass - class MyCustomManager(MagenticManagerBase): - async def plan(self, context: MagenticContext) -> ChatMessage: - # Custom planning logic - ... - - - manager = MyCustomManager() - workflow = MagenticBuilder().participants(agent1=agent1, agent2=agent2).with_standard_manager(manager).build() - - See Also: - - :class:`MagenticManagerBase`: Base class for custom managers - - :class:`StandardMagenticManager`: Default LLM-powered manager - - :class:`MagenticContext`: Context object passed to manager methods - - :class:`MagenticEvent`: Base class for workflow events """ def __init__(self) -> None: @@ -2160,33 +1380,29 @@ def __init__(self) -> None: self._manager: MagenticManagerBase | None = None self._enable_plan_review: bool = False self._checkpoint_storage: CheckpointStorage | None = None - self._enable_stall_intervention: bool = False - def participants(self, **participants: AgentProtocol | Executor) -> Self: - """Add participant agents or executors to the Magentic workflow. + def participants(self, participants: Sequence[AgentProtocol | Executor]) -> Self: + """Define participants for this Magentic workflow. - Participants are the agents that will execute tasks under the manager's direction. - Each participant should have distinct capabilities that complement the team. The - manager will select which participant to invoke based on the current plan and - progress state. + Accepts AgentProtocol instances (auto-wrapped as AgentExecutor) or Executor instances. Args: - **participants: Named agents or executors to add to the workflow. Names should - be descriptive of the agent's role (e.g., researcher=research_agent). - Accepts BaseAgent instances or custom Executor implementations. + participants: Sequence of participant definitions Returns: Self for method chaining - Usage: + Raises: + ValueError: If participants are empty, names are duplicated, or already set + TypeError: If any participant is not AgentProtocol or Executor instance + + Example: .. code-block:: python workflow = ( MagenticBuilder() - .participants( - researcher=research_agent, writer=writing_agent, coder=coding_agent, reviewer=review_agent - ) + .participants([research_agent, writing_agent, coding_agent, review_agent]) .with_standard_manager(agent=manager_agent) .build() ) @@ -2196,7 +1412,33 @@ def participants(self, **participants: AgentProtocol | Executor) -> Self: - Agent descriptions (if available) are extracted and provided to the manager - Can be called multiple times to add participants incrementally """ - self._participants.update(participants) + if self._participants: + raise ValueError("participants have already been set. Call participants(...) at most once.") + + if not participants: + raise ValueError("participants cannot be empty.") + + # Name of the executor mapped to participant instance + named: dict[str, AgentProtocol | Executor] = {} + for participant in participants: + if isinstance(participant, Executor): + identifier = participant.id + elif isinstance(participant, AgentProtocol): + if not participant.name: + raise ValueError("AgentProtocol participants must have a non-empty name.") + identifier = participant.name + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) + + if identifier in named: + raise ValueError(f"Duplicate participant name '{identifier}' detected") + + named[identifier] = participant + + self._participants = named + return self def with_plan_review(self, enable: bool = True) -> "MagenticBuilder": @@ -2249,67 +1491,6 @@ def with_plan_review(self, enable: bool = True) -> "MagenticBuilder": self._enable_plan_review = enable return self - def with_human_input_on_stall(self, enable: bool = True) -> "MagenticBuilder": - """Enable human intervention when the workflow detects a stall. - - When enabled, instead of automatically replanning when the workflow detects - that agents are not making progress or are stuck in a loop, the workflow will - pause and emit a MagenticStallInterventionRequest event. A human can then - decide to continue, trigger replanning, or provide guidance. - - This is useful for: - - Workflows where automatic replanning may not resolve the issue - - Scenarios requiring human judgment about workflow direction - - Debugging stuck workflows with human insight - - Complex tasks where human guidance can help agents get back on track - - When stall detection triggers (based on max_stall_count), instead of calling - _reset_and_replan automatically, the workflow will: - 1. Emit a MagenticHumanInterventionRequest with kind=STALL - 2. Wait for human response via send_responses_streaming - 3. Take action based on the human's decision (continue, replan, or guidance) - - Args: - enable: Whether to enable stall intervention (default True) - - Returns: - Self for method chaining - - Usage: - - .. code-block:: python - - workflow = ( - MagenticBuilder() - .participants(agent1=agent1) - .with_standard_manager(agent=manager_agent, max_stall_count=3) - .with_human_input_on_stall(enable=True) - .build() - ) - - # During execution, handle human intervention requests - async for event in workflow.run_stream("task"): - if isinstance(event, RequestInfoEvent): - if event.request_type is MagenticHumanInterventionRequest: - request = event.data - if request.kind == MagenticHumanInterventionKind.STALL: - print(f"Workflow stalled: {request.stall_reason}") - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.GUIDANCE, - comments="Focus on completing the current step first", - ) - responses = {event.request_id: reply} - async for ev in workflow.send_responses_streaming(responses): - ... - - See Also: - - :class:`MagenticHumanInterventionRequest`: Unified request type - - :class:`MagenticHumanInterventionDecision`: Decision options - - :meth:`with_standard_manager`: Configure max_stall_count for stall detection - """ - self._enable_stall_intervention = enable - return self - def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "MagenticBuilder": """Enable workflow state persistence using the provided checkpoint storage. @@ -2333,7 +1514,7 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "Magentic storage = InMemoryCheckpointStorage() workflow = ( MagenticBuilder() - .participants(agent1=agent1) + .participants([agent1]) .with_standard_manager(agent=manager_agent) .with_checkpointing(storage) .build() @@ -2506,6 +1687,36 @@ async def plan(self, context: MagenticContext) -> ChatMessage: ) return self + def _resolve_orchestrator(self, participants: Sequence[Executor]) -> Executor: + """Determine the orchestrator to use for the workflow. + + Args: + participants: List of resolved participant executors + """ + if self._manager is None: + raise ValueError("No manager configured. Call with_standard_manager(...) before building the orchestrator.") + + return MagenticOrchestrator( + manager=self._manager, + participant_registry=ParticipantRegistry(participants), + require_plan_signoff=self._enable_plan_review, + ) + + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + executors: list[Executor] = [] + for participant in self._participants.values(): + if isinstance(participant, Executor): + executors.append(participant) + elif isinstance(participant, AgentProtocol): + executors.append(MagenticAgentExecutor(participant)) + else: + raise TypeError( + f"Participants must be AgentProtocol or Executor instances. Got {type(participant).__name__}." + ) + + return executors + def build(self) -> Workflow: """Build a Magentic workflow with the orchestrator and all agent executors.""" if not self._participants: @@ -2516,305 +1727,19 @@ def build(self) -> Workflow: logger.info(f"Building Magentic workflow with {len(self._participants)} participants") - # Create participant descriptions - participant_descriptions: dict[str, str] = {} - for name, participant in self._participants.items(): - fallback = f"Executor {name}" if isinstance(participant, Executor) else f"Agent {name}" - participant_descriptions[name] = participant_description(participant, fallback) - - # Type narrowing: we already checked self._manager is not None above - manager: MagenticManagerBase = self._manager # type: ignore[assignment] - enable_stall_intervention = self._enable_stall_intervention - - def _orchestrator_factory(wiring: _GroupChatConfig) -> Executor: - return MagenticOrchestratorExecutor( - manager=manager, - participants=participant_descriptions, - require_plan_signoff=self._enable_plan_review, - enable_stall_intervention=enable_stall_intervention, - executor_id="magentic_orchestrator", - ) - - def _participant_factory( - spec: GroupChatParticipantSpec, - wiring: _GroupChatConfig, - ) -> _GroupChatParticipantPipeline: - agent_executor = MagenticAgentExecutor( - spec.participant, - spec.name, - ) - orchestrator = wiring.orchestrator - if isinstance(orchestrator, MagenticOrchestratorExecutor): - orchestrator.register_agent_executor(spec.name, agent_executor) - return (agent_executor,) - - # Magentic provides its own orchestrator via custom factory, so no manager is needed - group_builder = GroupChatBuilder( - _orchestrator_factory=group_chat_orchestrator(_orchestrator_factory), - _participant_factory=_participant_factory, - ).participants(self._participants) + participants: list[Executor] = self._resolve_participants() + orchestrator: Executor = self._resolve_orchestrator(participants) + # Build workflow graph + workflow_builder = WorkflowBuilder().set_start_executor(orchestrator) + for participant in participants: + # Orchestrator and participant bi-directional edges + workflow_builder = workflow_builder.add_edge(orchestrator, participant) + workflow_builder = workflow_builder.add_edge(participant, orchestrator) if self._checkpoint_storage is not None: - group_builder = group_builder.with_checkpointing(self._checkpoint_storage) - - return group_builder.build() - - def start_with_string(self, task: str) -> "MagenticWorkflow": - """Build a Magentic workflow and return a wrapper with convenience methods for string tasks. - - Args: - task: The task description as a string. - - Returns: - A MagenticWorkflow wrapper that provides convenience methods for starting with strings. - """ - return MagenticWorkflow(self.build(), task) - - def start_with_message(self, task: ChatMessage) -> "MagenticWorkflow": - """Build a Magentic workflow and return a wrapper with convenience methods for ChatMessage tasks. - - Args: - task: The task as a ChatMessage. - - Returns: - A MagenticWorkflow wrapper that provides convenience methods. - """ - return MagenticWorkflow(self.build(), task.text) - - def start_with(self, task: str | ChatMessage) -> "MagenticWorkflow": - """Build a Magentic workflow and return a wrapper with convenience methods. - - Args: - task: The task description as a string or ChatMessage. + workflow_builder = workflow_builder.with_checkpointing(self._checkpoint_storage) - Returns: - A MagenticWorkflow wrapper that provides convenience methods. - """ - if isinstance(task, str): - return self.start_with_string(task) - return self.start_with_message(task) + return workflow_builder.build() # endregion Magentic Workflow Builder - - -# region Magentic Workflow - - -class MagenticWorkflow: - """A wrapper around the base Workflow that provides convenience methods for Magentic workflows.""" - - def __init__(self, workflow: Workflow, task_text: str | None = None): - self._workflow = workflow - self._task_text = task_text - - @property - def workflow(self) -> Workflow: - """Access the underlying workflow.""" - return self._workflow - - async def run_streaming_with_string(self, task_text: str, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: - """Run the workflow with a task string. - - Args: - task_text: The task description as a string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - These kwargs will be available in @ai_function tools via **kwargs. - - Yields: - WorkflowEvent: The events generated during the workflow execution. - """ - start_message = _MagenticStartMessage.from_string(task_text) - start_message.run_kwargs = kwargs - async for event in self._workflow.run_stream(start_message): - yield event - - async def run_streaming_with_message( - self, task_message: ChatMessage, **kwargs: Any - ) -> AsyncIterable[WorkflowEvent]: - """Run the workflow with a ChatMessage. - - Args: - task_message: The task as a ChatMessage. - **kwargs: Additional keyword arguments to pass through to agent invocations. - These kwargs will be available in @ai_function tools via **kwargs. - - Yields: - WorkflowEvent: The events generated during the workflow execution. - """ - start_message = _MagenticStartMessage(task_message, run_kwargs=kwargs) - async for event in self._workflow.run_stream(start_message): - yield event - - async def run_stream(self, message: Any | None = None, **kwargs: Any) -> AsyncIterable[WorkflowEvent]: - """Run the workflow with either a message object or the preset task string. - - Args: - message: The message to send. If None and task_text was provided during construction, - uses the preset task string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - These kwargs will be available in @ai_function tools via **kwargs. - Example: workflow.run_stream("task", user_id="123", custom_data={...}) - - Yields: - WorkflowEvent: The events generated during the workflow execution. - """ - if message is None: - if self._task_text is None: - raise ValueError("No message provided and no preset task text available") - start_message = _MagenticStartMessage.from_string(self._task_text) - elif isinstance(message, str): - start_message = _MagenticStartMessage.from_string(message) - elif isinstance(message, (ChatMessage, list)): - start_message = _MagenticStartMessage(message) # type: ignore[arg-type] - else: - start_message = message - - # Attach kwargs to the start message - if isinstance(start_message, _MagenticStartMessage): - start_message.run_kwargs = kwargs - - async for event in self._workflow.run_stream(start_message): - yield event - - async def _validate_checkpoint_participants( - self, - checkpoint_id: str, - checkpoint_storage: CheckpointStorage | None = None, - ) -> None: - """Ensure participant roster matches the checkpoint before attempting restoration.""" - orchestrator = next( - ( - executor - for executor in self._workflow.executors.values() - if isinstance(executor, MagenticOrchestratorExecutor) - ), - None, - ) - if orchestrator is None: - return - - expected = getattr(orchestrator, "_participants", None) - if not expected: - return - - checkpoint: WorkflowCheckpoint | None = None - if checkpoint_storage is not None: - try: - checkpoint = await checkpoint_storage.load_checkpoint(checkpoint_id) - except Exception: # pragma: no cover - best effort - checkpoint = None - - if checkpoint is None: - runner_context = getattr(self._workflow, "_runner_context", None) - has_checkpointing = getattr(runner_context, "has_checkpointing", None) - load_checkpoint = getattr(runner_context, "load_checkpoint", None) - try: - if callable(has_checkpointing) and has_checkpointing() and callable(load_checkpoint): - loaded_checkpoint = await load_checkpoint(checkpoint_id) # type: ignore[misc] - if loaded_checkpoint is not None: - checkpoint = cast(WorkflowCheckpoint, loaded_checkpoint) - except Exception: # pragma: no cover - best effort - checkpoint = None - - if checkpoint is None: - return - - # At this point, checkpoint is guaranteed to be WorkflowCheckpoint - executor_states = cast(dict[str, Any], checkpoint.shared_state.get(EXECUTOR_STATE_KEY, {})) - orchestrator_id = getattr(orchestrator, "id", "") - orchestrator_state = cast(Any, executor_states.get(orchestrator_id)) - if orchestrator_state is None: - orchestrator_state = cast(Any, executor_states.get("magentic_orchestrator")) - - if not isinstance(orchestrator_state, dict): - return - - orchestrator_state_dict = cast(dict[str, Any], orchestrator_state) - context_payload = cast(Any, orchestrator_state_dict.get("magentic_context")) - if not isinstance(context_payload, dict): - return - - context_dict = cast(dict[str, Any], context_payload) - restored_participants = cast(Any, context_dict.get("participant_descriptions")) - if not isinstance(restored_participants, dict): - return - - participants_dict = cast(dict[str, str], restored_participants) - restored_names: set[str] = set(participants_dict.keys()) - expected_names = set(expected.keys()) - - if restored_names == expected_names: - return - - missing = ", ".join(sorted(expected_names - restored_names)) or "none" - unexpected = ", ".join(sorted(restored_names - expected_names)) or "none" - raise RuntimeError( - "Magentic checkpoint restore failed: participant names do not match the checkpoint. " - "Ensure MagenticBuilder.participants keys remain stable across runs. " - f"Missing names: {missing}; unexpected names: {unexpected}." - ) - - async def run_with_string(self, task_text: str, **kwargs: Any) -> WorkflowRunResult: - """Run the workflow with a task string and return all events. - - Args: - task_text: The task description as a string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - - Returns: - WorkflowRunResult: All events generated during the workflow execution. - """ - events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_string(task_text, **kwargs): - events.append(event) - return WorkflowRunResult(events) - - async def run_with_message(self, task_message: ChatMessage, **kwargs: Any) -> WorkflowRunResult: - """Run the workflow with a ChatMessage and return all events. - - Args: - task_message: The task as a ChatMessage. - **kwargs: Additional keyword arguments to pass through to agent invocations. - - Returns: - WorkflowRunResult: All events generated during the workflow execution. - """ - events: list[WorkflowEvent] = [] - async for event in self.run_streaming_with_message(task_message, **kwargs): - events.append(event) - return WorkflowRunResult(events) - - async def run(self, message: Any | None = None, **kwargs: Any) -> WorkflowRunResult: - """Run the workflow and return all events. - - Args: - message: The message to send. If None and task_text was provided during construction, - uses the preset task string. - **kwargs: Additional keyword arguments to pass through to agent invocations. - - Returns: - WorkflowRunResult: All events generated during the workflow execution. - """ - events: list[WorkflowEvent] = [] - async for event in self.run_stream(message, **kwargs): - events.append(event) - return WorkflowRunResult(events) - - def __getattr__(self, name: str) -> Any: - """Delegate unknown attributes to the underlying workflow.""" - return getattr(self._workflow, name) - - -# endregion Magentic Workflow - -# Public aliases for unified human intervention types -MagenticHumanInterventionRequest = _MagenticHumanInterventionRequest -MagenticHumanInterventionReply = _MagenticHumanInterventionReply - -# Backward compatibility aliases (deprecated) -# Old aliases - point to unified types for compatibility -MagenticHumanInputRequest = _MagenticHumanInterventionRequest # type: ignore -MagenticStallInterventionRequest = _MagenticHumanInterventionRequest # type: ignore -MagenticStallInterventionReply = _MagenticHumanInterventionReply # type: ignore -MagenticStallInterventionDecision = MagenticHumanInterventionDecision # type: ignore diff --git a/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py b/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py index 91c9ec799a..dc1e282a12 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py +++ b/python/packages/core/agent_framework/_workflows/_orchestration_request_info.py @@ -1,37 +1,20 @@ # Copyright (c) Microsoft. All rights reserved. -"""Request info support for high-level builder APIs. - -This module provides a mechanism for pausing workflows to request external input -before agent turns in `SequentialBuilder`, `ConcurrentBuilder`, `GroupChatBuilder`, -and `HandoffBuilder`. - -The design follows the standard `request_info` pattern used throughout the -workflow system, keeping the API consistent and predictable. - -Key components: -- AgentInputRequest: Request type emitted via RequestInfoEvent for pre-agent steering -- RequestInfoInterceptor: Internal executor that pauses workflow before agent runs -""" - -import logging -import uuid -from dataclasses import dataclass, field -from typing import Any +from dataclasses import dataclass from .._agents import AgentProtocol from .._types import ChatMessage, Role -from ._agent_executor import AgentExecutorRequest +from ._agent_executor import AgentExecutor, AgentExecutorRequest, AgentExecutorResponse +from ._agent_utils import resolve_agent_id from ._executor import Executor, handler from ._request_info_mixin import response_handler +from ._workflow import Workflow +from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext - -logger = logging.getLogger(__name__) +from ._workflow_executor import WorkflowExecutor -def resolve_request_info_filter( - agents: list[str | AgentProtocol | Executor] | None, -) -> set[str] | None: +def resolve_request_info_filter(agents: list[str | AgentProtocol] | None) -> set[str]: """Resolve a list of agent/executor references to a set of IDs for filtering. Args: @@ -42,288 +25,122 @@ def resolve_request_info_filter( Set of executor/agent IDs to filter on, or None if no filtering. """ if agents is None: - return None + return set() result: set[str] = set() for agent in agents: if isinstance(agent, str): result.add(agent) - elif isinstance(agent, Executor): - result.add(agent.id) elif isinstance(agent, AgentProtocol): - name = getattr(agent, "name", None) - if name: - result.add(name) - else: - logger.warning("AgentProtocol without name cannot be used for request_info filtering") + result.add(resolve_agent_id(agent)) else: - logger.warning(f"Unsupported type for request_info filter: {type(agent).__name__}") + raise TypeError(f"Unsupported type for request_info filter: {type(agent).__name__}") - return result if result else None + return result @dataclass -class AgentInputRequest: - """Request for human input before an agent runs in high-level builder workflows. - - Emitted via RequestInfoEvent when a workflow pauses before an agent executes. - The response is injected into the conversation as a user message to steer - the agent's behavior. - - This is the standard request type used by `.with_request_info()` on - SequentialBuilder, ConcurrentBuilder, GroupChatBuilder, and HandoffBuilder. +class AgentRequestInfoResponse: + """Response containing additional information requested from users for agents. Attributes: - target_agent_id: ID of the agent that is about to run - conversation: Current conversation history the agent will receive - instruction: Optional instruction from the orchestrator (e.g., manager in GroupChat) - metadata: Builder-specific context (stores internal state for resume) + messages: list[ChatMessage]: Additional messages provided by users. If empty, + the agent response is approved as-is. """ - target_agent_id: str | None - conversation: list[ChatMessage] = field(default_factory=lambda: []) - instruction: str | None = None - metadata: dict[str, Any] = field(default_factory=lambda: {}) - - -# Keep legacy name as alias for backward compatibility -AgentResponseReviewRequest = AgentInputRequest + messages: list[ChatMessage] + @staticmethod + def from_messages(messages: list[ChatMessage]) -> "AgentRequestInfoResponse": + """Create an AgentRequestInfoResponse from a list of ChatMessages. -DEFAULT_REQUEST_INFO_ID = "request_info_interceptor" - + Args: + messages: List of ChatMessage instances provided by users. -class RequestInfoInterceptor(Executor): - """Internal executor that pauses workflow for human input before agent runs. + Returns: + AgentRequestInfoResponse instance. + """ + return AgentRequestInfoResponse(messages=messages) - This executor is inserted into the workflow graph by builders when - `.with_request_info()` is called. It intercepts AgentExecutorRequest messages - BEFORE the agent runs and pauses the workflow via `ctx.request_info()` with - an AgentInputRequest. + @staticmethod + def from_strings(texts: list[str]) -> "AgentRequestInfoResponse": + """Create an AgentRequestInfoResponse from a list of string messages. - When a response is received, the response handler injects the input - as a user message into the conversation and forwards the request to the agent. + Args: + texts: List of text messages provided by users. - The optional `agent_filter` parameter allows limiting which agents trigger the pause. - If the target agent's ID is not in the filter set, the request is forwarded - without pausing. - """ + Returns: + AgentRequestInfoResponse instance. + """ + return AgentRequestInfoResponse(messages=[ChatMessage(role=Role.USER, text=text) for text in texts]) - def __init__( - self, - executor_id: str | None = None, - agent_filter: set[str] | None = None, - ) -> None: - """Initialize the request info interceptor executor. + @staticmethod + def approve() -> "AgentRequestInfoResponse": + """Create an AgentRequestInfoResponse that approves the original agent response. - Args: - executor_id: ID for this executor. If None, generates a unique ID - using the format "request_info_interceptor-". - agent_filter: Optional set of agent/executor IDs to filter on. - If provided, only requests to these agents trigger a pause. - If None (default), all requests trigger a pause. - """ - if executor_id is None: - executor_id = f"{DEFAULT_REQUEST_INFO_ID}-{uuid.uuid4().hex[:8]}" - super().__init__(executor_id) - self._agent_filter = agent_filter - - def _should_pause_for_agent(self, agent_id: str | None) -> bool: - """Check if we should pause for the given agent ID.""" - if self._agent_filter is None: - return True - if agent_id is None: - return False - # Check both the full ID and any name portion after a prefix - # e.g., "groupchat_agent:writer" should match filter "writer" - if agent_id in self._agent_filter: - return True - # Extract name from prefixed IDs like "groupchat_agent:writer" or "request_info:writer" - if ":" in agent_id: - name_part = agent_id.split(":", 1)[1] - if name_part in self._agent_filter: - return True - return False - - def _extract_agent_name_from_executor_id(self) -> str | None: - """Extract the agent name from this interceptor's executor ID. - - The interceptor ID is typically "request_info:", so we - extract the agent name to determine which agent we're intercepting for. + Returns: + AgentRequestInfoResponse instance with no additional messages. """ - if ":" in self.id: - return self.id.split(":", 1)[1] - return None - - @handler - async def intercept_agent_request( - self, - request: AgentExecutorRequest, - ctx: WorkflowContext[AgentExecutorRequest, Any], - ) -> None: - """Intercept request before agent runs and pause for human input. + return AgentRequestInfoResponse(messages=[]) - Pauses the workflow and emits a RequestInfoEvent with the current - conversation for steering. If an agent filter is configured and this - agent is not in the filter, the request is forwarded without pausing. - Args: - request: The request about to be sent to the agent - ctx: Workflow context for requesting info - """ - # Determine the target agent from our executor ID - target_agent = self._extract_agent_name_from_executor_id() - - # Check if we should pause for this agent - if not self._should_pause_for_agent(target_agent): - logger.debug(f"Skipping request_info pause for agent {target_agent} (not in filter)") - await ctx.send_message(request) - return - - conversation = list(request.messages or []) - - input_request = AgentInputRequest( - target_agent_id=target_agent, - conversation=conversation, - instruction=None, # Could be extended to include manager instruction - metadata={"_original_request": request, "_input_type": "AgentExecutorRequest"}, - ) - await ctx.request_info(input_request, str) +class AgentRequestInfoExecutor(Executor): + """Executor for gathering request info from users to assist agents.""" @handler - async def intercept_conversation( + async def request_info(self, agent_response: AgentExecutorResponse, ctx: WorkflowContext) -> None: + """Handle the agent's response and gather additional info from users.""" + await ctx.request_info(agent_response, AgentRequestInfoResponse) + + @response_handler + async def handle_request_info_response( self, - messages: list[ChatMessage], - ctx: WorkflowContext[list[ChatMessage], Any], + original_request: AgentExecutorResponse, + response: AgentRequestInfoResponse, + ctx: WorkflowContext[AgentExecutorRequest, AgentExecutorResponse], ) -> None: - """Intercept conversation before agent runs (used by SequentialBuilder). + """Process the additional info provided by users.""" + if response.messages: + # User provided additional messages, further iterate on agent response + await ctx.send_message(AgentExecutorRequest(messages=response.messages, should_respond=True)) + else: + # No additional info, approve original agent response + await ctx.yield_output(original_request) - SequentialBuilder passes list[ChatMessage] directly to agents. This handler - intercepts that flow and pauses for human input. - Args: - messages: The conversation about to be sent to the agent - ctx: Workflow context for requesting info - """ - # Determine the target agent from our executor ID - target_agent = self._extract_agent_name_from_executor_id() - - # Check if we should pause for this agent - if not self._should_pause_for_agent(target_agent): - logger.debug(f"Skipping request_info pause for agent {target_agent} (not in filter)") - await ctx.send_message(messages) - return - - input_request = AgentInputRequest( - target_agent_id=target_agent, - conversation=list(messages), - instruction=None, - metadata={"_original_messages": messages, "_input_type": "list[ChatMessage]"}, - ) - await ctx.request_info(input_request, str) +class AgentApprovalExecutor(WorkflowExecutor): + """Executor for enabling scenarios requiring agent approval in an orchestration. - @handler - async def intercept_concurrent_requests( - self, - requests: list[AgentExecutorRequest], - ctx: WorkflowContext[list[AgentExecutorRequest], Any], - ) -> None: - """Intercept requests before concurrent agents run. + This executor wraps a sub workflow that contains two executors: an agent executor + and an request info executor. The agent executor provides intelligence generation, + while the request info executor gathers input from users to further iterate on the + agent's output or send the final response to down stream executors in the orchestration. + """ - This handler is used by ConcurrentBuilder to get human input before - all parallel agents execute. + def __init__(self, agent: AgentProtocol) -> None: + """Initialize the AgentApprovalExecutor. Args: - requests: List of requests for all concurrent agents - ctx: Workflow context for requesting info + agent: The agent protocol to use for generating responses. """ - # Combine conversations for display - combined_conversation: list[ChatMessage] = [] - if requests: - combined_conversation = list(requests[0].messages or []) - - input_request = AgentInputRequest( - target_agent_id=None, # Multiple agents - conversation=combined_conversation, - instruction=None, - metadata={"_original_requests": requests}, + super().__init__(workflow=self._build_workflow(agent), id=resolve_agent_id(agent), propagate_request=True) + self._description = agent.description + + def _build_workflow(self, agent: AgentProtocol) -> Workflow: + """Build the internal workflow for the AgentApprovalExecutor.""" + agent_executor = AgentExecutor(agent) + request_info_executor = AgentRequestInfoExecutor(id="agent_request_info_executor") + + return ( + WorkflowBuilder() + # Create a loop between agent executor and request info executor + .add_edge(agent_executor, request_info_executor) + .add_edge(request_info_executor, agent_executor) + .set_start_executor(agent_executor) + .build() ) - await ctx.request_info(input_request, str) - @response_handler - async def handle_input_response( - self, - original_request: AgentInputRequest, - # TODO(@moonbox3): Extend to support other content types - response: str, - ctx: WorkflowContext[AgentExecutorRequest | list[ChatMessage], Any], - ) -> None: - """Handle the human input and forward the modified request to the agent. - - Injects the response as a user message into the conversation - and forwards the modified request to the agent. - - Args: - original_request: The AgentInputRequest that triggered the pause - response: The human input text - ctx: Workflow context for continuing the workflow - - TODO: Consider having each orchestration implement its own response handler - for more specialized behavior. - """ - human_message = ChatMessage(role=Role.USER, text=response) - - # Handle concurrent case (list of AgentExecutorRequest) - original_requests: list[AgentExecutorRequest] | None = original_request.metadata.get("_original_requests") - if original_requests is not None: - updated_requests: list[AgentExecutorRequest] = [] - for orig_req in original_requests: - messages = list(orig_req.messages or []) - messages.append(human_message) - updated_requests.append( - AgentExecutorRequest( - messages=messages, - should_respond=orig_req.should_respond, - ) - ) - - logger.debug( - f"Human input received for concurrent workflow, " - f"continuing with {len(updated_requests)} updated requests" - ) - await ctx.send_message(updated_requests) # type: ignore[arg-type] - return - - # Handle list[ChatMessage] case (SequentialBuilder) - original_messages: list[ChatMessage] | None = original_request.metadata.get("_original_messages") - if original_messages is not None: - messages = list(original_messages) - messages.append(human_message) - - logger.debug( - f"Human input received for agent {original_request.target_agent_id}, " - f"forwarding conversation with steering context" - ) - await ctx.send_message(messages) - return - - # Handle AgentExecutorRequest case (GroupChatBuilder) - orig_request: AgentExecutorRequest | None = original_request.metadata.get("_original_request") - if orig_request is not None: - messages = list(orig_request.messages or []) - messages.append(human_message) - - updated_request = AgentExecutorRequest( - messages=messages, - should_respond=orig_request.should_respond, - ) - - logger.debug( - f"Human input received for agent {original_request.target_agent_id}, " - f"forwarding request with steering context" - ) - await ctx.send_message(updated_request) - return - - logger.error("Input response handler missing original request/messages in metadata") - raise RuntimeError("Missing original request or messages in AgentInputRequest metadata") + @property + def description(self) -> str | None: + """Get a description of the underlying agent.""" + return self._description diff --git a/python/packages/core/agent_framework/_workflows/_orchestration_state.py b/python/packages/core/agent_framework/_workflows/_orchestration_state.py index 26c0068e7a..8210d7d4bb 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestration_state.py +++ b/python/packages/core/agent_framework/_workflows/_orchestration_state.py @@ -47,6 +47,7 @@ class OrchestrationState: conversation: list[ChatMessage] = field(default_factory=_new_chat_message_list) round_index: int = 0 + orchestrator_name: str = "" metadata: dict[str, Any] = field(default_factory=_new_metadata_dict) task: ChatMessage | None = None diff --git a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py index 14fd68fa46..edcffaa530 100644 --- a/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py +++ b/python/packages/core/agent_framework/_workflows/_orchestrator_helpers.py @@ -7,13 +7,9 @@ """ import logging -from typing import TYPE_CHECKING, Any from .._types import ChatMessage, Role -if TYPE_CHECKING: - from ._group_chat import _GroupChatRequestMessage # type: ignore[reportPrivateUsage] - logger = logging.getLogger(__name__) @@ -99,107 +95,3 @@ def create_completion_message( text=message_text, author_name=author_name, ) - - -def prepare_participant_request( - *, - participant_name: str, - conversation: list[ChatMessage], - instruction: str | None = None, - task: ChatMessage | None = None, - metadata: dict[str, Any] | None = None, -) -> "_GroupChatRequestMessage": - """Create a standardized participant request message. - - Simple helper to avoid duplicating request construction. - - Args: - participant_name: Name of the target participant - conversation: Conversation history to send - instruction: Optional instruction from manager/orchestrator - task: Optional task context - metadata: Optional metadata dict - - Returns: - GroupChatRequestMessage ready to send - """ - # Import here to avoid circular dependency - from ._group_chat import _GroupChatRequestMessage # type: ignore[reportPrivateUsage] - - return _GroupChatRequestMessage( - agent_name=participant_name, - conversation=list(conversation), - instruction=instruction or "", - task=task, - metadata=metadata, - ) - - -class ParticipantRegistry: - """Simple registry for tracking participant executor IDs and routing info. - - Provides a clean interface for the common pattern of mapping participant names - to executor IDs and tracking which are agents vs custom executors. - - Tracks both entry IDs (where to send requests) and exit IDs (where responses - come from) to support pipeline configurations where these differ. - """ - - def __init__(self) -> None: - self._participant_entry_ids: dict[str, str] = {} - self._agent_executor_ids: dict[str, str] = {} - self._executor_id_to_participant: dict[str, str] = {} - self._non_agent_participants: set[str] = set() - - def register( - self, - name: str, - *, - entry_id: str, - is_agent: bool, - exit_id: str | None = None, - ) -> None: - """Register a participant's routing information. - - Args: - name: Participant name - entry_id: Executor ID for this participant's entry point (where to send) - is_agent: Whether this is an AgentExecutor (True) or custom Executor (False) - exit_id: Executor ID for this participant's exit point (where responses come from). - If None, defaults to entry_id (single-executor pipeline). - """ - self._participant_entry_ids[name] = entry_id - actual_exit_id = exit_id if exit_id is not None else entry_id - - if is_agent: - self._agent_executor_ids[name] = entry_id - # Map both entry and exit IDs to participant name for response routing - self._executor_id_to_participant[entry_id] = name - if actual_exit_id != entry_id: - self._executor_id_to_participant[actual_exit_id] = name - else: - self._non_agent_participants.add(name) - - def get_entry_id(self, name: str) -> str | None: - """Get the entry executor ID for a participant name.""" - return self._participant_entry_ids.get(name) - - def get_participant_name(self, executor_id: str) -> str | None: - """Get the participant name for an executor ID (agents only).""" - return self._executor_id_to_participant.get(executor_id) - - def is_agent(self, name: str) -> bool: - """Check if a participant is an agent (vs custom executor).""" - return name in self._agent_executor_ids - - def is_registered(self, name: str) -> bool: - """Check if a participant is registered.""" - return name in self._participant_entry_ids - - def is_participant_registered(self, name: str) -> bool: - """Check if a participant is registered (alias for is_registered for compatibility).""" - return self.is_registered(name) - - def all_participants(self) -> set[str]: - """Get all registered participant names.""" - return set(self._participant_entry_ids.keys()) diff --git a/python/packages/core/agent_framework/_workflows/_participant_utils.py b/python/packages/core/agent_framework/_workflows/_participant_utils.py deleted file mode 100644 index 2d0a259ac8..0000000000 --- a/python/packages/core/agent_framework/_workflows/_participant_utils.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Shared participant helpers for orchestration builders.""" - -import re -from collections.abc import Callable, Iterable, Mapping -from dataclasses import dataclass -from typing import Any - -from .._agents import AgentProtocol -from ._agent_executor import AgentExecutor -from ._executor import Executor - - -@dataclass -class GroupChatParticipantSpec: - """Metadata describing a single participant in group chat orchestrations. - - Used by multiple orchestration patterns (GroupChat, Handoff, Magentic) to describe - participants with consistent structure across different workflow types. - - Attributes: - name: Unique identifier for the participant used by managers for selection - participant: AgentProtocol or Executor instance representing the participant - description: Human-readable description provided to managers for selection context - """ - - name: str - participant: AgentProtocol | Executor - description: str - - -_SANITIZE_PATTERN = re.compile(r"[^0-9a-zA-Z]+") - - -def sanitize_identifier(value: str, *, default: str = "agent") -> str: - """Return a deterministic, lowercase identifier derived from `value`.""" - cleaned = _SANITIZE_PATTERN.sub("_", value).strip("_") - if not cleaned: - cleaned = default - if cleaned[0].isdigit(): - cleaned = f"{default}_{cleaned}" - return cleaned.lower() - - -def wrap_participant(participant: AgentProtocol | Executor, *, executor_id: str | None = None) -> Executor: - """Represent `participant` as an `Executor`.""" - if isinstance(participant, Executor): - return participant - - if not isinstance(participant, AgentProtocol): - raise TypeError( - f"Participants must implement AgentProtocol or be Executor instances. Got {type(participant).__name__}." - ) - - executor_id = executor_id or participant.name or participant.id - return AgentExecutor(participant, id=executor_id) - - -def participant_description(participant: AgentProtocol | Executor, fallback: str) -> str: - """Produce a human-readable description for manager context.""" - if isinstance(participant, Executor): - description = getattr(participant, "description", None) - if isinstance(description, str) and description.strip(): - return description.strip() - return fallback - description = getattr(participant, "description", None) - if isinstance(description, str) and description.strip(): - return description.strip() - return fallback - - -def build_alias_map(participant: AgentProtocol | Executor, executor: Executor) -> dict[str, str]: - """Collect canonical and sanitised aliases that should resolve to `executor`.""" - aliases: dict[str, str] = {} - - def _register(values: Iterable[str | None]) -> None: - for value in values: - if not value: - continue - key = str(value) - if key not in aliases: - aliases[key] = executor.id - sanitized = sanitize_identifier(key) - if sanitized not in aliases: - aliases[sanitized] = executor.id - - _register([executor.id]) - - if isinstance(participant, AgentProtocol): - name = getattr(participant, "name", None) - agent_id = getattr(participant, "id", None) - _register([name, agent_id]) - else: - participant_id = getattr(participant, "id", None) - _register([participant_id]) - - return aliases - - -def merge_alias_maps(maps: Iterable[Mapping[str, str]]) -> dict[str, str]: - """Merge alias mappings, preserving the first occurrence of each alias.""" - merged: dict[str, str] = {} - for mapping in maps: - for key, value in mapping.items(): - merged.setdefault(key, value) - return merged - - -def prepare_participant_metadata( - participants: Mapping[str, AgentProtocol | Executor], - *, - executor_id_factory: Callable[[str, AgentProtocol | Executor], str | None] | None = None, - description_factory: Callable[[str, AgentProtocol | Executor], str] | None = None, -) -> dict[str, dict[str, Any]]: - """Return metadata dicts for participants keyed by participant name.""" - executors: dict[str, Executor] = {} - descriptions: dict[str, str] = {} - alias_maps: list[Mapping[str, str]] = [] - - for name, participant in participants.items(): - desired_id = executor_id_factory(name, participant) if executor_id_factory else None - executor = wrap_participant(participant, executor_id=desired_id) - fallback_description = description_factory(name, participant) if description_factory else executor.id - descriptions[name] = participant_description(participant, fallback_description) - executors[name] = executor - alias_maps.append(build_alias_map(participant, executor)) - - aliases = merge_alias_maps(alias_maps) - return { - "executors": executors, - "descriptions": descriptions, - "aliases": aliases, - } diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 00318d7021..db7bada00d 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -13,6 +13,7 @@ from ._const import INTERNAL_SOURCE_ID from ._events import RequestInfoEvent, WorkflowEvent from ._shared_state import SharedState +from ._typing_utils import is_instance_of logger = logging.getLogger(__name__) @@ -44,7 +45,7 @@ class Message: source_span_ids: list[str] | None = None # Publishing span IDs for linking from multiple sources # For response messages, the original request data - original_request: Any = None + original_request_info_event: RequestInfoEvent | None = None # Backward compatibility properties @property @@ -66,7 +67,7 @@ def to_dict(self) -> dict[str, Any]: "type": self.type.value, "trace_contexts": self.trace_contexts, "source_span_ids": self.source_span_ids, - "original_request": self.original_request, + "original_request_info_event": encode_checkpoint_value(self.original_request_info_event), } @staticmethod @@ -86,7 +87,7 @@ def from_dict(data: dict[str, Any]) -> "Message": type=MessageType(data.get("type", "standard")), trace_contexts=data.get("trace_contexts"), source_span_ids=data.get("source_span_ids"), - original_request=data.get("original_request"), + original_request_info_event=decode_checkpoint_value(data.get("original_request_info_event")), ) @@ -493,7 +494,7 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No raise ValueError(f"No pending request found for request_id: {request_id}") # Validate response type if specified - if event.response_type and not isinstance(response, event.response_type): + if event.response_type and not is_instance_of(response, event.response_type): raise TypeError( f"Response type mismatch for request_id {request_id}: " f"expected {event.response_type.__name__}, got {type(response).__name__}" @@ -505,7 +506,7 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No source_id=INTERNAL_SOURCE_ID(event.source_executor_id), target_id=event.source_executor_id, type=MessageType.RESPONSE, - original_request=event.data, + original_request_info_event=event, ) await self.send_message(response_msg) diff --git a/python/packages/core/agent_framework/_workflows/_sequential.py b/python/packages/core/agent_framework/_workflows/_sequential.py index 0c394574fa..11c123d153 100644 --- a/python/packages/core/agent_framework/_workflows/_sequential.py +++ b/python/packages/core/agent_framework/_workflows/_sequential.py @@ -47,13 +47,14 @@ AgentExecutor, AgentExecutorResponse, ) +from ._agent_utils import resolve_agent_id from ._checkpoint import CheckpointStorage from ._executor import ( Executor, handler, ) from ._message_utils import normalize_messages_input -from ._orchestration_request_info import RequestInfoInterceptor +from ._orchestration_request_info import AgentApprovalExecutor from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext @@ -77,24 +78,33 @@ async def from_messages(self, messages: list[str | ChatMessage], ctx: WorkflowCo await ctx.send_message(normalize_messages_input(messages)) -class _ResponseToConversation(Executor): - """Converts AgentExecutorResponse to list[ChatMessage] conversation for chaining.""" - - @handler - async def convert(self, response: AgentExecutorResponse, ctx: WorkflowContext[list[ChatMessage]]) -> None: - # Always use full_conversation; AgentExecutor guarantees it is populated. - if response.full_conversation is None: # Defensive: indicates a contract violation - raise RuntimeError("AgentExecutorResponse.full_conversation missing. AgentExecutor must populate it.") - await ctx.send_message(list(response.full_conversation)) - - class _EndWithConversation(Executor): """Terminates the workflow by emitting the final conversation context.""" @handler - async def end(self, conversation: list[ChatMessage], ctx: WorkflowContext[Any, list[ChatMessage]]) -> None: + async def end_with_messages( + self, + conversation: list[ChatMessage], + ctx: WorkflowContext[Any, list[ChatMessage]], + ) -> None: + """Handler for ending with a list of ChatMessage. + + This is used when the last participant is a custom executor. + """ await ctx.yield_output(list(conversation)) + @handler + async def end_with_agent_executor_response( + self, + response: AgentExecutorResponse, + ctx: WorkflowContext[Any, list[ChatMessage] | None], + ) -> None: + """Handle case where last participant is an agent. + + The agent is wrapped by AgentExecutor and emits AgentExecutorResponse. + """ + await ctx.yield_output(response.full_conversation) + class SequentialBuilder: r"""High-level builder for sequential agent/executor workflows with shared context. @@ -206,44 +216,65 @@ def with_checkpointing(self, checkpoint_storage: CheckpointStorage) -> "Sequenti def with_request_info( self, *, - agents: Sequence[str | AgentProtocol | Executor] | None = None, + agents: Sequence[str | AgentProtocol] | None = None, ) -> "SequentialBuilder": - """Enable request info before agents run in the workflow. - - When enabled, the workflow pauses before each agent runs, emitting - a RequestInfoEvent that allows the caller to review the conversation and - optionally inject guidance before the agent responds. The caller provides - input via the standard response_handler/request_info pattern. + """Enable request info after agent participant responses. - Args: - agents: Optional filter - only pause before these specific agents/executors. - Accepts agent names (str), agent instances, or executor instances. - If None (default), pauses before every agent. - - Returns: - self: The builder instance for fluent chaining. + This enables human-in-the-loop (HIL) scenarios for the sequential orchestration. + When enabled, the workflow pauses after each agent participant runs, emitting + a RequestInfoEvent that allows the caller to review the conversation and optionally + inject guidance for the agent participant to iterate. The caller provides input via + the standard response_handler/request_info pattern. - Example: + Simulated flow with HIL: + Input -> [Agent Participant <-> Request Info] -> [Agent Participant <-> Request Info] -> ... - .. code-block:: python + Note: This is only available for agent participants. Executor participants can incorporate + request info handling in their own implementation if desired. - # Pause before all agents - workflow = SequentialBuilder().participants([a1, a2]).with_request_info().build() + Args: + agents: Optional list of agents names or agent factories to enable request info for. + If None, enables HIL for all agent participants. - # Pause only before specific agents - workflow = ( - SequentialBuilder() - .participants([drafter, reviewer, finalizer]) - .with_request_info(agents=[reviewer]) # Only pause before reviewer - .build() - ) + Returns: + Self for fluent chaining """ from ._orchestration_request_info import resolve_request_info_filter self._request_info_enabled = True self._request_info_filter = resolve_request_info_filter(list(agents) if agents else None) + return self + def _resolve_participants(self) -> list[Executor]: + """Resolve participant instances into Executor objects.""" + participants: list[Executor | AgentProtocol] = [] + if self._participant_factories: + # Resolve the participant factories now. This doesn't break the factory pattern + # since the Sequential builder still creates new instances per workflow build. + for factory in self._participant_factories: + p = factory() + participants.append(p) + else: + participants = self._participants + + executors: list[Executor] = [] + for p in participants: + if isinstance(p, Executor): + executors.append(p) + elif isinstance(p, AgentProtocol): + if self._request_info_enabled and ( + not self._request_info_filter or resolve_agent_id(p) in self._request_info_filter + ): + # Handle request info enabled agents + executors.append(AgentApprovalExecutor(p)) + else: + executors.append(AgentExecutor(p)) + else: + raise TypeError(f"Participants must be AgentProtocol or Executor instances. Got {type(p).__name__}.") + + return executors + def build(self) -> Workflow: """Build and validate the sequential workflow. @@ -272,48 +303,17 @@ def build(self) -> Workflow: input_conv = _InputToConversation(id="input-conversation") end = _EndWithConversation(id="end") + # Resolve participants and participant factories to executors + participants: list[Executor] = self._resolve_participants() + builder = WorkflowBuilder() builder.set_start_executor(input_conv) # Start of the chain is the input normalizer prior: Executor | AgentProtocol = input_conv - - participants: list[Executor | AgentProtocol] = [] - if self._participant_factories: - # Resolve the participant factories now. This doesn't break the factory pattern - # since the Sequential builder still creates new instances per workflow build. - for factory in self._participant_factories: - p = factory() - participants.append(p) - else: - participants = self._participants - for p in participants: - if isinstance(p, (AgentProtocol, AgentExecutor)): - label = p.id if isinstance(p, AgentExecutor) else p.name - - if self._request_info_enabled: - # Insert request info interceptor BEFORE the agent - interceptor = RequestInfoInterceptor( - executor_id=f"request_info:{label}", - agent_filter=self._request_info_filter, - ) - builder.add_edge(prior, interceptor) - builder.add_edge(interceptor, p) - else: - builder.add_edge(prior, p) - - resp_to_conv = _ResponseToConversation(id=f"to-conversation:{label}") - builder.add_edge(p, resp_to_conv) - prior = resp_to_conv - elif isinstance(p, Executor): - # Custom executor operates on list[ChatMessage] - # If the executor doesn't handle list[ChatMessage] correctly, validation will fail - builder.add_edge(prior, p) - prior = p - else: - raise TypeError(f"Unsupported participant type: {type(p).__name__}") - + builder.add_edge(prior, p) + prior = p # Terminate with the final conversation builder.add_edge(prior, end) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 7b446926fc..fd1e4f4552 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -5,7 +5,6 @@ import hashlib import json import logging -import sys import uuid from collections.abc import AsyncIterable, Awaitable, Callable from typing import Any @@ -34,12 +33,7 @@ from ._runner import Runner from ._runner_context import RunnerContext from ._shared_state import SharedState - -if sys.version_info >= (3, 11): - pass # pragma: no cover -else: - pass # pragma: no cover - +from ._typing_utils import is_instance_of logger = logging.getLogger(__name__) @@ -734,7 +728,7 @@ async def _send_responses_internal(self, responses: dict[str, Any]) -> None: if request_id not in pending_requests: raise ValueError(f"Response provided for unknown request ID: {request_id}") pending_request = pending_requests[request_id] - if not isinstance(response, pending_request.response_type): + if not is_instance_of(response, pending_request.response_type): raise ValueError( f"Response type mismatch for request ID {request_id}: " f"expected {pending_request.response_type}, got {type(response)}" diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index cffeb02aa0..053221712c 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -269,6 +269,7 @@ def __init__( runner_context: RunnerContext, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, + request_id: str | None = None, ): """Initialize the executor context with the given workflow context. @@ -281,6 +282,7 @@ def __init__( runner_context: The runner context that provides methods to send messages and events. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking (not for nesting). + request_id: Optional request ID if this context is for a `handle_response` handler. """ self._executor = executor self._executor_id = executor.id @@ -298,9 +300,21 @@ def __init__( self._trace_contexts = trace_contexts or [] self._source_span_ids = source_span_ids or [] + # request info related + self._request_id: str | None = request_id + if not self._source_executor_ids: raise ValueError("source_executor_ids cannot be empty. At least one source executor ID is required.") + @property + def request_id(self) -> str | None: + """Get the request ID if this context is for a `handle_response` handler. + + Returns: + The request ID string or None if not applicable. + """ + return self._request_id + async def send_message(self, message: T_Out, target_id: str | None = None) -> None: """Send a message to the workflow context. @@ -361,7 +375,7 @@ async def add_event(self, event: WorkflowEvent) -> None: return await self._runner_context.add_event(event) - async def request_info(self, request_data: object, response_type: type) -> None: + async def request_info(self, request_data: object, response_type: type, *, request_id: str | None = None) -> None: """Request information from outside of the workflow. Calling this method will cause the workflow to emit a RequestInfoEvent, carrying the @@ -374,6 +388,8 @@ async def request_info(self, request_data: object, response_type: type) -> None: Args: request_data: The data associated with the information request. response_type: The expected type of the response, used for validation. + request_id: Optional unique identifier for the request. If not provided, + a new UUID will be generated. This allows executors to track requests and responses. """ request_type: type = type(request_data) if not self._executor.is_request_supported(request_type, response_type): @@ -385,7 +401,7 @@ async def request_info(self, request_data: object, response_type: type) -> None: ) request_info_event = RequestInfoEvent( - request_id=str(uuid.uuid4()), + request_id=request_id or str(uuid.uuid4()), source_executor_id=self._executor_id, request_data=request_data, response_type=response_type, diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index dccd76403b..69f24bcf2c 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -18,10 +18,8 @@ WorkflowFailedEvent, WorkflowRunState, ) -from ._executor import ( - Executor, - handler, -) +from ._executor import Executor, handler +from ._request_info_mixin import response_handler from ._runner_context import Message from ._typing_utils import is_instance_of from ._workflow import WorkflowRunResult @@ -265,7 +263,14 @@ async def handle_subworkflow_request( - Concurrent executions are fully isolated and do not interfere with each other """ - def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = False, **kwargs: Any): + def __init__( + self, + workflow: "Workflow", + id: str, + allow_direct_output: bool = False, + propagate_request: bool = False, + **kwargs: Any, + ): """Initialize the WorkflowExecutor. Args: @@ -277,6 +282,11 @@ def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = Fa When this is set to true, the outputs are yielded directly from the WorkflowExecutor to the parent workflow's event stream. + propagate_request: Whether to propagate requests from the sub-workflow to the + parent workflow. If set to true, requests from the sub-workflow + will be propagated as the original RequestInfoEvent to the parent + workflow. Otherwise, they will be wrapped in a SubWorkflowRequestMessage, + which should be handled by an executor in the parent workflow. Keyword Args: **kwargs: Additional keyword arguments passed to the parent constructor. @@ -289,6 +299,7 @@ def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = Fa self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext # Map request_id to execution_id for response routing self._request_to_execution: dict[str, str] = {} # request_id -> execution_id + self._propagate_request = propagate_request @property def input_types(self) -> list[type[Any]]: @@ -336,8 +347,15 @@ def can_handle(self, message: Message) -> bool: This prevents the WorkflowExecutor from accepting messages that should go to other executors because the handler `process_workflow` has no type restrictions. """ - # Always handle SubWorkflowResponseMessage if isinstance(message.data, SubWorkflowResponseMessage): + # Always handle SubWorkflowResponseMessage + return True + + if ( + message.original_request_info_event is not None + and message.original_request_info_event.request_id in self._request_to_execution + ): + # Handle propagated responses for known requests return True # For other messages, only handle if the wrapped workflow can accept them as input @@ -388,7 +406,11 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) del self._execution_contexts[execution_id] @handler - async def handle_response(self, response: SubWorkflowResponseMessage, ctx: WorkflowContext[Any]) -> None: + async def handle_message_wrapped_request_response( + self, + response: SubWorkflowResponseMessage, + ctx: WorkflowContext[Any], + ) -> None: """Handle response from parent for a forwarded request. This handler accumulates responses and only resumes the sub-workflow @@ -398,55 +420,34 @@ async def handle_response(self, response: SubWorkflowResponseMessage, ctx: Workf response: The response to a previous request. ctx: The workflow context. """ - # Find the execution context for this request - original_request = response.source_event - execution_id = self._request_to_execution.get(original_request.request_id) - if not execution_id or execution_id not in self._execution_contexts: - logger.warning( - f"WorkflowExecutor {self.id} received response for unknown request_id: {original_request.request_id}. " - "This response will be ignored." - ) - return - - execution_context = self._execution_contexts[execution_id] - - # Check if we have this pending request in the execution context - if original_request.request_id not in execution_context.pending_requests: - logger.warning( - f"WorkflowExecutor {self.id} received response for unknown request_id: " - f"{original_request.request_id} in execution {execution_id}, ignoring" - ) - return - - # Remove the request from pending list and request mapping - execution_context.pending_requests.pop(original_request.request_id, None) - self._request_to_execution.pop(original_request.request_id, None) - - # Accumulate the response in this execution's context - execution_context.collected_responses[original_request.request_id] = response.data - - # Check if we have all expected responses for this execution - if len(execution_context.collected_responses) < execution_context.expected_response_count: - logger.debug( - f"WorkflowExecutor {self.id} execution {execution_id} waiting for more responses: " - f"{len(execution_context.collected_responses)}/{execution_context.expected_response_count} received" - ) - return # Wait for more responses + await self._handle_response( + request_id=response.source_event.request_id, + response=response.data, + ctx=ctx, + ) - # Send all collected responses to the sub-workflow - responses_to_send = dict(execution_context.collected_responses) - execution_context.collected_responses.clear() # Clear for next batch + @response_handler + async def handle_propagated_request_response( + self, + original_request: Any, + response: object, + ctx: WorkflowContext[Any], + ) -> None: + """Handle response for a request that was propagated to the parent workflow. - try: - # Resume the sub-workflow with all collected responses - result = await self.workflow.send_responses(responses_to_send) + Args: + original_request: The original RequestInfoEvent. + response: The response data. + ctx: The workflow context. + """ + if ctx.request_id is None: + raise RuntimeError("WorkflowExecutor received a propagated response without a request ID in the context.") - # Process the workflow result using shared logic - await self._process_workflow_result(result, execution_context, ctx) - finally: - # Clean up execution context if it's completed (no pending requests) - if not execution_context.pending_requests: - del self._execution_contexts[execution_id] + await self._handle_response( + request_id=ctx.request_id, + response=response, + ctx=ctx, + ) @override async def on_checkpoint_save(self) -> dict[str, Any]: @@ -552,13 +553,15 @@ async def _process_workflow_result( execution_context.pending_requests[event.request_id] = event # Map request to execution for response routing self._request_to_execution[event.request_id] = execution_context.execution_id - # TODO(@taochen): There should be two ways a sub-workflow can make a request: - # 1. In a workflow where the parent workflow has an executor that may intercept the - # request and handle it directly, a message should be sent. - # 2. In a workflow where the parent workflow does not handle the request, the request - # should be propagated via the `request_info` mechanism to an external source. And - # a @response_handler would be required in the WorkflowExecutor to handle the response. - await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) + if self._propagate_request: + # In a workflow where the parent workflow does not handle the request, the request + # should be propagated via the `request_info` mechanism to an external source. And + # a @response_handler would be required in the WorkflowExecutor to handle the response. + await ctx.request_info(event.data, event.response_type, request_id=event.request_id) + else: + # In a workflow where the parent workflow has an executor that may intercept the + # request and handle it directly, a message should be sent. + await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) # Update expected response count for this execution execution_context.expected_response_count = len(request_info_events) @@ -602,3 +605,56 @@ async def _process_workflow_result( ) else: raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}") + + async def _handle_response( + self, + request_id: str, + response: Any, + ctx: WorkflowContext[Any], + ) -> None: + execution_id = self._request_to_execution.get(request_id) + if not execution_id or execution_id not in self._execution_contexts: + logger.warning( + f"WorkflowExecutor {self.id} received response for unknown request_id: {request_id}. " + "This response will be ignored." + ) + return + + execution_context = self._execution_contexts[execution_id] + + # Check if we have this pending request in the execution context + if request_id not in execution_context.pending_requests: + logger.warning( + f"WorkflowExecutor {self.id} received response for unknown request_id: " + f"{request_id} in execution {execution_id}, ignoring" + ) + return + + # Remove the request from pending list and request mapping + execution_context.pending_requests.pop(request_id, None) + self._request_to_execution.pop(request_id, None) + + # Accumulate the response in this execution's context + execution_context.collected_responses[request_id] = response + # Check if we have all expected responses for this execution + if len(execution_context.collected_responses) < execution_context.expected_response_count: + logger.debug( + f"WorkflowExecutor {self.id} execution {execution_id} waiting for more responses: " + f"{len(execution_context.collected_responses)}/{execution_context.expected_response_count} received" + ) + return # Wait for more responses + + # Send all collected responses to the sub-workflow + responses_to_send = dict(execution_context.collected_responses) + execution_context.collected_responses.clear() # Clear for next batch + + try: + # Resume the sub-workflow with all collected responses + result = await self.workflow.send_responses(responses_to_send) + + # Process the workflow result using shared logic + await self._process_workflow_result(result, execution_context, ctx) + finally: + # Clean up execution context if it's completed (no pending requests) + if not execution_context.pending_requests: + del self._execution_contexts[execution_id] diff --git a/python/packages/core/tests/workflow/test_agent_run_event_typing.py b/python/packages/core/tests/workflow/test_agent_run_event_typing.py index a89aa817a3..271a9bbe2f 100644 --- a/python/packages/core/tests/workflow/test_agent_run_event_typing.py +++ b/python/packages/core/tests/workflow/test_agent_run_event_typing.py @@ -17,14 +17,6 @@ def test_agent_run_event_data_type() -> None: assert data.text == "Hello" -def test_agent_run_event_data_none() -> None: - """Verify AgentRunEvent.data can be None.""" - event = AgentRunEvent(executor_id="test") - - data: AgentRunResponse | None = event.data - assert data is None - - def test_agent_run_update_event_data_type() -> None: """Verify AgentRunUpdateEvent.data is typed as AgentRunResponseUpdate | None.""" update = AgentRunResponseUpdate() @@ -33,11 +25,3 @@ def test_agent_run_update_event_data_type() -> None: # This assignment should pass type checking without a cast data: AgentRunResponseUpdate | None = event.data assert data is not None - - -def test_agent_run_update_event_data_none() -> None: - """Verify AgentRunUpdateEvent.data can be None.""" - event = AgentRunUpdateEvent(executor_id="test") - - data: AgentRunResponseUpdate | None = event.data - assert data is None diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py new file mode 100644 index 0000000000..7f80658a09 --- /dev/null +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +from collections.abc import AsyncIterable +from typing import Any + +from agent_framework import AgentRunResponse, AgentRunResponseUpdate, AgentThread, ChatMessage +from agent_framework._workflows._agent_utils import resolve_agent_id + + +class MockAgent: + """Mock agent for testing agent utilities.""" + + def __init__(self, agent_id: str, name: str | None = None) -> None: + self._id = agent_id + self._name = name + + @property + def id(self) -> str: + return self._id + + @property + def name(self) -> str | None: + return self._name + + @property + def display_name(self) -> str: + """Returns the display name of the agent.""" + ... + + @property + def description(self) -> str | None: + """Returns the description of the agent.""" + ... + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: ... + + def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: ... + + def get_new_thread(self, **kwargs: Any) -> AgentThread: + """Creates a new conversation thread for the agent.""" + ... + + +def test_resolve_agent_id_with_name() -> None: + """Test that resolve_agent_id returns name when agent has a name.""" + agent = MockAgent(agent_id="agent-123", name="MyAgent") + result = resolve_agent_id(agent) + assert result == "MyAgent" + + +def test_resolve_agent_id_without_name() -> None: + """Test that resolve_agent_id returns id when agent has no name.""" + agent = MockAgent(agent_id="agent-456", name=None) + result = resolve_agent_id(agent) + assert result == "agent-456" + + +def test_resolve_agent_id_with_empty_name() -> None: + """Test that resolve_agent_id returns id when agent has empty string name.""" + agent = MockAgent(agent_id="agent-789", name="") + result = resolve_agent_id(agent) + assert result == "agent-789" + + +def test_resolve_agent_id_prefers_name_over_id() -> None: + """Test that resolve_agent_id prefers name over id when both are set.""" + agent = MockAgent(agent_id="agent-abc", name="PreferredName") + result = resolve_agent_id(agent) + assert result == "PreferredName" + assert result != "agent-abc" diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 176c3027c8..a812f6dae6 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -12,6 +12,7 @@ WorkflowContext, executor, handler, + response_handler, ) @@ -266,6 +267,247 @@ async def handle(self, response: Response, ctx: WorkflowContext) -> None: assert collector_invoked.data.results == ["HELLO", "HELLO", "HELLO"] +def test_executor_output_types_property(): + """Test that the output_types property correctly identifies message output types.""" + + # Test executor with no output types + class NoOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext) -> None: + pass + + executor = NoOutputExecutor(id="no_output") + assert executor.output_types == [] + + # Test executor with single output type + class SingleOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + executor = SingleOutputExecutor(id="single_output") + assert int in executor.output_types + assert len(executor.output_types) == 1 + + # Test executor with union output types + class UnionOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int | str]) -> None: + pass + + executor = UnionOutputExecutor(id="union_output") + assert int in executor.output_types + assert str in executor.output_types + assert len(executor.output_types) == 2 + + # Test executor with multiple handlers having different output types + class MultiHandlerExecutor(Executor): + @handler + async def handle_string(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + @handler + async def handle_number(self, num: int, ctx: WorkflowContext[bool]) -> None: + pass + + executor = MultiHandlerExecutor(id="multi_handler") + assert int in executor.output_types + assert bool in executor.output_types + assert len(executor.output_types) == 2 + + +def test_executor_workflow_output_types_property(): + """Test that the workflow_output_types property correctly identifies workflow output types.""" + from typing_extensions import Never + + # Test executor with no workflow output types + class NoWorkflowOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + executor = NoWorkflowOutputExecutor(id="no_workflow_output") + assert executor.workflow_output_types == [] + + # Test executor with workflow output type (second type parameter) + class WorkflowOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + executor = WorkflowOutputExecutor(id="workflow_output") + assert str in executor.workflow_output_types + assert len(executor.workflow_output_types) == 1 + + # Test executor with union workflow output types + class UnionWorkflowOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None: + pass + + executor = UnionWorkflowOutputExecutor(id="union_workflow_output") + assert str in executor.workflow_output_types + assert bool in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + # Test executor with multiple handlers having different workflow output types + class MultiHandlerWorkflowExecutor(Executor): + @handler + async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + @handler + async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None: + pass + + executor = MultiHandlerWorkflowExecutor(id="multi_workflow") + assert str in executor.workflow_output_types + assert float in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + # Test executor with Never for message output (only workflow output) + class YieldOnlyExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[Never, str]) -> None: + pass + + executor = YieldOnlyExecutor(id="yield_only") + assert str in executor.workflow_output_types + assert len(executor.workflow_output_types) == 1 + # Should have no message output types + assert executor.output_types == [] + + +def test_executor_output_and_workflow_output_types_combined(): + """Test executor with both message and workflow output types.""" + + class DualOutputExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + executor = DualOutputExecutor(id="dual") + + # Should have int as message output type + assert int in executor.output_types + assert len(executor.output_types) == 1 + + # Should have str as workflow output type + assert str in executor.workflow_output_types + assert len(executor.workflow_output_types) == 1 + + # They should be distinct + assert int not in executor.workflow_output_types + assert str not in executor.output_types + + +def test_executor_output_types_includes_response_handlers(): + """Test that output_types includes types from response handlers.""" + from agent_framework import response_handler + + class RequestResponseExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + @response_handler + async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None: + pass + + executor = RequestResponseExecutor(id="request_response") + + # Should include output types from both handler and response_handler + assert int in executor.output_types + assert float in executor.output_types + assert len(executor.output_types) == 2 + + +def test_executor_workflow_output_types_includes_response_handlers(): + """Test that workflow_output_types includes types from response handlers.""" + from agent_framework import response_handler + + class RequestResponseWorkflowExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int, str]) -> None: + pass + + @response_handler + async def handle_response( + self, original_request: str, response: bool, ctx: WorkflowContext[float, bool] + ) -> None: + pass + + executor = RequestResponseWorkflowExecutor(id="request_response_workflow") + + # Should include workflow output types from both handler and response_handler + assert str in executor.workflow_output_types + assert bool in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + # Verify message output types are separate + assert int in executor.output_types + assert float in executor.output_types + assert len(executor.output_types) == 2 + + +def test_executor_multiple_response_handlers_output_types(): + """Test that multiple response handlers contribute their output types.""" + + class MultiResponseHandlerExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext[int]) -> None: + pass + + @response_handler + async def handle_string_bool_response( + self, original_request: str, response: bool, ctx: WorkflowContext[float] + ) -> None: + pass + + @response_handler + async def handle_int_bool_response( + self, original_request: int, response: bool, ctx: WorkflowContext[bool] + ) -> None: + pass + + executor = MultiResponseHandlerExecutor(id="multi_response") + + # Should include output types from all handlers and response handlers + assert int in executor.output_types + assert float in executor.output_types + assert bool in executor.output_types + assert len(executor.output_types) == 3 + + +def test_executor_response_handler_union_output_types(): + """Test that response handlers with union output types contribute all types.""" + from agent_framework import response_handler + + class UnionResponseHandlerExecutor(Executor): + @handler + async def handle(self, text: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_response( + self, original_request: str, response: bool, ctx: WorkflowContext[int | str | float, bool | int] + ) -> None: + pass + + executor = UnionResponseHandlerExecutor(id="union_response") + + # Should include all output types from the union + assert int in executor.output_types + assert str in executor.output_types + assert float in executor.output_types + assert len(executor.output_types) == 3 + + # Should include all workflow output types from the union + assert bool in executor.workflow_output_types + assert int in executor.workflow_output_types + assert len(executor.workflow_output_types) == 2 + + async def test_executor_invoked_event_data_not_mutated_by_handler(): """Test that ExecutorInvokedEvent.data captures original input, not mutated input.""" diff --git a/python/packages/core/tests/workflow/test_group_chat.py b/python/packages/core/tests/workflow/test_group_chat.py index b66900dad7..b575fdd684 100644 --- a/python/packages/core/tests/workflow/test_group_chat.py +++ b/python/packages/core/tests/workflow/test_group_chat.py @@ -4,48 +4,33 @@ from typing import Any, cast import pytest -from pydantic import BaseModel from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, + AgentExecutorResponse, + AgentRequestInfoResponse, AgentRunResponse, AgentRunResponseUpdate, - AgentRunUpdateEvent, AgentThread, BaseAgent, + BaseGroupChatOrchestrator, + ChatAgent, ChatMessage, - Executor, + ChatResponse, + ChatResponseUpdate, GroupChatBuilder, - GroupChatDirective, - GroupChatStateSnapshot, - MagenticBuilder, + GroupChatState, MagenticContext, MagenticManagerBase, + MagenticProgressLedger, + MagenticProgressLedgerItem, + RequestInfoEvent, Role, TextContent, - Workflow, - WorkflowContext, WorkflowOutputEvent, - handler, + WorkflowRunState, + WorkflowStatusEvent, ) from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage -from agent_framework._workflows._group_chat import ( - GroupChatOrchestratorExecutor, - ManagerSelectionResponse, - _default_orchestrator_factory, # type: ignore - _default_participant_factory, # type: ignore - _GroupChatConfig, # type: ignore - _SpeakerSelectorAdapter, # type: ignore - assemble_group_chat_workflow, -) -from agent_framework._workflows._magentic import ( - _MagenticProgressLedger, # type: ignore - _MagenticProgressLedgerItem, # type: ignore - _MagenticStartMessage, # type: ignore -) -from agent_framework._workflows._participant_utils import GroupChatParticipantSpec -from agent_framework._workflows._workflow_builder import WorkflowBuilder class StubAgent(BaseAgent): @@ -78,9 +63,23 @@ async def _stream() -> AsyncIterable[AgentRunResponseUpdate]: return _stream() -class StubManagerAgent(BaseAgent): +class MockChatClient: + """Mock chat client that raises NotImplementedError for all methods.""" + + @property + def additional_properties(self) -> dict[str, Any]: + return {} + + async def get_response(self, messages: Any, **kwargs: Any) -> ChatResponse: + raise NotImplementedError + + def get_streaming_response(self, messages: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]: + raise NotImplementedError + + +class StubManagerAgent(ChatAgent): def __init__(self) -> None: - super().__init__(name="manager_agent", description="Stub manager") + super().__init__(chat_client=MockChatClient(), name="manager_agent", description="Stub manager") self._call_count = 0 async def run( @@ -89,27 +88,40 @@ async def run( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AgentRunResponse: # type: ignore[override] + ) -> AgentRunResponse: if self._call_count == 0: self._call_count += 1 - payload = {"selected_participant": "agent", "finish": False, "final_message": None} + # First call: select the agent (using AgentOrchestrationOutput format) + payload = {"terminate": False, "reason": "Selecting agent", "next_speaker": "agent", "final_message": None} return AgentRunResponse( messages=[ ChatMessage( role=Role.ASSISTANT, - text='{"selected_participant": "agent", "finish": false}', + text=( + '{"terminate": false, "reason": "Selecting agent", ' + '"next_speaker": "agent", "final_message": null}' + ), author_name=self.name, ) ], value=payload, ) - payload = {"selected_participant": None, "finish": True, "final_message": "agent manager final"} + # Second call: terminate + payload = { + "terminate": True, + "reason": "Task complete", + "next_speaker": None, + "final_message": "agent manager final", + } return AgentRunResponse( messages=[ ChatMessage( role=Role.ASSISTANT, - text='{"finish": true, "final_message": "agent manager final"}', + text=( + '{"terminate": true, "reason": "Task complete", ' + '"next_speaker": null, "final_message": "agent manager final"}' + ), author_name=self.name, ) ], @@ -122,13 +134,20 @@ def run_stream( *, thread: AgentThread | None = None, **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: # type: ignore[override] + ) -> AsyncIterable[AgentRunResponseUpdate]: if self._call_count == 0: self._call_count += 1 async def _stream_initial() -> AsyncIterable[AgentRunResponseUpdate]: yield AgentRunResponseUpdate( - contents=[TextContent(text='{"selected_participant": "agent", "finish": false}')], + contents=[ + TextContent( + text=( + '{"terminate": false, "reason": "Selecting agent", ' + '"next_speaker": "agent", "final_message": null}' + ) + ) + ], role=Role.ASSISTANT, author_name=self.name, ) @@ -137,7 +156,14 @@ async def _stream_initial() -> AsyncIterable[AgentRunResponseUpdate]: async def _stream_final() -> AsyncIterable[AgentRunResponseUpdate]: yield AgentRunResponseUpdate( - contents=[TextContent(text='{"finish": true, "final_message": "agent manager final"}')], + contents=[ + TextContent( + text=( + '{"terminate": true, "reason": "Task complete", ' + '"next_speaker": null, "final_message": "agent manager final"}' + ) + ) + ], role=Role.ASSISTANT, author_name=self.name, ) @@ -145,21 +171,20 @@ async def _stream_final() -> AsyncIterable[AgentRunResponseUpdate]: return _stream_final() -def make_sequence_selector() -> Callable[[GroupChatStateSnapshot], Any]: +def make_sequence_selector() -> Callable[[GroupChatState], str]: state_counter = {"value": 0} - async def _selector(state: GroupChatStateSnapshot) -> str | None: - participants = list(state["participants"].keys()) + def _selector(state: GroupChatState) -> str: + participants = list(state.participants.keys()) step = state_counter["value"] + state_counter["value"] = step + 1 if step == 0: - state_counter["value"] = step + 1 return participants[0] if step == 1 and len(participants) > 1: - state_counter["value"] = step + 1 return participants[1] - return None + # Return first participant to continue (will be limited by max_rounds in tests) + return participants[0] - _selector.name = "manager" # type: ignore[attr-defined] return _selector @@ -174,46 +199,30 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return await self.plan(magentic_context) - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: participants = list(magentic_context.participant_descriptions.keys()) target = participants[0] if participants else "agent" if self._round == 0: self._round += 1 - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="", answer=False), - is_in_loop=_MagenticProgressLedgerItem(reason="", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="", answer=target), - instruction_or_question=_MagenticProgressLedgerItem(reason="", answer="respond"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="", answer=False), + is_in_loop=MagenticProgressLedgerItem(reason="", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="", answer=target), + instruction_or_question=MagenticProgressLedgerItem(reason="", answer="respond"), ) - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="", answer=True), - is_in_loop=_MagenticProgressLedgerItem(reason="", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="", answer=target), - instruction_or_question=_MagenticProgressLedgerItem(reason="", answer=""), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="", answer=True), + is_in_loop=MagenticProgressLedgerItem(reason="", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="", answer=target), + instruction_or_question=MagenticProgressLedgerItem(reason="", answer=""), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="final", author_name="magentic_manager") -class PassthroughExecutor(Executor): - @handler - async def forward(self, message: Any, ctx: WorkflowContext[Any]) -> None: - await ctx.send_message(message) - - -class CountingWorkflowBuilder(WorkflowBuilder): - def __init__(self) -> None: - super().__init__() - self.start_calls = 0 - - def set_start_executor(self, executor: Any) -> "CountingWorkflowBuilder": - self.start_calls += 1 - return cast("CountingWorkflowBuilder", super().set_start_executor(executor)) - - async def test_group_chat_builder_basic_flow() -> None: selector = make_sequence_selector() alpha = StubAgent("alpha", "ack from alpha") @@ -221,8 +230,9 @@ async def test_group_chat_builder_basic_flow() -> None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha, beta=beta) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha, beta]) + .with_max_rounds(2) # Limit rounds to prevent infinite loop .build() ) @@ -235,44 +245,9 @@ async def test_group_chat_builder_basic_flow() -> None: assert len(outputs) == 1 assert len(outputs[0]) >= 1 - # The final message should be "done" from the manager - assert outputs[0][-1].text == "done" - assert outputs[0][-1].author_name == "manager" - - -async def test_magentic_builder_returns_workflow_and_runs() -> None: - manager = StubMagenticManager() - agent = StubAgent("writer", "first draft") - - workflow = MagenticBuilder().participants(writer=agent).with_standard_manager(manager=manager).build() - - assert isinstance(workflow, Workflow) - - outputs: list[ChatMessage] = [] - orchestrator_event_count = 0 - agent_event_count = 0 - start_message = _MagenticStartMessage.from_string("compose summary") - async for event in workflow.run_stream(start_message): - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - orchestrator_event_count += 1 - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_event_count += 1 - if isinstance(event, WorkflowOutputEvent): - msg = event.data - if isinstance(msg, list): - outputs.append(cast(list[ChatMessage], msg)) - - assert outputs, "Expected a final output message" - conversation = outputs[-1] - assert len(conversation) >= 1 - final = conversation[-1] - assert final.text == "final" - assert final.author_name == "magentic_manager" - assert orchestrator_event_count > 0, "Expected orchestrator events to be emitted" - assert agent_event_count > 0, "Expected agent delta events to be emitted" + # Check that both agents contributed + authors = {msg.author_name for msg in outputs[0] if msg.author_name in ["alpha", "beta"]} + assert len(authors) == 2 async def test_group_chat_as_agent_accepts_conversation() -> None: @@ -282,8 +257,9 @@ async def test_group_chat_as_agent_accepts_conversation() -> None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha, beta=beta) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha, beta]) + .with_max_rounds(2) # Limit rounds to prevent infinite loop .build() ) @@ -297,22 +273,6 @@ async def test_group_chat_as_agent_accepts_conversation() -> None: assert response.messages, "Expected agent conversation output" -async def test_magentic_as_agent_accepts_conversation() -> None: - manager = StubMagenticManager() - writer = StubAgent("writer", "draft") - - workflow = MagenticBuilder().participants(writer=writer).with_standard_manager(manager=manager).build() - - agent = workflow.as_agent(name="magentic-agent") - conversation = [ - ChatMessage(role=Role.SYSTEM, text="Guidelines", author_name="system"), - ChatMessage(role=Role.USER, text="Summarize the findings", author_name="requester"), - ] - response = await agent.run(conversation) - - assert isinstance(response, AgentRunResponse) - - # Comprehensive tests for group chat functionality @@ -325,16 +285,16 @@ def test_build_without_manager_raises_error(self) -> None: builder = GroupChatBuilder().participants([agent]) - with pytest.raises(ValueError, match="manager must be configured before build"): + with pytest.raises(RuntimeError, match="Orchestrator could not be resolved"): builder.build() def test_build_without_participants_raises_error(self) -> None: """Test that building without participants raises ValueError.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) with pytest.raises(ValueError, match="participants must be configured before build"): builder.build() @@ -342,21 +302,21 @@ def selector(state: GroupChatStateSnapshot) -> str | None: def test_duplicate_manager_configuration_raises_error(self) -> None: """Test that configuring multiple managers raises ValueError.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) - with pytest.raises(ValueError, match="already has a manager configured"): - builder.set_select_speakers_func(selector) + with pytest.raises(ValueError, match="select_speakers_func has already been configured"): + builder.with_select_speaker_func(selector) def test_empty_participants_raises_error(self) -> None: """Test that empty participants list raises ValueError.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) with pytest.raises(ValueError, match="participants cannot be empty"): builder.participants([]) @@ -366,10 +326,10 @@ def test_duplicate_participant_names_raises_error(self) -> None: agent1 = StubAgent("test", "response1") agent2 = StubAgent("test", "response2") - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) with pytest.raises(ValueError, match="Duplicate participant name 'test'"): builder.participants([agent1, agent2]) @@ -394,77 +354,35 @@ async def _stream() -> AsyncIterable[AgentRunResponseUpdate]: agent = AgentWithoutName() - def selector(state: GroupChatStateSnapshot) -> str | None: - return None + def selector(state: GroupChatState) -> str: + return "agent" - builder = GroupChatBuilder().set_select_speakers_func(selector) + builder = GroupChatBuilder().with_select_speaker_func(selector) - with pytest.raises(ValueError, match="must define a non-empty 'name' attribute"): + with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"): builder.participants([agent]) def test_empty_participant_name_raises_error(self) -> None: """Test that empty participant name raises ValueError.""" - agent = StubAgent("test", "response") - - def selector(state: GroupChatStateSnapshot) -> str | None: - return None - - builder = GroupChatBuilder().set_select_speakers_func(selector) - - with pytest.raises(ValueError, match="participant names must be non-empty strings"): - builder.participants({"": agent}) - - def test_assemble_group_chat_respects_existing_start_executor(self) -> None: - """Ensure assemble_group_chat_workflow does not override preconfigured start executor.""" - - async def manager(_: GroupChatStateSnapshot) -> GroupChatDirective: - return GroupChatDirective(finish=True) + agent = StubAgent("", "response") # Agent with empty name - builder = CountingWorkflowBuilder() - entry = PassthroughExecutor(id="entry") - builder = builder.set_start_executor(entry) - - participant = PassthroughExecutor(id="participant") - participant_spec = GroupChatParticipantSpec( - name="participant", - participant=participant, - description="participant", - ) - - wiring = _GroupChatConfig( - manager=manager, - manager_participant=None, - manager_name="manager", - participants={"participant": participant_spec}, - max_rounds=None, - termination_condition=None, - participant_aliases={}, - participant_executors={"participant": participant}, - ) + def selector(state: GroupChatState) -> str: + return "agent" - result = assemble_group_chat_workflow( - wiring=wiring, - participant_factory=_default_participant_factory, - orchestrator_factory=_default_orchestrator_factory, - builder=builder, - return_builder=True, - ) + builder = GroupChatBuilder().with_select_speaker_func(selector) - assert isinstance(result, tuple) - assembled_builder, _ = result - assert assembled_builder is builder - assert builder.start_calls == 1 - assert assembled_builder._start_executor is entry # type: ignore + with pytest.raises(ValueError, match="AgentProtocol participants must have a non-empty name"): + builder.participants([agent]) -class TestGroupChatOrchestrator: - """Tests for GroupChatOrchestratorExecutor core functionality.""" +class TestGroupChatWorkflow: + """Tests for GroupChat workflow functionality.""" async def test_max_rounds_enforcement(self) -> None: """Test that max_rounds properly limits conversation rounds.""" call_count = {"value": 0} - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: call_count["value"] += 1 # Always return the agent name to try to continue indefinitely return "agent" @@ -473,7 +391,7 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_max_rounds(2) # Limit to 2 rounds .build() @@ -492,12 +410,12 @@ def selector(state: GroupChatStateSnapshot) -> str | None: conversation = outputs[-1] assert len(conversation) >= 1 final_output = conversation[-1] - assert "round limit" in final_output.text.lower() + assert "maximum number of rounds" in final_output.text.lower() async def test_termination_condition_halts_conversation(self) -> None: """Test that a custom termination condition stops the workflow.""" - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: return "agent" def termination_condition(conversation: list[ChatMessage]) -> bool: @@ -508,7 +426,7 @@ def termination_condition(conversation: list[ChatMessage]) -> bool: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_termination_condition(termination_condition) .build() @@ -526,46 +444,17 @@ def termination_condition(conversation: list[ChatMessage]) -> bool: agent_replies = [msg for msg in conversation if msg.author_name == "agent" and msg.role == Role.ASSISTANT] assert len(agent_replies) == 2 final_output = conversation[-1] - assert final_output.author_name == "manager" + # The orchestrator uses its ID as author_name by default assert "termination condition" in final_output.text.lower() - async def test_termination_condition_uses_manager_final_message(self) -> None: - """Test that manager-provided final message is used on termination.""" - - async def selector(state: GroupChatStateSnapshot) -> str | None: - return None - - agent = StubAgent("agent", "response") - final_text = "manager summary on termination" - - workflow = ( - GroupChatBuilder() - .set_select_speakers_func(selector, final_message=final_text) - .participants([agent]) - .with_termination_condition(lambda _: True) - .build() - ) - - outputs: list[list[ChatMessage]] = [] - async for event in workflow.run_stream("test task"): - if isinstance(event, WorkflowOutputEvent): - data = event.data - if isinstance(data, list): - outputs.append(cast(list[ChatMessage], data)) - - assert outputs, "Expected termination to yield output" - conversation = outputs[-1] - assert conversation[-1].text == final_text - assert conversation[-1].author_name == "manager" - async def test_termination_condition_agent_manager_finalizes(self) -> None: - """Test that agent-based manager can provide final message on termination.""" + """Test that termination condition with agent orchestrator produces default termination message.""" manager = StubManagerAgent() worker = StubAgent("agent", "response") workflow = ( GroupChatBuilder() - .set_manager(manager, display_name="Manager") + .with_agent_orchestrator(manager) .participants([worker]) .with_termination_condition(lambda conv: any(msg.author_name == "agent" for msg in conv)) .build() @@ -580,167 +469,23 @@ async def test_termination_condition_agent_manager_finalizes(self) -> None: assert outputs, "Expected termination to yield output" conversation = outputs[-1] - assert conversation[-1].text == "agent manager final" - assert conversation[-1].author_name == "Manager" + assert conversation[-1].text == BaseGroupChatOrchestrator.TERMINATION_CONDITION_MET_MESSAGE + assert conversation[-1].author_name == manager.name async def test_unknown_participant_error(self) -> None: - """Test that _apply_directive raises error for unknown participants.""" + """Test that unknown participant selection raises error.""" - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: return "unknown_agent" # Return non-existent participant agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).build() - with pytest.raises(ValueError, match="Manager selected unknown participant 'unknown_agent'"): + with pytest.raises(RuntimeError, match="Selection function returned unknown participant 'unknown_agent'"): async for _ in workflow.run_stream("test task"): pass - async def test_directive_without_agent_name_raises_error(self) -> None: - """Test that directive without agent_name raises error when finish=False.""" - - def bad_selector(state: GroupChatStateSnapshot) -> GroupChatDirective: - # Return a GroupChatDirective object instead of string to trigger error - return GroupChatDirective(finish=False, agent_name=None) # type: ignore - - agent = StubAgent("agent", "response") - - # The _SpeakerSelectorAdapter will catch this and raise TypeError - workflow = GroupChatBuilder().set_select_speakers_func(bad_selector).participants([agent]).build() # type: ignore - - # This should raise a TypeError because selector doesn't return str or None - with pytest.raises(TypeError, match="must return a participant name \\(str\\) or None"): - async for _ in workflow.run_stream("test"): - pass - - async def test_handle_empty_conversation_raises_error(self) -> None: - """Test that empty conversation list raises ValueError.""" - - def selector(state: GroupChatStateSnapshot) -> str | None: - return None - - agent = StubAgent("agent", "response") - - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() - - with pytest.raises(ValueError, match="requires at least one chat message"): - async for _ in workflow.run_stream([]): - pass - - async def test_unknown_participant_response_raises_error(self) -> None: - """Test that responses from unknown participants raise errors.""" - - def selector(state: GroupChatStateSnapshot) -> str | None: - return "agent" - - # Create orchestrator to test _ingest_participant_message directly - orchestrator = GroupChatOrchestratorExecutor( - manager=selector, # type: ignore - participants={"agent": "test agent"}, - manager_name="test_manager", # type: ignore - ) - - # Mock the workflow context - class MockContext: - async def yield_output(self, message: ChatMessage) -> None: - pass - - ctx = MockContext() - - # Initialize orchestrator state - orchestrator._task_message = ChatMessage(role=Role.USER, text="test") # type: ignore - orchestrator._conversation = [orchestrator._task_message] # type: ignore - orchestrator._history = [] # type: ignore - orchestrator._pending_agent = None # type: ignore - orchestrator._round_index = 0 # type: ignore - - # Test with unknown participant - message = ChatMessage(role=Role.ASSISTANT, text="response") - - with pytest.raises(ValueError, match="Received response from unknown participant 'unknown'"): - await orchestrator._ingest_participant_message("unknown", message, ctx) # type: ignore - - async def test_state_build_before_initialization_raises_error(self) -> None: - """Test that _build_state raises error before task message initialization.""" - - def selector(state: GroupChatStateSnapshot) -> str | None: - return None - - orchestrator = GroupChatOrchestratorExecutor( - manager=selector, # type: ignore - participants={"agent": "test agent"}, - manager_name="test_manager", # type: ignore - ) - - with pytest.raises(RuntimeError, match="state not initialized with task message"): - orchestrator._build_state() # type: ignore - - -class TestSpeakerSelectorAdapter: - """Tests for _SpeakerSelectorAdapter functionality.""" - - async def test_selector_returning_list_with_multiple_items_raises_error(self) -> None: - """Test that selector returning list with multiple items raises error.""" - - def bad_selector(state: GroupChatStateSnapshot) -> list[str]: - return ["agent1", "agent2"] # Multiple items - - adapter = _SpeakerSelectorAdapter(bad_selector, manager_name="manager") - - state = { - "participants": {"agent1": "desc1", "agent2": "desc2"}, - "task": ChatMessage(role=Role.USER, text="test"), - "conversation": (), - "history": (), - "round_index": 0, - "pending_agent": None, - } - - with pytest.raises(ValueError, match="must return a single participant name"): - await adapter(state) - - async def test_selector_returning_non_string_raises_error(self) -> None: - """Test that selector returning non-string raises TypeError.""" - - def bad_selector(state: GroupChatStateSnapshot) -> int: - return 42 # Not a string - - adapter = _SpeakerSelectorAdapter(bad_selector, manager_name="manager") - - state = { - "participants": {"agent": "desc"}, - "task": ChatMessage(role=Role.USER, text="test"), - "conversation": (), - "history": (), - "round_index": 0, - "pending_agent": None, - } - - with pytest.raises(TypeError, match="must return a participant name \\(str\\) or None"): - await adapter(state) - - async def test_selector_returning_empty_list_finishes(self) -> None: - """Test that selector returning empty list finishes conversation.""" - - def empty_selector(state: GroupChatStateSnapshot) -> list[str]: - return [] # Empty list should finish - - adapter = _SpeakerSelectorAdapter(empty_selector, manager_name="manager") - - state = { - "participants": {"agent": "desc"}, - "task": ChatMessage(role=Role.USER, text="test"), - "conversation": (), - "history": (), - "round_index": 0, - "pending_agent": None, - } - - directive = await adapter(state) - assert directive.finish is True - assert directive.final_message is not None - class TestCheckpointing: """Tests for checkpointing functionality.""" @@ -748,9 +493,7 @@ class TestCheckpointing: async def test_workflow_with_checkpointing(self) -> None: """Test that workflow works with checkpointing enabled.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - if state["round_index"] >= 1: - return None + def selector(state: GroupChatState) -> str: return "agent" agent = StubAgent("agent", "response") @@ -758,8 +501,9 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) + .with_max_rounds(1) .with_checkpointing(storage) .build() ) @@ -774,89 +518,40 @@ def selector(state: GroupChatStateSnapshot) -> str | None: assert len(outputs) == 1 # Should complete normally -class TestAgentManagerConfiguration: - """Tests for agent-based manager configuration.""" - - async def test_set_manager_configures_response_format(self) -> None: - """Ensure ChatAgent managers receive default ManagerSelectionResponse formatting.""" - from unittest.mock import MagicMock - - from agent_framework import ChatAgent - - chat_client = MagicMock() - manager_agent = ChatAgent(chat_client=chat_client, name="Coordinator") - assert manager_agent.default_options.get("response_format") is None - - worker = StubAgent("worker", "response") - - builder = GroupChatBuilder().set_manager(manager_agent).participants([worker]) - - assert manager_agent.default_options.get("response_format") is ManagerSelectionResponse - assert builder._manager_participant is manager_agent # type: ignore[attr-defined] - - async def test_set_manager_accepts_agent_manager(self) -> None: - """Verify agent-based manager can be set and workflow builds.""" - from unittest.mock import MagicMock - - from agent_framework import ChatAgent - - chat_client = MagicMock() - manager_agent = ChatAgent(chat_client=chat_client, name="Coordinator") - worker = StubAgent("worker", "response") - - builder = GroupChatBuilder().set_manager(manager_agent, display_name="Orchestrator") - builder = builder.participants([worker]).with_max_rounds(1) - - assert builder._manager_participant is manager_agent # type: ignore[attr-defined] - assert "worker" in builder._participants # type: ignore[attr-defined] +class TestConversationHandling: + """Tests for different conversation input types.""" - async def test_set_manager_rejects_custom_response_format(self) -> None: - """Reject custom response_format on ChatAgent managers.""" - from unittest.mock import MagicMock + async def test_handle_empty_conversation_raises_error(self) -> None: + """Test that empty conversation list raises ValueError.""" - from agent_framework import ChatAgent + def selector(state: GroupChatState) -> str: + return "agent" - class CustomResponse(BaseModel): - value: str + agent = StubAgent("agent", "response") - chat_client = MagicMock() - manager_agent = ChatAgent( - chat_client=chat_client, name="Coordinator", default_options={"response_format": CustomResponse} + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() ) - worker = StubAgent("worker", "response") - - with pytest.raises(ValueError, match="response_format must be ManagerSelectionResponse"): - GroupChatBuilder().set_manager(manager_agent).participants([worker]) - - assert manager_agent.default_options.get("response_format") is CustomResponse - - -class TestFactoryFunctions: - """Tests for factory functions.""" - - def test_default_orchestrator_factory_without_manager_raises_error(self) -> None: - """Test that default factory requires manager to be set.""" - config = _GroupChatConfig(manager=None, manager_participant=None, manager_name="test", participants={}) - - with pytest.raises(RuntimeError, match="requires a manager to be configured"): - _default_orchestrator_factory(config) - -class TestConversationHandling: - """Tests for different conversation input types.""" + with pytest.raises(ValueError, match="At least one ChatMessage is required to start the group chat workflow."): + async for _ in workflow.run_stream([]): + pass async def test_handle_string_input(self) -> None: """Test handling string input creates proper ChatMessage.""" - def selector(state: GroupChatStateSnapshot) -> str | None: - # Verify the task was properly converted - assert state["task"].role == Role.USER - assert state["task"].text == "test string" - return None + def selector(state: GroupChatState) -> str: + # Verify the conversation has the user message + assert len(state.conversation) > 0 + assert state.conversation[0].role == Role.USER + assert state.conversation[0].text == "test string" + return "agent" agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() + ) outputs: list[list[ChatMessage]] = [] async for event in workflow.run_stream("test string"): @@ -871,14 +566,17 @@ async def test_handle_chat_message_input(self) -> None: """Test handling ChatMessage input directly.""" task_message = ChatMessage(role=Role.USER, text="test message") - def selector(state: GroupChatStateSnapshot) -> str | None: - # Verify the task message was preserved - assert state["task"] == task_message - return None + def selector(state: GroupChatState) -> str: + # Verify the task message was preserved in conversation + assert len(state.conversation) > 0 + assert state.conversation[0] == task_message + return "agent" agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() + ) outputs: list[list[ChatMessage]] = [] async for event in workflow.run_stream(task_message): @@ -896,15 +594,17 @@ async def test_handle_conversation_list_input(self) -> None: ChatMessage(role=Role.USER, text="user message"), ] - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: # Verify conversation context is preserved - assert len(state["conversation"]) == 2 - assert state["task"].text == "user message" - return None + assert len(state.conversation) >= 2 + assert state.conversation[-1].text == "user message" + return "agent" agent = StubAgent("agent", "response") - workflow = GroupChatBuilder().set_select_speakers_func(selector).participants([agent]).build() + workflow = ( + GroupChatBuilder().with_select_speaker_func(selector).participants([agent]).with_max_rounds(1).build() + ) outputs: list[list[ChatMessage]] = [] async for event in workflow.run_stream(conversation): @@ -920,10 +620,10 @@ class TestRoundLimitEnforcement: """Tests for round limit checking functionality.""" async def test_round_limit_in_apply_directive(self) -> None: - """Test round limit enforcement in _apply_directive.""" + """Test round limit enforcement.""" rounds_called = {"count": 0} - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: rounds_called["count"] += 1 # Keep trying to select agent to test limit enforcement return "agent" @@ -932,7 +632,7 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_max_rounds(1) # Very low limit .build() @@ -951,13 +651,13 @@ def selector(state: GroupChatStateSnapshot) -> str | None: conversation = outputs[-1] assert len(conversation) >= 1 final_output = conversation[-1] - assert "round limit" in final_output.text.lower() + assert "maximum number of rounds" in final_output.text.lower() async def test_round_limit_in_ingest_participant_message(self) -> None: """Test round limit enforcement after participant response.""" responses_received = {"count": 0} - def selector(state: GroupChatStateSnapshot) -> str | None: + def selector(state: GroupChatState) -> str: responses_received["count"] += 1 if responses_received["count"] == 1: return "agent" # First call selects agent @@ -967,7 +667,7 @@ def selector(state: GroupChatStateSnapshot) -> str | None: workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) .participants([agent]) .with_max_rounds(1) # Hit limit after first response .build() @@ -986,25 +686,29 @@ def selector(state: GroupChatStateSnapshot) -> str | None: conversation = outputs[-1] assert len(conversation) >= 1 final_output = conversation[-1] - assert "round limit" in final_output.text.lower() + assert "maximum number of rounds" in final_output.text.lower() async def test_group_chat_checkpoint_runtime_only() -> None: """Test checkpointing configured ONLY at runtime, not at build time.""" - from agent_framework import WorkflowRunState, WorkflowStatusEvent - storage = InMemoryCheckpointStorage() agent_a = StubAgent("agentA", "Reply from A") agent_b = StubAgent("agentB", "Reply from B") selector = make_sequence_selector() - wf = GroupChatBuilder().participants([agent_a, agent_b]).set_select_speakers_func(selector).build() + wf = ( + GroupChatBuilder() + .participants([agent_a, agent_b]) + .with_select_speaker_func(selector) + .with_max_rounds(2) + .build() + ) baseline_output: list[ChatMessage] | None = None async for ev in wf.run_stream("runtime checkpoint test", checkpoint_storage=storage): if isinstance(ev, WorkflowOutputEvent): - baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None + baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None # type: ignore if isinstance(ev, WorkflowStatusEvent) and ev.state in ( WorkflowRunState.IDLE, WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, @@ -1022,7 +726,6 @@ async def test_group_chat_checkpoint_runtime_overrides_buildtime() -> None: import tempfile with tempfile.TemporaryDirectory() as temp_dir1, tempfile.TemporaryDirectory() as temp_dir2: - from agent_framework import WorkflowRunState, WorkflowStatusEvent from agent_framework._workflows._checkpoint import FileCheckpointStorage buildtime_storage = FileCheckpointStorage(temp_dir1) @@ -1035,15 +738,15 @@ async def test_group_chat_checkpoint_runtime_overrides_buildtime() -> None: wf = ( GroupChatBuilder() .participants([agent_a, agent_b]) - .set_select_speakers_func(selector) + .with_select_speaker_func(selector) + .with_max_rounds(2) .with_checkpointing(buildtime_storage) .build() ) - baseline_output: list[ChatMessage] | None = None async for ev in wf.run_stream("override test", checkpoint_storage=runtime_storage): if isinstance(ev, WorkflowOutputEvent): - baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None + baseline_output = cast(list[ChatMessage], ev.data) if isinstance(ev.data, list) else None # type: ignore if isinstance(ev, WorkflowStatusEvent) and ev.state in ( WorkflowRunState.IDLE, WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, @@ -1059,37 +762,8 @@ async def test_group_chat_checkpoint_runtime_overrides_buildtime() -> None: assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden" -class _StubExecutor(Executor): - """Minimal executor used to satisfy workflow wiring in tests.""" - - def __init__(self, id: str) -> None: - super().__init__(id=id) - - @handler - async def handle(self, message: object, ctx: WorkflowContext[ChatMessage]) -> None: - await ctx.yield_output(message) - - -def test_set_manager_builds_with_agent_manager() -> None: - """GroupChatBuilder should build when using an agent-based manager.""" - - manager = _StubExecutor("manager_executor") - participant = _StubExecutor("participant_executor") - - workflow = ( - GroupChatBuilder().set_manager(manager, display_name="Moderator").participants({"worker": participant}).build() - ) - - orchestrator = workflow.get_start_executor() - - assert isinstance(orchestrator, GroupChatOrchestratorExecutor) - assert orchestrator._is_manager_agent() - - async def test_group_chat_with_request_info_filtering(): """Test that with_request_info(agents=[...]) only pauses before specified agents run.""" - from agent_framework import AgentInputRequest, RequestInfoEvent - # Create agents - we want to verify only beta triggers pause alpha = StubAgent("alpha", "response from alpha") beta = StubAgent("beta", "response from beta") @@ -1097,19 +771,21 @@ async def test_group_chat_with_request_info_filtering(): # Manager that selects alpha first, then beta, then finishes call_count = 0 - async def selector(state: GroupChatStateSnapshot) -> str | None: + async def selector(state: GroupChatState) -> str: nonlocal call_count call_count += 1 if call_count == 1: return "alpha" if call_count == 2: return "beta" - return None + # Return to alpha to continue + return "alpha" workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha, beta=beta) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha, beta]) + .with_max_rounds(2) .with_request_info(agents=["beta"]) # Only pause before beta runs .build() ) @@ -1117,7 +793,7 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: # Run until we get a request info event (should be before beta, not alpha) request_events: list[RequestInfoEvent] = [] async for event in workflow.run_stream("test task"): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentInputRequest): + if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentExecutorResponse): request_events.append(event) # Don't break - let stream complete naturally when paused @@ -1125,13 +801,15 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: assert len(request_events) == 1 request_event = request_events[0] - # The target agent should be beta's executor ID (groupchat_agent:beta) - assert request_event.data.target_agent_id is not None - assert "beta" in request_event.data.target_agent_id + # The target agent should be beta's executor ID + assert isinstance(request_event.data, AgentExecutorResponse) + assert request_event.source_executor_id == "beta" # Continue the workflow with a response outputs: list[WorkflowOutputEvent] = [] - async for event in workflow.send_responses_streaming({request_event.request_id: "continue please"}): + async for event in workflow.send_responses_streaming({ + request_event.request_id: AgentRequestInfoResponse.approve() + }): if isinstance(event, WorkflowOutputEvent): outputs.append(event) @@ -1141,25 +819,25 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: async def test_group_chat_with_request_info_no_filter_pauses_all(): """Test that with_request_info() without agents pauses before all participants.""" - from agent_framework import AgentInputRequest, RequestInfoEvent - # Create agents alpha = StubAgent("alpha", "response from alpha") # Manager selects alpha then finishes call_count = 0 - async def selector(state: GroupChatStateSnapshot) -> str | None: + async def selector(state: GroupChatState) -> str: nonlocal call_count call_count += 1 if call_count == 1: return "alpha" - return None + # Keep returning alpha to continue + return "alpha" workflow = ( GroupChatBuilder() - .set_select_speakers_func(selector, display_name="manager", final_message="done") - .participants(alpha=alpha) + .with_select_speaker_func(selector, orchestrator_name="manager") + .participants([alpha]) + .with_max_rounds(1) .with_request_info() # No filter - pause for all .build() ) @@ -1167,14 +845,13 @@ async def selector(state: GroupChatStateSnapshot) -> str | None: # Run until we get a request info event request_events: list[RequestInfoEvent] = [] async for event in workflow.run_stream("test task"): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentInputRequest): + if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentExecutorResponse): request_events.append(event) break # Should pause before alpha assert len(request_events) == 1 - assert request_events[0].data.target_agent_id is not None - assert "alpha" in request_events[0].data.target_agent_id + assert request_events[0].source_executor_id == "alpha" def test_group_chat_builder_with_request_info_returns_self(): diff --git a/python/packages/core/tests/workflow/test_handoff.py b/python/packages/core/tests/workflow/test_handoff.py index c39cc3ef08..268f89d513 100644 --- a/python/packages/core/tests/workflow/test_handoff.py +++ b/python/packages/core/tests/workflow/test_handoff.py @@ -1,151 +1,78 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, AsyncIterator -from dataclasses import dataclass +from collections.abc import AsyncIterable from typing import Any, cast -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from agent_framework import ( - AgentRunResponse, - AgentRunResponseUpdate, - BaseAgent, ChatAgent, ChatMessage, + ChatResponse, + ChatResponseUpdate, FunctionCallContent, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, RequestInfoEvent, Role, TextContent, WorkflowEvent, WorkflowOutputEvent, + resolve_agent_id, + use_function_invocation, ) -from agent_framework._mcp import MCPTool -from agent_framework._workflows import AgentRunEvent -from agent_framework._workflows import _handoff as handoff_module # type: ignore -from agent_framework._workflows._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value -from agent_framework._workflows._handoff import ( - _clone_chat_agent, # type: ignore[reportPrivateUsage] - _ConversationWithUserInput, - _UserInputGateway, -) -from agent_framework._workflows._workflow_builder import WorkflowBuilder - - -class _CountingWorkflowBuilder(WorkflowBuilder): - created: list["_CountingWorkflowBuilder"] = [] - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.start_calls = 0 - _CountingWorkflowBuilder.created.append(self) - - def set_start_executor(self, executor: Any) -> "_CountingWorkflowBuilder": # type: ignore[override] - self.start_calls += 1 - return cast("_CountingWorkflowBuilder", super().set_start_executor(executor)) - - -@dataclass -class _ComplexMetadata: - reason: str - payload: dict[str, str] - - -@pytest.fixture -def complex_metadata() -> _ComplexMetadata: - return _ComplexMetadata(reason="route", payload={"code": "X1"}) - - -def _metadata_from_conversation(conversation: list[ChatMessage], key: str) -> list[object]: - return [msg.additional_properties[key] for msg in conversation if key in msg.additional_properties] -def _conversation_debug(conversation: list[ChatMessage]) -> list[tuple[str, str | None, str]]: - return [ - (msg.role.value if hasattr(msg.role, "value") else str(msg.role), msg.author_name, msg.text) - for msg in conversation - ] +@use_function_invocation +class MockChatClient: + """Mock chat client for testing handoff workflows.""" + additional_properties: dict[str, Any] -class _RecordingAgent(BaseAgent): def __init__( self, - *, name: str, + *, handoff_to: str | None = None, - text_handoff: bool = False, - extra_properties: dict[str, object] | None = None, ) -> None: - super().__init__(id=name, name=name) - self._agent_name = name - self.handoff_to = handoff_to - self.calls: list[list[ChatMessage]] = [] - self._text_handoff = text_handoff - self._extra_properties = dict(extra_properties or {}) + """Initialize the mock chat client. + + Args: + name: The name of the agent using this chat client. + handoff_to: The name of the agent to hand off to, or None for no handoff. + This is hardcoded for testing purposes so that the agent always attempts to hand off. + """ + self._name = name + self._handoff_to = handoff_to self._call_index = 0 - async def run( # type: ignore[override] - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: Any = None, - **kwargs: Any, - ) -> AgentRunResponse: - conversation = _normalise(messages) - self.calls.append(conversation) - additional_properties = _merge_additional_properties( - self.handoff_to, self._text_handoff, self._extra_properties - ) - contents = _build_reply_contents(self._agent_name, self.handoff_to, self._text_handoff, self._next_call_id()) + async def get_response(self, messages: Any, **kwargs: Any) -> ChatResponse: + contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id()) reply = ChatMessage( role=Role.ASSISTANT, contents=contents, - author_name=self.name, - additional_properties=additional_properties, ) - return AgentRunResponse(messages=[reply]) + return ChatResponse(messages=reply, response_id="mock_response") - async def run_stream( # type: ignore[override] - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: Any = None, - **kwargs: Any, - ) -> AsyncIterator[AgentRunResponseUpdate]: - conversation = _normalise(messages) - self.calls.append(conversation) - additional_props = _merge_additional_properties(self.handoff_to, self._text_handoff, self._extra_properties) - contents = _build_reply_contents(self._agent_name, self.handoff_to, self._text_handoff, self._next_call_id()) - yield AgentRunResponseUpdate( - contents=contents, - role=Role.ASSISTANT, - additional_properties=additional_props, - ) + def get_streaming_response(self, messages: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]: + async def _stream() -> AsyncIterable[ChatResponseUpdate]: + contents = _build_reply_contents(self._name, self._handoff_to, self._next_call_id()) + yield ChatResponseUpdate(contents=contents, role=Role.ASSISTANT) + + return _stream() def _next_call_id(self) -> str | None: - if not self.handoff_to: + if not self._handoff_to: return None - call_id = f"{self.id}-handoff-{self._call_index}" + call_id = f"{self._name}-handoff-{self._call_index}" self._call_index += 1 return call_id -def _merge_additional_properties( - handoff_to: str | None, use_text_hint: bool, extras: dict[str, object] -) -> dict[str, object]: - additional_properties: dict[str, object] = {} - if handoff_to and not use_text_hint: - additional_properties["handoff_to"] = handoff_to - additional_properties.update(extras) - return additional_properties - - def _build_reply_contents( agent_name: str, handoff_to: str | None, - use_text_hint: bool, call_id: str | None, ) -> list[TextContent | FunctionCallContent]: contents: list[TextContent | FunctionCallContent] = [] @@ -154,161 +81,89 @@ def _build_reply_contents( FunctionCallContent(call_id=call_id, name=f"handoff_to_{handoff_to}", arguments={"handoff_to": handoff_to}) ) text = f"{agent_name} reply" - if use_text_hint and handoff_to: - text += f"\nHANDOFF_TO: {handoff_to}" contents.append(TextContent(text=text)) return contents -def _normalise(messages: str | ChatMessage | list[str] | list[ChatMessage] | None) -> list[ChatMessage]: - if isinstance(messages, list): - result: list[ChatMessage] = [] - for msg in messages: - if isinstance(msg, ChatMessage): - result.append(msg) - elif isinstance(msg, str): - result.append(ChatMessage(Role.USER, text=msg)) - return result - if isinstance(messages, ChatMessage): - return [messages] - if isinstance(messages, str): - return [ChatMessage(Role.USER, text=messages)] - return [] +class MockHandoffAgent(ChatAgent): + """Mock agent that can hand off to another agent.""" + + def __init__( + self, + *, + name: str, + handoff_to: str | None = None, + ) -> None: + """Initialize the mock handoff agent. + + Args: + name: The name of the agent. + handoff_to: The name of the agent to hand off to, or None for no handoff. + This is hardcoded for testing purposes so that the agent always attempts to hand off. + """ + super().__init__(chat_client=MockChatClient(name, handoff_to=handoff_to), name=name, id=name) async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: return [event async for event in stream] -async def test_specialist_to_specialist_handoff(): - """Test that specialists can hand off to other specialists via .add_handoff() configuration.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist") - specialist = _RecordingAgent(name="specialist", handoff_to="escalation") - escalation = _RecordingAgent(name="escalation") +async def test_handoff(): + """Test that agents can hand off to each other.""" + + # `triage` hands off to `specialist`, who then hands off to `escalation`. + # `escalation` has no handoff, so the workflow should request user input to continue. + triage = MockHandoffAgent(name="triage", handoff_to="specialist") + specialist = MockHandoffAgent(name="specialist", handoff_to="escalation") + escalation = MockHandoffAgent(name="escalation") + # Without explicitly defining handoffs, the builder will create connections + # between all agents. workflow = ( HandoffBuilder(participants=[triage, specialist, escalation]) - .set_coordinator(triage) - .add_handoff(triage, [specialist, escalation]) - .add_handoff(specialist, escalation) + .with_start_agent(triage) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) .build() ) - # Start conversation - triage hands off to specialist + # Start conversation - triage hands off to specialist then escalation + # escalation won't trigger a handoff, so the response from it will become + # a request for user input because autonomous mode is not enabled by default. events = await _drain(workflow.run_stream("Need technical support")) requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Specialist should have been called - assert len(specialist.calls) > 0 - - # Second user message - specialist hands off to escalation - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "This is complex"})) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs - - # Escalation should have been called - assert len(escalation.calls) > 0 - - -async def test_handoff_preserves_complex_additional_properties(complex_metadata: _ComplexMetadata): - triage = _RecordingAgent(name="triage", handoff_to="specialist", extra_properties={"complex": complex_metadata}) - specialist = _RecordingAgent(name="specialist") - - # Sanity check: agent response contains complex metadata before entering workflow - triage_response = await triage.run([ChatMessage(role=Role.USER, text="Need help with a return")]) - assert triage_response.messages - assert "complex" in triage_response.messages[0].additional_properties - - workflow = ( - HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role == Role.USER) >= 2) - .build() - ) - - # Initial run should preserve complex metadata in the triage response - events = await _drain(workflow.run_stream("Need help with a return")) - agent_events = [ev for ev in events if isinstance(ev, AgentRunEvent)] - if agent_events: - first_agent_event = agent_events[0] - first_agent_event_data = first_agent_event.data - if first_agent_event_data and first_agent_event_data.messages: - first_agent_message = first_agent_event_data.messages[0] - assert "complex" in first_agent_message.additional_properties, "Agent event lost complex metadata" - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests, "Workflow should request additional user input" - - request_data = requests[-1].data - assert isinstance(request_data, HandoffUserInputRequest) - conversation_snapshot = request_data.conversation - metadata_values = _metadata_from_conversation(conversation_snapshot, "complex") - assert metadata_values, ( - "Expected triage message in conversation, found " - f"additional_properties={[msg.additional_properties for msg in conversation_snapshot]}," - f" messages={_conversation_debug(conversation_snapshot)}" - ) - assert any(isinstance(value, _ComplexMetadata) for value in metadata_values), ( - "Complex metadata lost after first hop" - ) - restored_meta = next(value for value in metadata_values if isinstance(value, _ComplexMetadata)) - assert restored_meta.payload["code"] == "X1" - - # Respond and ensure metadata survives subsequent cycles - follow_up_events = await _drain( - workflow.send_responses_streaming({requests[-1].request_id: "Here are more details"}) - ) - follow_up_requests = [ev for ev in follow_up_events if isinstance(ev, RequestInfoEvent)] - outputs = [ev for ev in follow_up_events if isinstance(ev, WorkflowOutputEvent)] - - follow_up_conversation: list[ChatMessage] - if follow_up_requests: - follow_up_request_data = follow_up_requests[-1].data - assert isinstance(follow_up_request_data, HandoffUserInputRequest) - follow_up_conversation = follow_up_request_data.conversation - else: - assert outputs, "Workflow produced neither follow-up request nor output" - output_data = outputs[-1].data - follow_up_conversation = cast(list[ChatMessage], output_data) if isinstance(output_data, list) else [] - - metadata_values_after = _metadata_from_conversation(follow_up_conversation, "complex") - assert metadata_values_after, "Expected triage message after follow-up" - assert any(isinstance(value, _ComplexMetadata) for value in metadata_values_after), ( - "Complex metadata lost after restore" - ) - - restored_meta_after = next(value for value in metadata_values_after if isinstance(value, _ComplexMetadata)) - assert restored_meta_after.payload["code"] == "X1" + assert requests + assert len(requests) == 1 -async def test_tool_call_handoff_detection_with_text_hint(): - triage = _RecordingAgent(name="triage", handoff_to="specialist", text_handoff=True) - specialist = _RecordingAgent(name="specialist") - - workflow = HandoffBuilder(participants=[triage, specialist]).set_coordinator(triage).build() - - await _drain(workflow.run_stream("Package arrived broken")) - - assert specialist.calls, "Specialist should be invoked using handoff tool call" - assert len(specialist.calls[0]) >= 2 + request = requests[0] + assert isinstance(request.data, HandoffAgentUserRequest) + assert request.source_executor_id == escalation.name -async def test_autonomous_interaction_mode_yields_output_without_user_request(): +async def test_autonomous_mode_yields_output_without_user_request(): """Ensure autonomous interaction mode yields output without requesting user input.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist") - specialist = _RecordingAgent(name="specialist") + triage = MockHandoffAgent(name="triage", handoff_to="specialist") + specialist = MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .with_interaction_mode("autonomous", autonomous_turn_limit=1) + .with_start_agent(triage) + # Since specialist has no handoff, the specialist will be generating normal responses. + # With autonomous mode, this should continue until the termination condition is met. + .with_autonomous_mode( + agents=[specialist], + turn_limits={resolve_agent_id(specialist): 1}, + ) + # This termination condition ensures the workflow runs through both agents. + # First message is the user message to triage, second is triage's response, which + # is a handoff to specialist, third is specialist's response that should not request + # user input due to autonomous mode. Fourth message will come from the specialist + # again and will trigger termination. + .with_termination_condition(lambda conv: len(conv) >= 4) .build() ) events = await _drain(workflow.run_stream("Package arrived broken")) - assert len(triage.calls) == 1 - assert len(specialist.calls) == 1 requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert not requests, "Autonomous mode should not request additional user input" @@ -323,117 +178,31 @@ async def test_autonomous_interaction_mode_yields_output_without_user_request(): ) -async def test_autonomous_continues_without_handoff_until_termination(): - """Autonomous mode should keep invoking the same agent when no handoff occurs.""" - worker = _RecordingAgent(name="worker") - - workflow = ( - HandoffBuilder(participants=[worker]) - .set_coordinator(worker) - .with_interaction_mode("autonomous", autonomous_turn_limit=3) - .with_termination_condition(lambda conv: False) - .build() - ) - - events = await _drain(workflow.run_stream("Start")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Autonomous mode should yield output after termination condition" - assert len(worker.calls) == 3, "Worker should be invoked multiple times without user input" - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert not requests, "Autonomous mode should not request user input" - - -async def test_autonomous_turn_limit_stops_loop(): - """Autonomous mode should stop when the configured turn limit is reached.""" - worker = _RecordingAgent(name="worker") +async def test_autonomous_mode_resumes_user_input_on_turn_limit(): + """Autonomous mode should resume user input request when turn limit is reached.""" + triage = MockHandoffAgent(name="triage", handoff_to="worker") + worker = MockHandoffAgent(name="worker") workflow = ( - HandoffBuilder(participants=[worker]) - .set_coordinator(worker) - .with_interaction_mode("autonomous", autonomous_turn_limit=2) + HandoffBuilder(participants=[triage, worker]) + .with_start_agent(triage) + .with_autonomous_mode(agents=[worker], turn_limits={resolve_agent_id(worker): 2}) .with_termination_condition(lambda conv: False) .build() ) events = await _drain(workflow.run_stream("Start")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Turn limit should force a workflow output" - assert len(worker.calls) == 2, "Worker should stop after reaching the turn limit" requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert not requests, "Autonomous mode should not request user input" - - -async def test_autonomous_routes_back_to_coordinator_when_specialist_stops(): - """Specialist without handoff should route back to coordinator in autonomous mode.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist") - specialist = _RecordingAgent(name="specialist") - - workflow = ( - HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .add_handoff(triage, specialist) - .with_interaction_mode("autonomous", autonomous_turn_limit=3) - .with_termination_condition(lambda conv: len(conv) >= 4) - .build() - ) - - events = await _drain(workflow.run_stream("Issue")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Workflow should complete without user input" - assert len(specialist.calls) >= 1, "Specialist should run without handoff" - - -async def test_autonomous_mode_with_inline_turn_limit(): - """Autonomous mode should respect turn limit passed via with_interaction_mode.""" - worker = _RecordingAgent(name="worker") - - workflow = ( - HandoffBuilder(participants=[worker]) - .set_coordinator(worker) - .with_interaction_mode("autonomous", autonomous_turn_limit=2) - .with_termination_condition(lambda conv: False) - .build() - ) - - events = await _drain(workflow.run_stream("Start")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Turn limit should force a workflow output" - assert len(worker.calls) == 2, "Worker should stop after reaching the inline turn limit" - - -def test_autonomous_turn_limit_ignored_in_human_in_loop_mode(caplog): - """Verify that autonomous_turn_limit logs a warning when mode is human_in_loop.""" - worker = _RecordingAgent(name="worker") - - # Should not raise, but should log a warning - HandoffBuilder(participants=[worker]).set_coordinator(worker).with_interaction_mode( - "human_in_loop", autonomous_turn_limit=10 - ) - - assert "autonomous_turn_limit=10 was provided but interaction_mode is 'human_in_loop'; ignoring." in caplog.text + assert requests and len(requests) == 1, "Turn limit should force a user input request" + assert requests[0].source_executor_id == worker.name -def test_autonomous_turn_limit_must_be_positive(): - """Verify that autonomous_turn_limit raises an error when <= 0.""" - worker = _RecordingAgent(name="worker") +def test_build_fails_without_start_agent(): + """Verify that build() raises ValueError when with_start_agent() was not called.""" + triage = MockHandoffAgent(name="triage") + specialist = MockHandoffAgent(name="specialist") - with pytest.raises(ValueError, match="autonomous_turn_limit must be positive"): - HandoffBuilder(participants=[worker]).set_coordinator(worker).with_interaction_mode( - "autonomous", autonomous_turn_limit=0 - ) - - with pytest.raises(ValueError, match="autonomous_turn_limit must be positive"): - HandoffBuilder(participants=[worker]).set_coordinator(worker).with_interaction_mode( - "autonomous", autonomous_turn_limit=-5 - ) - - -def test_build_fails_without_coordinator(): - """Verify that build() raises ValueError when set_coordinator() was not called.""" - triage = _RecordingAgent(name="triage") - specialist = _RecordingAgent(name="specialist") - - with pytest.raises(ValueError, match=r"Must call set_coordinator\(...\) before building the workflow."): + with pytest.raises(ValueError, match=r"Must call with_start_agent\(...\) before building the workflow."): HandoffBuilder(participants=[triage, specialist]).build() @@ -453,11 +222,12 @@ async def async_termination(conv: list[ChatMessage]) -> bool: user_count = sum(1 for msg in conv if msg.role == Role.USER) return user_count >= 2 - coordinator = _RecordingAgent(name="coordinator") + coordinator = MockHandoffAgent(name="coordinator", handoff_to="worker") + worker = MockHandoffAgent(name="worker") workflow = ( - HandoffBuilder(participants=[coordinator]) - .set_coordinator(coordinator) + HandoffBuilder(participants=[coordinator, worker]) + .with_start_agent(coordinator) .with_termination_condition(async_termination) .build() ) @@ -466,7 +236,11 @@ async def async_termination(conv: list[ChatMessage]) -> bool: requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert requests - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Second user message"})) + events = await _drain( + workflow.send_responses_streaming({ + requests[-1].request_id: [ChatMessage(role=Role.USER, text="Second user message")] + }) + ) outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] assert len(outputs) == 1 @@ -478,188 +252,8 @@ async def async_termination(conv: list[ChatMessage]) -> bool: assert termination_call_count > 0 -async def test_clone_chat_agent_preserves_mcp_tools() -> None: - """Test that _clone_chat_agent preserves MCP tools when cloning an agent.""" - mock_chat_client = MagicMock() - - mock_mcp_tool = MagicMock(spec=MCPTool) - mock_mcp_tool.name = "test_mcp_tool" - - def sample_function() -> str: - return "test" - - original_agent = ChatAgent( - chat_client=mock_chat_client, - name="TestAgent", - instructions="Test instructions", - tools=[mock_mcp_tool, sample_function], - ) - - assert hasattr(original_agent, "_local_mcp_tools") - assert len(original_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage] - assert original_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage] - - cloned_agent = _clone_chat_agent(original_agent) - - assert hasattr(cloned_agent, "_local_mcp_tools") - assert len(cloned_agent._local_mcp_tools) == 1 # type: ignore[reportPrivateUsage] - assert cloned_agent._local_mcp_tools[0] == mock_mcp_tool # type: ignore[reportPrivateUsage] - assert cloned_agent.default_options.get("tools") is not None - assert len(cloned_agent.default_options.get("tools")) == 1 - - -async def test_return_to_previous_routing(): - """Test that return-to-previous routes back to the current specialist handling the conversation.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b") - specialist_b = _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) - .add_handoff(triage, [specialist_a, specialist_b]) - .add_handoff(specialist_a, specialist_b) - .enable_return_to_previous(True) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 4) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - assert len(specialist_a.calls) > 0 - - # Specialist_a should have been called with initial request - initial_specialist_a_calls = len(specialist_a.calls) - - # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Specialist_b should have been called - assert len(specialist_b.calls) > 0 - initial_specialist_b_calls = len(specialist_b.calls) - - # Third user message - with return_to_previous, should route back to specialist_b (current agent) - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"})) - third_requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - - # Specialist_b should have been called again (return-to-previous routes to current agent) - assert len(specialist_b.calls) > initial_specialist_b_calls, ( - "Specialist B should be called again due to return-to-previous routing to current agent" - ) - - # Specialist_a should NOT be called again (it's no longer the current agent) - assert len(specialist_a.calls) == initial_specialist_a_calls, ( - "Specialist A should not be called again - specialist_b is the current agent" - ) - - # Triage should only have been called once at the start - assert len(triage.calls) == 1, "Triage should only be called once (initial routing)" - - # Verify awaiting_agent_id is set to specialist_b (the agent that just responded) - if third_requests: - user_input_req = third_requests[-1].data - assert isinstance(user_input_req, HandoffUserInputRequest) - assert user_input_req.awaiting_agent_id == "specialist_b", ( - f"Expected awaiting_agent_id 'specialist_b' but got '{user_input_req.awaiting_agent_id}'" - ) - - -async def test_return_to_previous_disabled_routes_to_coordinator(): - """Test that with return-to-previous disabled, routing goes back to coordinator.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a", handoff_to="specialist_b") - specialist_b = _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) - .add_handoff(triage, [specialist_a, specialist_b]) - .add_handoff(specialist_a, specialist_b) - .enable_return_to_previous(False) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - assert len(triage.calls) == 1 - - # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need more help"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Third user message - without return_to_previous, should route back to triage - await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"})) - - # Triage should have been called twice total: initial + after specialist_b responds - assert len(triage.calls) == 2, "Triage should be called twice (initial + default routing to coordinator)" - - -async def test_return_to_previous_enabled(): - """Verify that enable_return_to_previous() keeps control with the current specialist.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a") - specialist_b = _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) - .enable_return_to_previous(True) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - assert len(triage.calls) == 1 - assert len(specialist_a.calls) == 1 - - # Second user message - with return_to_previous, should route to specialist_a (not triage) - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up question"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Triage should only have been called once (initial) - specialist_a handles follow-up - assert len(triage.calls) == 1, "Triage should only be called once (initial)" - assert len(specialist_a.calls) == 2, "Specialist A should handle follow-up with return_to_previous enabled" - - -def test_handoff_builder_sets_start_executor_once(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure HandoffBuilder.build sets the start executor only once when assembling the workflow.""" - _CountingWorkflowBuilder.created.clear() - monkeypatch.setattr(handoff_module, "WorkflowBuilder", _CountingWorkflowBuilder) - - coordinator = _RecordingAgent(name="coordinator") - specialist = _RecordingAgent(name="specialist") - - workflow = ( - HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator(coordinator) - .with_termination_condition(lambda conv: len(conv) > 0) - .build() - ) - - assert workflow is not None - assert _CountingWorkflowBuilder.created, "Expected CountingWorkflowBuilder to be instantiated" - builder = _CountingWorkflowBuilder.created[-1] - assert builder.start_calls == 1, "set_start_executor should be invoked exactly once" - - async def test_tool_choice_preserved_from_agent_config(): """Verify that agent-level tool_choice configuration is preserved and not overridden.""" - from unittest.mock import AsyncMock - - from agent_framework import ChatResponse - # Create a mock chat client that records the tool_choice used recorded_tool_choices: list[Any] = [] @@ -691,96 +285,6 @@ async def mock_get_response(messages: Any, options: dict[str, Any] | None = None assert last_tool_choice == {"mode": "required"}, f"Expected 'required', got {last_tool_choice}" -async def test_handoff_builder_with_request_info(): - """Test that HandoffBuilder supports request info via with_request_info().""" - from agent_framework import AgentInputRequest, RequestInfoEvent - - # Create test agents - coordinator = _RecordingAgent(name="coordinator") - specialist = _RecordingAgent(name="specialist") - - # Build workflow with request info enabled - workflow = ( - HandoffBuilder(participants=[coordinator, specialist]) - .set_coordinator(coordinator) - .with_termination_condition(lambda conv: len([m for m in conv if m.role == Role.USER]) >= 1) - .with_request_info() - .build() - ) - - # Run workflow until it pauses for request info - request_event: RequestInfoEvent | None = None - async for event in workflow.run_stream("Hello"): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, AgentInputRequest): - request_event = event - - # Verify request info was emitted - assert request_event is not None, "Request info should have been emitted" - assert isinstance(request_event.data, AgentInputRequest) - - # Provide response and continue - output_events: list[WorkflowOutputEvent] = [] - async for event in workflow.send_responses_streaming({request_event.request_id: "approved"}): - if isinstance(event, WorkflowOutputEvent): - output_events.append(event) - - # Verify we got output events - assert len(output_events) > 0, "Should produce output events after response" - - -async def test_handoff_builder_with_request_info_method_chaining(): - """Test that with_request_info returns self for method chaining.""" - coordinator = _RecordingAgent(name="coordinator") - - builder = HandoffBuilder(participants=[coordinator]) - result = builder.with_request_info() - - assert result is builder, "with_request_info should return self for chaining" - assert builder._request_info_enabled is True # type: ignore - - -async def test_return_to_previous_state_serialization(): - """Test that return_to_previous state is properly serialized/deserialized for checkpointing.""" - from agent_framework._workflows._handoff import _HandoffCoordinator # type: ignore[reportPrivateUsage] - - # Create a coordinator with return_to_previous enabled - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - return_to_previous=True, - ) - - # Set the current agent (simulating a handoff scenario) - coordinator._current_agent_id = "specialist_a" # type: ignore[reportPrivateUsage] - - # Snapshot the state - state = await coordinator.on_checkpoint_save() - - # Verify pattern metadata includes current_agent_id - assert "metadata" in state - assert "current_agent_id" in state["metadata"] - assert state["metadata"]["current_agent_id"] == "specialist_a" - - # Create a new coordinator and restore state - coordinator2 = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a", "specialist_b": "specialist_b"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - return_to_previous=True, - ) - - # Restore state - await coordinator2.on_checkpoint_restore(state) - - # Verify current_agent_id was restored - assert coordinator2._current_agent_id == "specialist_a", "Current agent should be restored from checkpoint" # type: ignore[reportPrivateUsage] - - # region Participant Factory Tests @@ -796,43 +300,43 @@ def test_handoff_builder_rejects_empty_participant_factories(): def test_handoff_builder_rejects_mixing_participants_and_factories(): """Test that mixing participants and participant_factories in __init__ raises an error.""" - triage = _RecordingAgent(name="triage") + triage = MockHandoffAgent(name="triage") with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder(participants=[triage], participant_factories={"triage": lambda: triage}) def test_handoff_builder_rejects_mixing_participants_and_participant_factories_methods(): """Test that mixing .participants() and .participant_factories() raises an error.""" - triage = _RecordingAgent(name="triage") + triage = MockHandoffAgent(name="triage") # Case 1: participants first, then participant_factories with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder(participants=[triage]).participant_factories({ - "specialist": lambda: _RecordingAgent(name="specialist") + "specialist": lambda: MockHandoffAgent(name="specialist") }) # Case 2: participant_factories first, then participants with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder(participant_factories={"triage": lambda: triage}).participants([ - _RecordingAgent(name="specialist") + MockHandoffAgent(name="specialist") ]) # Case 3: participants(), then participant_factories() with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder().participants([triage]).participant_factories({ - "specialist": lambda: _RecordingAgent(name="specialist") + "specialist": lambda: MockHandoffAgent(name="specialist") }) # Case 4: participant_factories(), then participants() with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder().participant_factories({"triage": lambda: triage}).participants([ - _RecordingAgent(name="specialist") + MockHandoffAgent(name="specialist") ]) # Case 5: mix during initialization with pytest.raises(ValueError, match="Cannot mix .participants"): HandoffBuilder( - participants=[triage], participant_factories={"specialist": lambda: _RecordingAgent(name="specialist")} + participants=[triage], participant_factories={"specialist": lambda: MockHandoffAgent(name="specialist")} ) @@ -841,60 +345,49 @@ def test_handoff_builder_rejects_multiple_calls_to_participant_factories(): with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): ( HandoffBuilder() - .participant_factories({"agent1": lambda: _RecordingAgent(name="agent1")}) - .participant_factories({"agent2": lambda: _RecordingAgent(name="agent2")}) + .participant_factories({"agent1": lambda: MockHandoffAgent(name="agent1")}) + .participant_factories({"agent2": lambda: MockHandoffAgent(name="agent2")}) ) def test_handoff_builder_rejects_multiple_calls_to_participants(): """Test that multiple calls to .participants() raises an error.""" with pytest.raises(ValueError, match="participants have already been assigned"): - (HandoffBuilder().participants([_RecordingAgent(name="agent1")]).participants([_RecordingAgent(name="agent2")])) - - -def test_handoff_builder_rejects_duplicate_factories(): - """Test that multiple calls to participant_factories are rejected.""" - factories = { - "triage": lambda: _RecordingAgent(name="triage"), - "specialist": lambda: _RecordingAgent(name="specialist"), - } - - # Multiple calls to participant_factories should fail - builder = HandoffBuilder(participant_factories=factories) - with pytest.raises(ValueError, match=r"participant_factories\(\) has already been called"): - builder.participant_factories({"triage": lambda: _RecordingAgent(name="triage2")}) + ( + HandoffBuilder() + .participants([MockHandoffAgent(name="agent1")]) + .participants([MockHandoffAgent(name="agent2")]) + ) def test_handoff_builder_rejects_instance_coordinator_with_factories(): """Test that using an agent instance for set_coordinator when using factories raises an error.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") # Create an agent instance - coordinator_instance = _RecordingAgent(name="coordinator") + coordinator_instance = MockHandoffAgent(name="coordinator") - with pytest.raises(ValueError, match=r"Call participants\(\.\.\.\) before coordinator\(\.\.\.\)"): + with pytest.raises(ValueError, match=r"Call participants\(\.\.\.\) before with_start_agent\(\.\.\.\)"): ( HandoffBuilder( participant_factories={"triage": create_triage, "specialist": create_specialist} - ).set_coordinator(coordinator_instance) # Instance, not factory name + ).with_start_agent(coordinator_instance) # Instance, not factory name ) def test_handoff_builder_rejects_factory_name_coordinator_with_instances(): """Test that using a factory name for set_coordinator when using instances raises an error.""" - triage = _RecordingAgent(name="triage") - specialist = _RecordingAgent(name="specialist") + triage = MockHandoffAgent(name="triage") + specialist = MockHandoffAgent(name="specialist") - with pytest.raises( - ValueError, match="coordinator factory name 'triage' is not part of the participant_factories list" - ): + with pytest.raises(ValueError, match="Call participant_factories.*before with_start_agent"): ( - HandoffBuilder(participants=[triage, specialist]).set_coordinator( + HandoffBuilder(participants=[triage, specialist]).with_start_agent( "triage" ) # String factory name, not instance ) @@ -902,28 +395,28 @@ def test_handoff_builder_rejects_factory_name_coordinator_with_instances(): def test_handoff_builder_rejects_mixed_types_in_add_handoff_source(): """Test that add_handoff rejects factory name source with instance-based participants.""" - triage = _RecordingAgent(name="triage") - specialist = _RecordingAgent(name="specialist") + triage = MockHandoffAgent(name="triage") + specialist = MockHandoffAgent(name="specialist") - with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol/Executor instances"): + with pytest.raises(TypeError, match="Cannot mix factory names \\(str\\) and AgentProtocol.*instances"): ( HandoffBuilder(participants=[triage, specialist]) - .set_coordinator(triage) - .add_handoff("triage", specialist) # String source with instance participants + .with_start_agent(triage) + .add_handoff("triage", [specialist]) # String source with instance participants ) def test_handoff_builder_accepts_all_factory_names_in_add_handoff(): """Test that add_handoff accepts all factory names when using participant_factories.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a") + def create_specialist_a() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_a") - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") + def create_specialist_b() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_b") # This should work - all strings with participant_factories builder = ( @@ -934,7 +427,7 @@ def create_specialist_b() -> _RecordingAgent: "specialist_b": create_specialist_b, } ) - .set_coordinator("triage") + .with_start_agent("triage") .add_handoff("triage", ["specialist_a", "specialist_b"]) ) @@ -946,14 +439,14 @@ def create_specialist_b() -> _RecordingAgent: def test_handoff_builder_accepts_all_instances_in_add_handoff(): """Test that add_handoff accepts all instances when using participants.""" - triage = _RecordingAgent(name="triage", handoff_to="specialist_a") - specialist_a = _RecordingAgent(name="specialist_a") - specialist_b = _RecordingAgent(name="specialist_b") + triage = MockHandoffAgent(name="triage", handoff_to="specialist_a") + specialist_a = MockHandoffAgent(name="specialist_a") + specialist_b = MockHandoffAgent(name="specialist_b") # This should work - all instances with participants builder = ( HandoffBuilder(participants=[triage, specialist_a, specialist_b]) - .set_coordinator(triage) + .with_start_agent(triage) .add_handoff(triage, [specialist_a, specialist_b]) ) @@ -967,19 +460,19 @@ async def test_handoff_with_participant_factories(): """Test workflow creation using participant_factories.""" call_count = 0 - def create_triage() -> _RecordingAgent: + def create_triage() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="triage", handoff_to="specialist") + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: + def create_specialist() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="specialist") + return MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") + .with_start_agent("triage") .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) .build() ) @@ -992,7 +485,9 @@ def create_specialist() -> _RecordingAgent: assert requests # Follow-up message - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "More details"})) + events = await _drain( + workflow.send_responses_streaming({requests[-1].request_id: [ChatMessage(role=Role.USER, text="More details")]}) + ) outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] assert outputs @@ -1001,19 +496,19 @@ async def test_handoff_participant_factories_reusable_builder(): """Test that the builder can be reused to build multiple workflows with factories.""" call_count = 0 - def create_triage() -> _RecordingAgent: + def create_triage() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="triage", handoff_to="specialist") + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: + def create_specialist() -> MockHandoffAgent: nonlocal call_count call_count += 1 - return _RecordingAgent(name="specialist") + return MockHandoffAgent(name="specialist") builder = HandoffBuilder( participant_factories={"triage": create_triage, "specialist": create_specialist} - ).set_coordinator("triage") + ).with_start_agent("triage") # Build first workflow wf1 = builder.build() @@ -1031,14 +526,14 @@ def create_specialist() -> _RecordingAgent: async def test_handoff_with_participant_factories_and_add_handoff(): """Test that .add_handoff() works correctly with participant_factories.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist_a") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist_a") - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") + def create_specialist_a() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_a", handoff_to="specialist_b") - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") + def create_specialist_b() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_b") workflow = ( HandoffBuilder( @@ -1048,9 +543,9 @@ def create_specialist_b() -> _RecordingAgent: "specialist_b": create_specialist_b, } ) - .set_coordinator("triage") + .with_start_agent("triage") .add_handoff("triage", ["specialist_a", "specialist_b"]) - .add_handoff("specialist_a", "specialist_b") + .add_handoff("specialist_a", ["specialist_b"]) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) .build() ) @@ -1064,7 +559,11 @@ def create_specialist_b() -> _RecordingAgent: assert "specialist_a" in workflow.executors # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) + events = await _drain( + workflow.send_responses_streaming({ + requests[-1].request_id: [ChatMessage(role=Role.USER, text="Need escalation")] + }) + ) requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert requests @@ -1078,15 +577,15 @@ async def test_handoff_participant_factories_with_checkpointing(): storage = InMemoryCheckpointStorage() - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") + .with_start_agent("triage") .with_checkpointing(storage) .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 2) .build() @@ -1097,7 +596,9 @@ def create_specialist() -> _RecordingAgent: requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] assert requests - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "follow up"})) + events = await _drain( + workflow.send_responses_streaming({requests[-1].request_id: [ChatMessage(role=Role.USER, text="follow up")]}) + ) outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] assert outputs, "Should have workflow output after termination condition is met" @@ -1109,15 +610,15 @@ def create_specialist() -> _RecordingAgent: def test_handoff_set_coordinator_with_factory_name(): """Test that set_coordinator accepts factory name as string.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") builder = HandoffBuilder( participant_factories={"triage": create_triage, "specialist": create_specialist} - ).set_coordinator("triage") + ).with_start_agent("triage") workflow = builder.build() assert "triage" in workflow.executors @@ -1126,14 +627,14 @@ def create_specialist() -> _RecordingAgent: def test_handoff_add_handoff_with_factory_names(): """Test that add_handoff accepts factory names as strings.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist_a") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist_a") - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a") + def create_specialist_a() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_a") - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") + def create_specialist_b() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist_b") builder = ( HandoffBuilder( @@ -1143,7 +644,7 @@ def create_specialist_b() -> _RecordingAgent: "specialist_b": create_specialist_b, } ) - .set_coordinator("triage") + .with_start_agent("triage") .add_handoff("triage", ["specialist_a", "specialist_b"]) ) @@ -1156,516 +657,53 @@ def create_specialist_b() -> _RecordingAgent: async def test_handoff_participant_factories_autonomous_mode(): """Test autonomous mode with participant_factories.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage", handoff_to="specialist") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") workflow = ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") - .with_interaction_mode("autonomous", autonomous_turn_limit=2) + .with_start_agent("triage") + .with_autonomous_mode(agents=["specialist"], turn_limits={"specialist": 1}) .build() ) events = await _drain(workflow.run_stream("Issue")) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs, "Autonomous mode should yield output" requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert not requests, "Autonomous mode should not request user input" - - -async def test_handoff_participant_factories_with_request_info(): - """Test that .with_request_info() works with participant_factories.""" - - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") - - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") - - builder = ( - HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") - .with_request_info(agents=["triage"]) - ) - - workflow = builder.build() - assert "triage" in workflow.executors + assert requests and len(requests) == 1 + assert requests[0].source_executor_id == "specialist" def test_handoff_participant_factories_invalid_coordinator_name(): """Test that set_coordinator raises error for non-existent factory name.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") with pytest.raises( - ValueError, match="coordinator factory name 'nonexistent' is not part of the participant_factories list" + ValueError, match="Start agent factory name 'nonexistent' is not in the participant_factories list" ): - (HandoffBuilder(participant_factories={"triage": create_triage}).set_coordinator("nonexistent").build()) + (HandoffBuilder(participant_factories={"triage": create_triage}).with_start_agent("nonexistent").build()) def test_handoff_participant_factories_invalid_handoff_target(): """Test that add_handoff raises error for non-existent target factory name.""" - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage") + def create_triage() -> MockHandoffAgent: + return MockHandoffAgent(name="triage") - def create_specialist() -> _RecordingAgent: - return _RecordingAgent(name="specialist") + def create_specialist() -> MockHandoffAgent: + return MockHandoffAgent(name="specialist") with pytest.raises(ValueError, match="Target factory name 'nonexistent' is not in the participant_factories list"): ( HandoffBuilder(participant_factories={"triage": create_triage, "specialist": create_specialist}) - .set_coordinator("triage") - .add_handoff("triage", "nonexistent") + .with_start_agent("triage") + .add_handoff("triage", ["nonexistent"]) .build() ) -async def test_handoff_participant_factories_enable_return_to_previous(): - """Test return_to_previous works with participant_factories.""" - - def create_triage() -> _RecordingAgent: - return _RecordingAgent(name="triage", handoff_to="specialist_a") - - def create_specialist_a() -> _RecordingAgent: - return _RecordingAgent(name="specialist_a", handoff_to="specialist_b") - - def create_specialist_b() -> _RecordingAgent: - return _RecordingAgent(name="specialist_b") - - workflow = ( - HandoffBuilder( - participant_factories={ - "triage": create_triage, - "specialist_a": create_specialist_a, - "specialist_b": create_specialist_b, - } - ) - .set_coordinator("triage") - .add_handoff("triage", ["specialist_a", "specialist_b"]) - .add_handoff("specialist_a", "specialist_b") - .enable_return_to_previous(True) - .with_termination_condition(lambda conv: sum(1 for m in conv if m.role == Role.USER) >= 3) - .build() - ) - - # Start conversation - triage hands off to specialist_a - events = await _drain(workflow.run_stream("Initial request")) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Second user message - specialist_a hands off to specialist_b - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Need escalation"})) - requests = [ev for ev in events if isinstance(ev, RequestInfoEvent)] - assert requests - - # Third user message - should route back to specialist_b (return to previous) - events = await _drain(workflow.send_responses_streaming({requests[-1].request_id: "Follow up"})) - outputs = [ev for ev in events if isinstance(ev, WorkflowOutputEvent)] - assert outputs or [ev for ev in events if isinstance(ev, RequestInfoEvent)] - - # endregion Participant Factory Tests - - -async def test_handoff_user_input_request_checkpoint_excludes_conversation(): - """Test that HandoffUserInputRequest serialization excludes conversation to prevent duplication. - - Issue #2667: When checkpointing a workflow with a pending HandoffUserInputRequest, - the conversation field gets serialized twice: once in the RequestInfoEvent's data - and once in the coordinator's conversation state. On restore, this causes duplicate - messages. - - The fix is to exclude the conversation field during checkpoint serialization since - the conversation is already preserved in the coordinator's state. - """ - # Create a conversation history - conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi there!"), - ChatMessage(role=Role.USER, text="Help me"), - ] - - # Create a HandoffUserInputRequest with the conversation - request = HandoffUserInputRequest( - conversation=conversation, - awaiting_agent_id="specialist_agent", - prompt="Please provide your input", - source_executor_id="gateway", - ) - - # Encode the request (simulating checkpoint save) - encoded = encode_checkpoint_value(request) - - # Verify conversation is NOT in the encoded output - # The fix should exclude conversation from serialization - assert isinstance(encoded, dict) - - # If using MODEL_MARKER strategy (to_dict/from_dict) - if "__af_model__" in encoded or "__af_dataclass__" in encoded: - value = encoded.get("value", {}) - assert "conversation" not in value, "conversation should be excluded from checkpoint serialization" - - # Decode the request (simulating checkpoint restore) - decoded = decode_checkpoint_value(encoded) - - # Verify the decoded request is a HandoffUserInputRequest - assert isinstance(decoded, HandoffUserInputRequest) - - # Verify other fields are preserved - assert decoded.awaiting_agent_id == "specialist_agent" - assert decoded.prompt == "Please provide your input" - assert decoded.source_executor_id == "gateway" - - # Conversation should be an empty list after deserialization - # (will be reconstructed from coordinator state on restore) - assert decoded.conversation == [] - - -async def test_handoff_user_input_request_roundtrip_preserves_metadata(): - """Test that non-conversation fields survive checkpoint roundtrip.""" - request = HandoffUserInputRequest( - conversation=[ChatMessage(role=Role.USER, text="test")], - awaiting_agent_id="test_agent", - prompt="Enter your response", - source_executor_id="test_gateway", - ) - - # Roundtrip through checkpoint encoding - encoded = encode_checkpoint_value(request) - decoded = decode_checkpoint_value(encoded) - - assert isinstance(decoded, HandoffUserInputRequest) - assert decoded.awaiting_agent_id == request.awaiting_agent_id - assert decoded.prompt == request.prompt - assert decoded.source_executor_id == request.source_executor_id - - -async def test_request_info_event_with_handoff_user_input_request(): - """Test RequestInfoEvent serialization with HandoffUserInputRequest data.""" - conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="How can I help?"), - ] - - request = HandoffUserInputRequest( - conversation=conversation, - awaiting_agent_id="specialist", - prompt="Provide input", - source_executor_id="gateway", - ) - - # Create a RequestInfoEvent wrapping the request - event = RequestInfoEvent( - request_id="test-request-123", - source_executor_id="gateway", - request_data=request, - response_type=object, - ) - - # Serialize the event - event_dict = event.to_dict() - - # Verify the data field doesn't contain conversation - data_encoded = event_dict["data"] - if isinstance(data_encoded, dict) and ("__af_model__" in data_encoded or "__af_dataclass__" in data_encoded): - value = data_encoded.get("value", {}) - assert "conversation" not in value - - # Deserialize and verify - restored_event = RequestInfoEvent.from_dict(event_dict) - assert isinstance(restored_event.data, HandoffUserInputRequest) - assert restored_event.data.awaiting_agent_id == "specialist" - assert restored_event.data.conversation == [] - - -async def test_handoff_user_input_request_to_dict_excludes_conversation(): - """Test that to_dict() method excludes conversation field.""" - conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi!"), - ] - - request = HandoffUserInputRequest( - conversation=conversation, - awaiting_agent_id="agent1", - prompt="Enter input", - source_executor_id="gateway", - ) - - # Call to_dict directly - data = request.to_dict() - - # Verify conversation is excluded - assert "conversation" not in data - assert data["awaiting_agent_id"] == "agent1" - assert data["prompt"] == "Enter input" - assert data["source_executor_id"] == "gateway" - - -async def test_handoff_user_input_request_from_dict_creates_empty_conversation(): - """Test that from_dict() creates an instance with empty conversation.""" - data = { - "awaiting_agent_id": "agent1", - "prompt": "Enter input", - "source_executor_id": "gateway", - } - - request = HandoffUserInputRequest.from_dict(data) - - assert request.conversation == [] - assert request.awaiting_agent_id == "agent1" - assert request.prompt == "Enter input" - assert request.source_executor_id == "gateway" - - -async def test_user_input_gateway_resume_handles_empty_conversation(): - """Test that _UserInputGateway.resume_from_user handles post-restore scenario. - - After checkpoint restore, the HandoffUserInputRequest will have an empty - conversation. The gateway should handle this by sending only the new user - messages to the coordinator. - """ - from unittest.mock import AsyncMock - - # Create a gateway - gateway = _UserInputGateway( - starting_agent_id="coordinator", - prompt="Enter input", - id="test-gateway", - ) - - # Simulate post-restore: request with empty conversation - restored_request = HandoffUserInputRequest( - conversation=[], # Empty after restore - awaiting_agent_id="specialist", - prompt="Enter input", - source_executor_id="test-gateway", - ) - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Call resume_from_user with a user response - await gateway.resume_from_user(restored_request, "New user message", mock_ctx) - - # Verify send_message was called - mock_ctx.send_message.assert_called_once() - - # Get the message that was sent - call_args = mock_ctx.send_message.call_args - sent_message = call_args[0][0] - - # Verify it's a _ConversationWithUserInput - assert isinstance(sent_message, _ConversationWithUserInput) - - # Verify it contains only the new user message (not any history) - assert len(sent_message.full_conversation) == 1 - assert sent_message.full_conversation[0].role == Role.USER - assert sent_message.full_conversation[0].text == "New user message" - - -async def test_user_input_gateway_resume_with_full_conversation(): - """Test that _UserInputGateway.resume_from_user handles normal flow correctly. - - In normal flow (no checkpoint restore), the HandoffUserInputRequest has - the full conversation. The gateway should send the full conversation - plus the new user messages. - """ - from unittest.mock import AsyncMock - - # Create a gateway - gateway = _UserInputGateway( - starting_agent_id="coordinator", - prompt="Enter input", - id="test-gateway", - ) - - # Normal flow: request with full conversation - normal_request = HandoffUserInputRequest( - conversation=[ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi!"), - ], - awaiting_agent_id="specialist", - prompt="Enter input", - source_executor_id="test-gateway", - ) - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Call resume_from_user with a user response - await gateway.resume_from_user(normal_request, "Follow up message", mock_ctx) - - # Verify send_message was called - mock_ctx.send_message.assert_called_once() - - # Get the message that was sent - call_args = mock_ctx.send_message.call_args - sent_message = call_args[0][0] - - # Verify it's a _ConversationWithUserInput - assert isinstance(sent_message, _ConversationWithUserInput) - - # Verify it contains the full conversation plus new user message - assert len(sent_message.full_conversation) == 3 - assert sent_message.full_conversation[0].text == "Hello" - assert sent_message.full_conversation[1].text == "Hi!" - assert sent_message.full_conversation[2].text == "Follow up message" - - -async def test_coordinator_handle_user_input_post_restore(): - """Test that _HandoffCoordinator.handle_user_input handles post-restore correctly. - - After checkpoint restore, the coordinator has its conversation restored, - and the gateway sends only the new user messages. The coordinator should - append these to its existing conversation rather than replacing. - """ - from unittest.mock import AsyncMock - - from agent_framework._workflows._handoff import _HandoffCoordinator - - # Create a coordinator with pre-existing conversation (simulating restored state) - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - ) - - # Simulate restored conversation - coordinator._conversation = [ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi there!"), - ChatMessage(role=Role.USER, text="Help me"), - ChatMessage(role=Role.ASSISTANT, text="Sure, what do you need?"), - ] - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Simulate post-restore: only new user message with explicit flag - incoming = _ConversationWithUserInput( - full_conversation=[ChatMessage(role=Role.USER, text="I need shipping help")], - is_post_restore=True, - ) - - # Handle the user input - await coordinator.handle_user_input(incoming, mock_ctx) - - # Verify conversation was appended, not replaced - assert len(coordinator._conversation) == 5 - assert coordinator._conversation[0].text == "Hello" - assert coordinator._conversation[1].text == "Hi there!" - assert coordinator._conversation[2].text == "Help me" - assert coordinator._conversation[3].text == "Sure, what do you need?" - assert coordinator._conversation[4].text == "I need shipping help" - - -async def test_coordinator_handle_user_input_normal_flow(): - """Test that _HandoffCoordinator.handle_user_input handles normal flow correctly. - - In normal flow (no restore), the gateway sends the full conversation. - The coordinator should replace its conversation with the incoming one. - """ - from unittest.mock import AsyncMock - - from agent_framework._workflows._handoff import _HandoffCoordinator - - # Create a coordinator - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - ) - - # Set some initial conversation - coordinator._conversation = [ - ChatMessage(role=Role.USER, text="Old message"), - ] - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Normal flow: full conversation including new user message (is_post_restore=False by default) - incoming = _ConversationWithUserInput( - full_conversation=[ - ChatMessage(role=Role.USER, text="Hello"), - ChatMessage(role=Role.ASSISTANT, text="Hi!"), - ChatMessage(role=Role.USER, text="New message"), - ], - is_post_restore=False, - ) - - # Handle the user input - await coordinator.handle_user_input(incoming, mock_ctx) - - # Verify conversation was replaced (normal flow with full history) - assert len(coordinator._conversation) == 3 - assert coordinator._conversation[0].text == "Hello" - assert coordinator._conversation[1].text == "Hi!" - assert coordinator._conversation[2].text == "New message" - - -async def test_coordinator_handle_user_input_multiple_consecutive_user_messages(): - """Test that multiple consecutive USER messages in normal flow are handled correctly. - - This is a regression test for the edge case where a user submits multiple consecutive - USER messages. The explicit is_post_restore flag ensures this doesn't get incorrectly - detected as a post-restore scenario. - """ - from unittest.mock import AsyncMock - - from agent_framework._workflows._handoff import _HandoffCoordinator - - # Create a coordinator with existing conversation - coordinator = _HandoffCoordinator( - starting_agent_id="triage", - specialist_ids={"specialist_a": "specialist_a"}, - input_gateway_id="gateway", - termination_condition=lambda conv: False, - id="test-coordinator", - ) - - # Set existing conversation with 4 messages - coordinator._conversation = [ - ChatMessage(role=Role.USER, text="Original message 1"), - ChatMessage(role=Role.ASSISTANT, text="Response 1"), - ChatMessage(role=Role.USER, text="Original message 2"), - ChatMessage(role=Role.ASSISTANT, text="Response 2"), - ] - - # Create mock context - mock_ctx = MagicMock() - mock_ctx.send_message = AsyncMock() - - # Normal flow: User sends multiple consecutive USER messages - # This should REPLACE the conversation, not append to it - incoming = _ConversationWithUserInput( - full_conversation=[ - ChatMessage(role=Role.USER, text="New user message 1"), - ChatMessage(role=Role.USER, text="New user message 2"), - ], - is_post_restore=False, # Explicit flag - this is normal flow - ) - - # Handle the user input - await coordinator.handle_user_input(incoming, mock_ctx) - - # Verify conversation was REPLACED (not appended) - # Without the explicit flag, the old heuristic might incorrectly append - assert len(coordinator._conversation) == 2 - assert coordinator._conversation[0].text == "New user message 1" - assert coordinator._conversation[1].text == "New user message 2" diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index 4ee16ddb5f..8cfa8d0bea 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -3,44 +3,42 @@ import sys from collections.abc import AsyncIterable from dataclasses import dataclass -from typing import Any, cast +from typing import Any, ClassVar, cast import pytest from agent_framework import ( + AgentProtocol, AgentRunResponse, AgentRunResponseUpdate, AgentRunUpdateEvent, + AgentThread, BaseAgent, ChatMessage, Executor, + GroupChatRequestMessage, MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, + MagenticContext, MagenticManagerBase, + MagenticOrchestrator, + MagenticOrchestratorEvent, + MagenticPlanReviewRequest, + MagenticProgressLedger, + MagenticProgressLedgerItem, RequestInfoEvent, Role, + StandardMagenticManager, TextContent, + Workflow, WorkflowCheckpoint, WorkflowContext, - WorkflowEvent, # type: ignore # noqa: E402 + WorkflowEvent, WorkflowOutputEvent, WorkflowRunState, WorkflowStatusEvent, handler, ) -from agent_framework._workflows import _group_chat as group_chat_module # type: ignore from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage -from agent_framework._workflows._magentic import ( # type: ignore[reportPrivateUsage] - MagenticAgentExecutor, - MagenticContext, - MagenticOrchestratorExecutor, - _MagenticProgressLedger, # type: ignore - _MagenticProgressLedgerItem, # type: ignore - _MagenticStartMessage, # type: ignore -) -from agent_framework._workflows._workflow_builder import WorkflowBuilder if sys.version_info >= (3, 12): from typing import override @@ -48,40 +46,9 @@ from typing_extensions import override -def test_magentic_start_message_from_string(): - msg = _MagenticStartMessage.from_string("Do the thing") - assert isinstance(msg, _MagenticStartMessage) - assert isinstance(msg.task, ChatMessage) - assert msg.task.role == Role.USER - assert msg.task.text == "Do the thing" - - -def test_human_intervention_request_defaults_and_reply_variants(): - from agent_framework._workflows._magentic import MagenticHumanInterventionKind - - req = MagenticHumanInterventionRequest(kind=MagenticHumanInterventionKind.PLAN_REVIEW) - assert hasattr(req, "request_id") - assert req.task_text == "" and req.facts_text == "" and req.plan_text == "" - assert isinstance(req.round_index, int) and req.round_index == 0 - - # Replies: approve, revise with comments, revise with edited text - approve = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) - revise_comments = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, comments="Tighten scope" - ) - revise_text = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, - edited_plan_text="- Step 1\n- Step 2", - ) - - assert approve.decision == MagenticHumanInterventionDecision.APPROVE - assert revise_comments.comments == "Tighten scope" - assert revise_text.edited_plan_text is not None and revise_text.edited_plan_text.startswith("- Step 1") - - def test_magentic_context_reset_behavior(): ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="task"), + task="task", participant_descriptions={"Alice": "Researcher"}, ) # seed context state @@ -105,10 +72,24 @@ class _SimpleLedger: class FakeManager(MagenticManagerBase): """Deterministic manager for tests that avoids real LLM calls.""" - task_ledger: _SimpleLedger | None = None - satisfied_after_signoff: bool = True - next_speaker_name: str = "agentA" - instruction_text: str = "Proceed with step 1" + FINAL_ANSWER: ClassVar[str] = "FINAL" + + def __init__( + self, + *, + max_stall_count: int = 3, + max_reset_count: int | None = None, + max_round_count: int | None = None, + ) -> None: + super().__init__( + max_stall_count=max_stall_count, + max_reset_count=max_reset_count, + max_round_count=max_round_count, + ) + self.name = "magentic_manager" + self.task_ledger: _SimpleLedger | None = None + self.next_speaker_name: str = "agentA" + self.instruction_text: str = "Proceed with step 1" @override def on_checkpoint_save(self) -> dict[str, Any]: @@ -141,47 +122,117 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: facts = ChatMessage(role=Role.ASSISTANT, text="GIVEN OR VERIFIED FACTS\n- A\n") plan = ChatMessage(role=Role.ASSISTANT, text="- Do X\n- Do Y\n") self.task_ledger = _SimpleLedger(facts=facts, plan=plan) - combined = f"Task: {magentic_context.task.text}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" - return ChatMessage(role=Role.ASSISTANT, text=combined, author_name="magentic_manager") + combined = f"Task: {magentic_context.task}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" + return ChatMessage(role=Role.ASSISTANT, text=combined, author_name=self.name) async def replan(self, magentic_context: MagenticContext) -> ChatMessage: facts = ChatMessage(role=Role.ASSISTANT, text="GIVEN OR VERIFIED FACTS\n- A2\n") plan = ChatMessage(role=Role.ASSISTANT, text="- Do Z\n") self.task_ledger = _SimpleLedger(facts=facts, plan=plan) - combined = f"Task: {magentic_context.task.text}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" - return ChatMessage(role=Role.ASSISTANT, text=combined, author_name="magentic_manager") - - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: - is_satisfied = self.satisfied_after_signoff and len(magentic_context.chat_history) > 0 - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="test", answer=is_satisfied), - is_in_loop=_MagenticProgressLedgerItem(reason="test", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="test", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="test", answer=self.next_speaker_name), - instruction_or_question=_MagenticProgressLedgerItem(reason="test", answer=self.instruction_text), + combined = f"Task: {magentic_context.task}\n\nFacts:\n{facts.text}\n\nPlan:\n{plan.text}" + return ChatMessage(role=Role.ASSISTANT, text=combined, author_name=self.name) + + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: + # At least two messages in chat history means request is satisfied for testing + is_satisfied = len(magentic_context.chat_history) > 1 + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="test", answer=is_satisfied), + is_in_loop=MagenticProgressLedgerItem(reason="test", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="test", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="test", answer=self.next_speaker_name), + instruction_or_question=MagenticProgressLedgerItem(reason="test", answer=self.instruction_text), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: - return ChatMessage(role=Role.ASSISTANT, text="FINAL", author_name="magentic_manager") + return ChatMessage(role=Role.ASSISTANT, text=self.FINAL_ANSWER, author_name=self.name) + + +class StubAgent(BaseAgent): + def __init__(self, agent_name: str, reply_text: str, **kwargs: Any) -> None: + super().__init__(name=agent_name, description=f"Stub agent {agent_name}", **kwargs) + self._reply_text = reply_text + + async def run( # type: ignore[override] + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + response = ChatMessage(role=Role.ASSISTANT, text=self._reply_text, author_name=self.name) + return AgentRunResponse(messages=[response]) + def run_stream( # type: ignore[override] + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + async def _stream() -> AsyncIterable[AgentRunResponseUpdate]: + yield AgentRunResponseUpdate( + contents=[TextContent(text=self._reply_text)], role=Role.ASSISTANT, author_name=self.name + ) -class _CountingWorkflowBuilder(WorkflowBuilder): - created: list["_CountingWorkflowBuilder"] = [] + return _stream() - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.start_calls = 0 - _CountingWorkflowBuilder.created.append(self) - def set_start_executor(self, executor: Any) -> "_CountingWorkflowBuilder": # type: ignore[override] - self.start_calls += 1 - return cast("_CountingWorkflowBuilder", super().set_start_executor(executor)) +class DummyExec(Executor): + def __init__(self, name: str) -> None: + super().__init__(name) + + @handler + async def _noop( + self, message: GroupChatRequestMessage, ctx: WorkflowContext[ChatMessage] + ) -> None: # pragma: no cover - not called + pass + + +async def test_magentic_builder_returns_workflow_and_runs() -> None: + manager = FakeManager() + agent = StubAgent(manager.next_speaker_name, "first draft") + + workflow = MagenticBuilder().participants([agent]).with_standard_manager(manager).build() + + assert isinstance(workflow, Workflow) + + outputs: list[ChatMessage] = [] + orchestrator_event_count = 0 + async for event in workflow.run_stream("compose summary"): + if isinstance(event, WorkflowOutputEvent): + msg = event.data + if isinstance(msg, list): + outputs.extend(cast(list[ChatMessage], msg)) + elif isinstance(event, MagenticOrchestratorEvent): + orchestrator_event_count += 1 + + assert outputs, "Expected a final output message" + assert len(outputs) >= 1 + final = outputs[-1] + assert final.text == manager.FINAL_ANSWER + assert final.author_name == manager.name + assert orchestrator_event_count > 0, "Expected orchestrator events to be emitted" + + +async def test_magentic_as_agent_does_not_accept_conversation() -> None: + manager = FakeManager() + writer = StubAgent(manager.next_speaker_name, "summary response") + + workflow = MagenticBuilder().participants([writer]).with_standard_manager(manager).build() + + agent = workflow.as_agent(name="magentic-agent") + conversation = [ + ChatMessage(role=Role.SYSTEM, text="Guidelines", author_name="system"), + ChatMessage(role=Role.USER, text="Summarize the findings", author_name="requester"), + ] + with pytest.raises(ValueError, match="Magentic only support a single task message to start the workflow."): + await agent.run(conversation) async def test_standard_manager_plan_and_replan_combined_ledger(): - manager = FakeManager(max_round_count=10, max_stall_count=3, max_reset_count=2) + manager = FakeManager() ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="demo task"), + task="demo task", participant_descriptions={"agentA": "Agent A"}, ) @@ -193,55 +244,34 @@ async def test_standard_manager_plan_and_replan_combined_ledger(): assert "A2" in replanned.text or "Do Z" in replanned.text -async def test_standard_manager_progress_ledger_and_fallback(): - manager = FakeManager(max_round_count=10) - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="demo"), - participant_descriptions={"agentA": "Agent A"}, - ) - - ledger = await manager.create_progress_ledger(ctx.clone()) - assert isinstance(ledger, _MagenticProgressLedger) - assert ledger.next_speaker.answer == "agentA" - - manager.satisfied_after_signoff = False - ledger2 = await manager.create_progress_ledger(ctx.clone()) - assert ledger2.is_request_satisfied.answer is False - - async def test_magentic_workflow_plan_review_approval_to_completion(): - manager = FakeManager(max_round_count=10) - wf = ( - MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) - .with_standard_manager(manager) - .with_plan_review() - .build() - ) + manager = FakeManager() + wf = MagenticBuilder().participants([DummyExec("agentA")]).with_standard_manager(manager).with_plan_review().build() req_event: RequestInfoEvent | None = None async for ev in wf.run_stream("do work"): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: req_event = ev assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) completed = False output: list[ChatMessage] | None = None - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) - async for ev in wf.send_responses_streaming(responses={req_event.request_id: reply}): + async for ev in wf.send_responses_streaming(responses={req_event.request_id: req_event.data.approve()}): if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: completed = True elif isinstance(ev, WorkflowOutputEvent): output = ev.data # type: ignore[assignment] if completed and output is not None: break + assert completed assert output is not None assert isinstance(output, list) assert all(isinstance(msg, ChatMessage) for msg in output) -async def test_magentic_plan_review_approve_with_comments_replans_and_proceeds(): +async def test_magentic_plan_review_with_revise(): class CountingManager(FakeManager): # Declare as a model field so assignment is allowed under Pydantic replan_count: int = 0 @@ -253,10 +283,10 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: # typ self.replan_count += 1 return await super().replan(magentic_context) - manager = CountingManager(max_round_count=10) + manager = CountingManager() wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec(name=manager.next_speaker_name)]) .with_standard_manager(manager) .with_plan_review() .build() @@ -265,30 +295,32 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: # typ # Wait for the initial plan review request req_event: RequestInfoEvent | None = None async for ev in wf.run_stream("do work"): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: req_event = ev assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) - # Reply APPROVE with comments (no edited text). Expect one replan and no second review round. + # Send a revise response saw_second_review = False completed = False async for ev in wf.send_responses_streaming( - responses={ - req_event.request_id: MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.APPROVE, - comments="Looks good; consider Z", - ) - } + responses={req_event.request_id: req_event.data.revise("Looks good; consider Z")} ): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: saw_second_review = True + req_event = ev + + # Approve the second review + async for ev in wf.send_responses_streaming( + responses={req_event.request_id: req_event.data.approve()} # type: ignore[union-attr] + ): if isinstance(ev, WorkflowStatusEvent) and ev.state == WorkflowRunState.IDLE: completed = True break assert completed assert manager.replan_count >= 1 - assert saw_second_review is False + assert saw_second_review is True # Replan from FakeManager updates facts/plan to include A2 / Do Z assert manager.task_ledger is not None combined_text = (manager.task_ledger.facts.text or "") + (manager.task_ledger.plan.text or "") @@ -297,16 +329,16 @@ async def replan(self, magentic_context: MagenticContext) -> ChatMessage: # typ async def test_magentic_orchestrator_round_limit_produces_partial_result(): manager = FakeManager(max_round_count=1) - manager.satisfied_after_signoff = False - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() - - from agent_framework import WorkflowEvent # type: ignore + wf = ( + MagenticBuilder() + .participants([DummyExec(name=manager.next_speaker_name)]) + .with_standard_manager(manager) + .build() + ) events: list[WorkflowEvent] = [] async for ev in wf.run_stream("round limit test"): events.append(ev) - if len(events) > 50: - break idle_status = next( (e for e in events if isinstance(e, WorkflowStatusEvent) and e.state == WorkflowRunState.IDLE), None @@ -317,18 +349,18 @@ async def test_magentic_orchestrator_round_limit_produces_partial_result(): assert output_event is not None data = output_event.data assert isinstance(data, list) - assert all(isinstance(msg, ChatMessage) for msg in data) - assert len(data) > 0 - assert data[-1].role == Role.ASSISTANT + assert len(data) > 0 # type: ignore + assert data[-1].role == Role.ASSISTANT # type: ignore + assert all(isinstance(msg, ChatMessage) for msg in data) # type: ignore async def test_magentic_checkpoint_resume_round_trip(): storage = InMemoryCheckpointStorage() - manager1 = FakeManager(max_round_count=10) + manager1 = FakeManager() wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec(name=manager1.next_speaker_name)]) .with_standard_manager(manager1) .with_plan_review() .with_checkpointing(storage) @@ -338,99 +370,52 @@ async def test_magentic_checkpoint_resume_round_trip(): task_text = "checkpoint task" req_event: RequestInfoEvent | None = None async for ev in wf.run_stream(task_text): - if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticHumanInterventionRequest: + if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: req_event = ev assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) checkpoints = await storage.list_checkpoints() assert checkpoints checkpoints.sort(key=lambda cp: cp.timestamp) resume_checkpoint = checkpoints[-1] - manager2 = FakeManager(max_round_count=10) + manager2 = FakeManager() wf_resume = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec(name=manager2.next_speaker_name)]) .with_standard_manager(manager2) .with_plan_review() .with_checkpointing(storage) .build() ) - orchestrator = next(exec for exec in wf_resume.executors.values() if isinstance(exec, MagenticOrchestratorExecutor)) - - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) completed: WorkflowOutputEvent | None = None req_event = None async for event in wf_resume.run_stream( resume_checkpoint.checkpoint_id, ): - if isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: + if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: req_event = event assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) - responses = {req_event.request_id: reply} + responses = {req_event.request_id: req_event.data.approve()} async for event in wf_resume.send_responses_streaming(responses=responses): if isinstance(event, WorkflowOutputEvent): completed = event assert completed is not None - assert orchestrator._context is not None # type: ignore[reportPrivateUsage] - assert orchestrator._context.chat_history # type: ignore[reportPrivateUsage] + orchestrator = next(exec for exec in wf_resume.executors.values() if isinstance(exec, MagenticOrchestrator)) + assert orchestrator._magentic_context is not None # type: ignore[reportPrivateUsage] + assert orchestrator._magentic_context.chat_history # type: ignore[reportPrivateUsage] assert orchestrator._task_ledger is not None # type: ignore[reportPrivateUsage] assert manager2.task_ledger is not None # Latest entry in chat history should be the task ledger plan - assert orchestrator._context.chat_history[-1].text == orchestrator._task_ledger.text # type: ignore[reportPrivateUsage] - - -class _DummyExec(Executor): - def __init__(self, name: str) -> None: - super().__init__(name) - - @handler - async def _noop(self, message: object, ctx: WorkflowContext[object]) -> None: # pragma: no cover - not called - pass - - -def test_magentic_builder_sets_start_executor_once(monkeypatch: pytest.MonkeyPatch) -> None: - """Ensure MagenticBuilder wiring sets the start executor only once.""" - _CountingWorkflowBuilder.created.clear() - monkeypatch.setattr(group_chat_module, "WorkflowBuilder", _CountingWorkflowBuilder) - - manager = FakeManager() - - workflow = ( - MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager=manager).build() - ) + assert orchestrator._magentic_context.chat_history[-1].text == orchestrator._task_ledger.text # type: ignore[reportPrivateUsage] - assert workflow is not None - assert _CountingWorkflowBuilder.created, "Expected CountingWorkflowBuilder to be instantiated" - builder = _CountingWorkflowBuilder.created[-1] - assert builder.start_calls == 1, "set_start_executor should be called exactly once" - -async def test_magentic_agent_executor_on_checkpoint_save_and_restore_roundtrip(): - backing_executor = _DummyExec("backing") - agent_exec = MagenticAgentExecutor(backing_executor, "agentA") - agent_exec._chat_history.extend([ # type: ignore[reportPrivateUsage] - ChatMessage(role=Role.USER, text="hello"), - ChatMessage(role=Role.ASSISTANT, text="world", author_name="agentA"), - ]) - - state = await agent_exec.on_checkpoint_save() - - restored_executor = MagenticAgentExecutor(_DummyExec("backing2"), "agentA") - await restored_executor.on_checkpoint_restore(state) - - assert len(restored_executor._chat_history) == 2 # type: ignore[reportPrivateUsage] - assert restored_executor._chat_history[0].text == "hello" # type: ignore[reportPrivateUsage] - assert restored_executor._chat_history[1].author_name == "agentA" # type: ignore[reportPrivateUsage] - - -from agent_framework import StandardMagenticManager # noqa: E402 - - -class _StubManagerAgent(BaseAgent): +class StubManagerAgent(BaseAgent): """Stub agent for testing StandardMagenticManager.""" async def run( @@ -456,7 +441,7 @@ async def _gen() -> AsyncIterable[AgentRunResponseUpdate]: async def test_standard_manager_plan_and_replan_via_complete_monkeypatch(): - mgr = StandardMagenticManager(agent=_StubManagerAgent()) + mgr = StandardMagenticManager(StubManagerAgent()) async def fake_complete_plan(messages: list[ChatMessage], **kwargs: Any) -> ChatMessage: # Return a different response depending on call order length @@ -467,10 +452,7 @@ async def fake_complete_plan(messages: list[ChatMessage], **kwargs: Any) -> Chat # First, patch to produce facts then plan mgr._complete = fake_complete_plan # type: ignore[attr-defined] - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="T"), - participant_descriptions={"A": "desc"}, - ) + ctx = MagenticContext(task="T", participant_descriptions={"A": "desc"}) combined = await mgr.plan(ctx.clone()) # Assert structural headings and that steps appear in the combined ledger output. assert "We are working to address the following user request:" in combined.text @@ -489,11 +471,8 @@ async def fake_complete_replan(messages: list[ChatMessage], **kwargs: Any) -> Ch async def test_standard_manager_progress_ledger_success_and_error(): - mgr = StandardMagenticManager(agent=_StubManagerAgent()) - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="task"), - participant_descriptions={"alice": "desc"}, - ) + mgr = StandardMagenticManager(agent=StubManagerAgent()) + ctx = MagenticContext(task="task", participant_descriptions={"alice": "desc"}) # Success path: valid JSON async def fake_complete_ok(messages: list[ChatMessage], **kwargs: Any) -> ChatMessage: @@ -530,24 +509,24 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="re-ledger") - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: if not self._invoked: # First round: ask agentA to respond self._invoked = True - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="r", answer=False), - is_in_loop=_MagenticProgressLedgerItem(reason="r", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="r", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="r", answer="agentA"), - instruction_or_question=_MagenticProgressLedgerItem(reason="r", answer="say hi"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False), + is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), + instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="say hi"), ) # Next round: mark satisfied so run can conclude - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="r", answer=True), - is_in_loop=_MagenticProgressLedgerItem(reason="r", answer=False), - is_progress_being_made=_MagenticProgressLedgerItem(reason="r", answer=True), - next_speaker=_MagenticProgressLedgerItem(reason="r", answer="agentA"), - instruction_or_question=_MagenticProgressLedgerItem(reason="r", answer="done"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=True), + is_in_loop=MagenticProgressLedgerItem(reason="r", answer=False), + is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=True), + next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), + instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="done"), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: @@ -555,15 +534,18 @@ async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatM class StubThreadAgent(BaseAgent): + def __init__(self, name: str | None = None) -> None: + super().__init__(name=name or "agentA") + async def run_stream(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] yield AgentRunResponseUpdate( contents=[TextContent(text="thread-ok")], - author_name="agentA", + author_name=self.name, role=Role.ASSISTANT, ) async def run(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="thread-ok", author_name="agentA")]) + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="thread-ok", author_name=self.name)]) class StubAssistantsClient: @@ -574,29 +556,26 @@ class StubAssistantsAgent(BaseAgent): chat_client: object | None = None # allow assignment via Pydantic field def __init__(self) -> None: - super().__init__() + super().__init__(name="agentA") self.chat_client = StubAssistantsClient() # type name contains 'AssistantsClient' async def run_stream(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] yield AgentRunResponseUpdate( contents=[TextContent(text="assistants-ok")], - author_name="agentA", + author_name=self.name, role=Role.ASSISTANT, ) async def run(self, messages=None, *, thread=None, **kwargs): # type: ignore[override] - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="assistants-ok", author_name="agentA")]) + return AgentRunResponse( + messages=[ChatMessage(role=Role.ASSISTANT, text="assistants-ok", author_name=self.name)] + ) -async def _collect_agent_responses_setup(participant_obj: object): +async def _collect_agent_responses_setup(participant: AgentProtocol) -> list[ChatMessage]: captured: list[ChatMessage] = [] - wf = ( - MagenticBuilder() - .participants(agentA=participant_obj) # type: ignore[arg-type] - .with_standard_manager(InvokeOnceManager()) - .build() - ) + wf = MagenticBuilder().participants([participant]).with_standard_manager(InvokeOnceManager()).build() # Run a bounded stream to allow one invoke and then completion events: list[WorkflowEvent] = [] @@ -604,27 +583,27 @@ async def _collect_agent_responses_setup(participant_obj: object): events.append(ev) if isinstance(ev, WorkflowOutputEvent): break - if isinstance(ev, AgentRunUpdateEvent) and ev.data is not None: + if isinstance(ev, AgentRunUpdateEvent): captured.append( ChatMessage( role=ev.data.role or Role.ASSISTANT, text=ev.data.text or "", author_name=ev.data.author_name ) ) - if len(events) > 50: - break return captured async def test_agent_executor_invoke_with_thread_chat_client(): - captured = await _collect_agent_responses_setup(StubThreadAgent()) + agent = StubThreadAgent() + captured = await _collect_agent_responses_setup(agent) # Should have at least one response from agentA via _MagenticAgentExecutor path - assert any((m.author_name == "agentA" and "ok" in (m.text or "")) for m in captured) + assert any((m.author_name == agent.name and "ok" in (m.text or "")) for m in captured) async def test_agent_executor_invoke_with_assistants_client_messages(): - captured = await _collect_agent_responses_setup(StubAssistantsAgent()) - assert any((m.author_name == "agentA" and "ok" in (m.text or "")) for m in captured) + agent = StubAssistantsAgent() + captured = await _collect_agent_responses_setup(agent) + assert any((m.author_name == agent.name and "ok" in (m.text or "")) for m in captured) async def _collect_checkpoints(storage: InMemoryCheckpointStorage) -> list[WorkflowCheckpoint]: @@ -639,7 +618,7 @@ async def test_magentic_checkpoint_resume_inner_loop_superstep(): workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(InvokeOnceManager()) .with_checkpointing(storage) .build() @@ -654,7 +633,7 @@ async def test_magentic_checkpoint_resume_inner_loop_superstep(): resumed = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(InvokeOnceManager()) .with_checkpointing(storage) .build() @@ -668,7 +647,8 @@ async def test_magentic_checkpoint_resume_inner_loop_superstep(): assert completed is not None -async def test_magentic_checkpoint_resume_after_reset(): +async def test_magentic_checkpoint_resume_from_saved_state(): + """Test that we can resume workflow execution from a saved checkpoint.""" storage = InMemoryCheckpointStorage() # Use the working InvokeOnceManager first to get a completed workflow @@ -676,27 +656,24 @@ async def test_magentic_checkpoint_resume_after_reset(): workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(manager) .with_checkpointing(storage) .build() ) - async for event in workflow.run_stream("reset task"): + async for event in workflow.run_stream("checkpoint resume task"): if isinstance(event, WorkflowOutputEvent): break checkpoints = await _collect_checkpoints(storage) - # For this test, we just need to verify that we can resume from any checkpoint - # The original test intention was to test resuming after a reset has occurred - # Since we can't easily simulate a reset in the test environment without causing hangs, - # we'll test the basic checkpoint resume functionality which is the core requirement + # Verify we can resume from the last saved checkpoint resumed_state = checkpoints[-1] # Use the last checkpoint resumed_workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(InvokeOnceManager()) .with_checkpointing(storage) .build() @@ -717,7 +694,7 @@ async def test_magentic_checkpoint_resume_rejects_participant_renames(): workflow = ( MagenticBuilder() - .participants(agentA=StubThreadAgent()) + .participants([StubThreadAgent()]) .with_standard_manager(manager) .with_plan_review() .with_checkpointing(storage) @@ -726,17 +703,18 @@ async def test_magentic_checkpoint_resume_rejects_participant_renames(): req_event: RequestInfoEvent | None = None async for event in workflow.run_stream("task"): - if isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: + if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: req_event = event assert req_event is not None + assert isinstance(req_event.data, MagenticPlanReviewRequest) checkpoints = await _collect_checkpoints(storage) target_checkpoint = checkpoints[-1] renamed_workflow = ( MagenticBuilder() - .participants(agentB=StubThreadAgent()) + .participants([StubThreadAgent(name="renamedAgent")]) .with_standard_manager(InvokeOnceManager()) .with_plan_review() .with_checkpointing(storage) @@ -761,23 +739,23 @@ async def plan(self, magentic_context: MagenticContext) -> ChatMessage: async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="re-ledger") - async def create_progress_ledger(self, magentic_context: MagenticContext) -> _MagenticProgressLedger: - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(reason="r", answer=False), - is_in_loop=_MagenticProgressLedgerItem(reason="r", answer=True), - is_progress_being_made=_MagenticProgressLedgerItem(reason="r", answer=False), - next_speaker=_MagenticProgressLedgerItem(reason="r", answer="agentA"), - instruction_or_question=_MagenticProgressLedgerItem(reason="r", answer="done"), + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(reason="r", answer=False), + is_in_loop=MagenticProgressLedgerItem(reason="r", answer=True), + is_progress_being_made=MagenticProgressLedgerItem(reason="r", answer=False), + next_speaker=MagenticProgressLedgerItem(reason="r", answer="agentA"), + instruction_or_question=MagenticProgressLedgerItem(reason="r", answer="done"), ) async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="final") -async def test_magentic_stall_and_reset_successfully(): +async def test_magentic_stall_and_reset_reach_limits(): manager = NotProgressingManager(max_round_count=10, max_stall_count=0, max_reset_count=1) - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() + wf = MagenticBuilder().participants([DummyExec("agentA")]).with_standard_manager(manager).build() events: list[WorkflowEvent] = [] async for ev in wf.run_stream("test limits"): @@ -790,10 +768,10 @@ async def test_magentic_stall_and_reset_successfully(): output_event = next((e for e in events if isinstance(e, WorkflowOutputEvent)), None) assert output_event is not None assert isinstance(output_event.data, list) - assert all(isinstance(msg, ChatMessage) for msg in output_event.data) - assert len(output_event.data) > 0 - assert output_event.data[-1].text is not None - assert output_event.data[-1].text == "re-ledger" + assert all(isinstance(msg, ChatMessage) for msg in output_event.data) # type: ignore + assert len(output_event.data) > 0 # type: ignore + assert output_event.data[-1].text is not None # type: ignore + assert output_event.data[-1].text == "Workflow terminated due to reaching maximum reset count." # type: ignore async def test_magentic_checkpoint_runtime_only() -> None: @@ -801,8 +779,7 @@ async def test_magentic_checkpoint_runtime_only() -> None: storage = InMemoryCheckpointStorage() manager = FakeManager(max_round_count=10) - manager.satisfied_after_signoff = True - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() + wf = MagenticBuilder().participants([DummyExec("agentA")]).with_standard_manager(manager).build() baseline_output: ChatMessage | None = None async for ev in wf.run_stream("runtime checkpoint test", checkpoint_storage=storage): @@ -831,10 +808,9 @@ async def test_magentic_checkpoint_runtime_overrides_buildtime() -> None: runtime_storage = FileCheckpointStorage(temp_dir2) manager = FakeManager(max_round_count=10) - manager.satisfied_after_signoff = True wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec("agentA")]) .with_standard_manager(manager) .with_checkpointing(buildtime_storage) .build() @@ -859,127 +835,12 @@ async def test_magentic_checkpoint_runtime_overrides_buildtime() -> None: assert len(buildtime_checkpoints) == 0, "Build-time storage should have no checkpoints when overridden" -def test_magentic_builder_does_not_have_human_input_hook(): - """Test that MagenticBuilder does not expose with_human_input_hook (uses specialized HITL instead). - - Magentic uses specialized human intervention mechanisms: - - with_plan_review() for plan approval - - with_human_input_on_stall() for stall intervention - - Tool approval via FunctionApprovalRequestContent - - These emit MagenticHumanInterventionRequest events with structured decision options. - """ - builder = MagenticBuilder() - - # MagenticBuilder should NOT have the generic human input hook mixin - assert not hasattr(builder, "with_human_input_hook"), ( - "MagenticBuilder should not have with_human_input_hook - " - "use with_plan_review() or with_human_input_on_stall() instead" - ) - - # region Message Deduplication Tests -async def test_magentic_no_duplicate_messages_with_conversation_history(): - """Test that passing list[ChatMessage] does not create duplicate messages in chat_history. - - When a frontend passes conversation history as list[ChatMessage], the last message - (task) should not be duplicated in the orchestrator's chat_history. - """ - manager = FakeManager(max_round_count=10) - manager.satisfied_after_signoff = True # Complete immediately after first agent response - - wf = MagenticBuilder().participants(agentA=_DummyExec("agentA")).with_standard_manager(manager).build() - - # Simulate frontend passing conversation history - conversation: list[ChatMessage] = [ - ChatMessage(role=Role.USER, text="previous question"), - ChatMessage(role=Role.ASSISTANT, text="previous answer"), - ChatMessage(role=Role.USER, text="current task"), - ] - - # Get orchestrator to inspect chat_history after run - orchestrator = None - for executor in wf.executors.values(): - if isinstance(executor, MagenticOrchestratorExecutor): - orchestrator = executor - break - - events: list[WorkflowEvent] = [] - async for event in wf.run_stream(conversation): - events.append(event) - if isinstance(event, WorkflowStatusEvent) and event.state in ( - WorkflowRunState.IDLE, - WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, - ): - break - - assert orchestrator is not None - assert orchestrator._context is not None # type: ignore[reportPrivateUsage] - - # Count occurrences of each message text in chat_history - history = orchestrator._context.chat_history # type: ignore[reportPrivateUsage] - user_task_count = sum(1 for msg in history if msg.text == "current task") - prev_question_count = sum(1 for msg in history if msg.text == "previous question") - prev_answer_count = sum(1 for msg in history if msg.text == "previous answer") - - # Each input message should appear exactly once (no duplicates) - assert prev_question_count == 1, f"Expected 1 'previous question', got {prev_question_count}" - assert prev_answer_count == 1, f"Expected 1 'previous answer', got {prev_answer_count}" - assert user_task_count == 1, f"Expected 1 'current task', got {user_task_count}" - - -async def test_magentic_agent_executor_no_duplicate_messages_on_broadcast(): - """Test that MagenticAgentExecutor does not duplicate messages from broadcasts. - - When the orchestrator broadcasts the task ledger to all agents, each agent - should receive it exactly once, not multiple times. - """ - backing_executor = _DummyExec("backing") - agent_exec = MagenticAgentExecutor(backing_executor, "agentA") - - # Simulate orchestrator sending a broadcast message - broadcast_msg = ChatMessage( - role=Role.ASSISTANT, - text="Task ledger content", - author_name="magentic_manager", - ) - - # Simulate the same message being received multiple times (e.g., from checkpoint restore + live) - from agent_framework._workflows._magentic import _MagenticResponseMessage - - response1 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) - response2 = _MagenticResponseMessage(body=broadcast_msg, broadcast=True) - - # Create a mock context - from unittest.mock import AsyncMock, MagicMock - - mock_context = MagicMock() - mock_context.send_message = AsyncMock() - - # Call the handler twice with the same message - await agent_exec.handle_response_message(response1, mock_context) # type: ignore[arg-type] - await agent_exec.handle_response_message(response2, mock_context) # type: ignore[arg-type] - - # Count how many times the broadcast message appears - history = agent_exec._chat_history # type: ignore[reportPrivateUsage] - broadcast_count = sum(1 for msg in history if msg.text == "Task ledger content") - - # Each broadcast should be recorded (this is expected behavior - broadcasts are additive) - # The test documents current behavior. If dedup is needed, this assertion would change. - assert broadcast_count == 2, ( - f"Expected 2 broadcasts (current behavior is additive), got {broadcast_count}. " - "If deduplication is required, update the handler logic." - ) - - async def test_magentic_context_no_duplicate_on_reset(): """Test that MagenticContext.reset() clears chat_history without leaving duplicates.""" - ctx = MagenticContext( - task=ChatMessage(role=Role.USER, text="task"), - participant_descriptions={"Alice": "Researcher"}, - ) + ctx = MagenticContext(task="task", participant_descriptions={"Alice": "Researcher"}) # Add some history ctx.chat_history.append(ChatMessage(role=Role.ASSISTANT, text="response1")) @@ -997,24 +858,6 @@ async def test_magentic_context_no_duplicate_on_reset(): assert len(ctx.chat_history) == 1, "Should have exactly 1 message after adding to reset context" -async def test_magentic_start_message_messages_list_integrity(): - """Test that _MagenticStartMessage preserves message list without internal duplication.""" - conversation: list[ChatMessage] = [ - ChatMessage(role=Role.USER, text="msg1"), - ChatMessage(role=Role.ASSISTANT, text="msg2"), - ChatMessage(role=Role.USER, text="msg3"), - ] - - start_msg = _MagenticStartMessage(conversation) - - # Verify messages list is preserved - assert len(start_msg.messages) == 3, f"Expected 3 messages, got {len(start_msg.messages)}" - - # Verify task is the last message (not a copy) - assert start_msg.task is start_msg.messages[-1], "task should be the same object as messages[-1]" - assert start_msg.task.text == "msg3" - - async def test_magentic_checkpoint_restore_no_duplicate_history(): """Test that checkpoint restore does not create duplicate messages in chat_history.""" manager = FakeManager(max_round_count=10) @@ -1022,7 +865,7 @@ async def test_magentic_checkpoint_restore_no_duplicate_history(): wf = ( MagenticBuilder() - .participants(agentA=_DummyExec("agentA")) + .participants([DummyExec("agentA")]) .with_standard_manager(manager) .with_checkpointing(storage) .build() @@ -1030,7 +873,6 @@ async def test_magentic_checkpoint_restore_no_duplicate_history(): # Run with conversation history to create initial checkpoint conversation: list[ChatMessage] = [ - ChatMessage(role=Role.USER, text="history_msg"), ChatMessage(role=Role.USER, text="task_msg"), ] @@ -1054,18 +896,18 @@ async def test_magentic_checkpoint_restore_no_duplicate_history(): # Check the magentic_context in the checkpoint for _, executor_state in checkpoint_data.metadata.items(): if isinstance(executor_state, dict) and "magentic_context" in executor_state: - ctx_data = executor_state["magentic_context"] - chat_history = ctx_data.get("chat_history", []) + ctx_data: dict[str, Any] = executor_state["magentic_context"] # type: ignore + chat_history = ctx_data.get("chat_history", []) # type: ignore # Count unique messages by text - texts = [ - msg.get("text") or (msg.get("contents", [{}])[0].get("text") if msg.get("contents") else None) - for msg in chat_history + texts = [ # type: ignore + msg.get("text") or (msg.get("contents", [{}])[0].get("text") if msg.get("contents") else None) # type: ignore + for msg in chat_history # type: ignore ] text_counts: dict[str, int] = {} - for text in texts: + for text in texts: # type: ignore if text: - text_counts[text] = text_counts.get(text, 0) + 1 + text_counts[text] = text_counts.get(text, 0) + 1 # type: ignore # Input messages should not be duplicated assert text_counts.get("history_msg", 0) <= 1, ( diff --git a/python/packages/core/tests/workflow/test_orchestration_request_info.py b/python/packages/core/tests/workflow/test_orchestration_request_info.py index e5f4d7a11f..a47a666ddf 100644 --- a/python/packages/core/tests/workflow/test_orchestration_request_info.py +++ b/python/packages/core/tests/workflow/test_orchestration_request_info.py @@ -1,59 +1,51 @@ # Copyright (c) Microsoft. All rights reserved. -"""Unit tests for request info support in high-level builders.""" +"""Unit tests for orchestration request info support.""" +from collections.abc import AsyncIterable from typing import Any -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock + +import pytest from agent_framework import ( - AgentInputRequest, AgentProtocol, - AgentResponseReviewRequest, + AgentRunResponse, + AgentRunResponseUpdate, + AgentThread, ChatMessage, - RequestInfoInterceptor, Role, ) -from agent_framework._workflows._executor import Executor, handler -from agent_framework._workflows._orchestration_request_info import resolve_request_info_filter +from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse +from agent_framework._workflows._orchestration_request_info import ( + AgentApprovalExecutor, + AgentRequestInfoExecutor, + AgentRequestInfoResponse, + resolve_request_info_filter, +) from agent_framework._workflows._workflow_context import WorkflowContext -class DummyExecutor(Executor): - """Dummy executor with a handler for testing.""" - - @handler - async def handle(self, data: str, ctx: WorkflowContext[Any, Any]) -> None: - pass - - class TestResolveRequestInfoFilter: """Tests for resolve_request_info_filter function.""" - def test_returns_none_for_none_input(self): - """Test that None input returns None (no filtering).""" + def test_returns_empty_set_for_none_input(self): + """Test that None input returns empty set (no filtering).""" result = resolve_request_info_filter(None) - assert result is None + assert result == set() - def test_returns_none_for_empty_list(self): - """Test that empty list returns None.""" + def test_returns_empty_set_for_empty_list(self): + """Test that empty list returns empty set.""" result = resolve_request_info_filter([]) - assert result is None + assert result == set() def test_resolves_string_names(self): """Test resolving string agent names.""" result = resolve_request_info_filter(["agent1", "agent2"]) assert result == {"agent1", "agent2"} - def test_resolves_executor_ids(self): - """Test resolving Executor instances by ID.""" - exec1 = DummyExecutor(id="executor1") - exec2 = DummyExecutor(id="executor2") - - result = resolve_request_info_filter([exec1, exec2]) - assert result == {"executor1", "executor2"} - - def test_resolves_agent_names(self): - """Test resolving AgentProtocol-like objects by name attribute.""" + def test_resolves_agent_display_names(self): + """Test resolving AgentProtocol instances by name attribute.""" agent1 = MagicMock(spec=AgentProtocol) agent1.name = "writer" agent2 = MagicMock(spec=AgentProtocol) @@ -63,106 +55,205 @@ def test_resolves_agent_names(self): assert result == {"writer", "reviewer"} def test_mixed_types(self): - """Test resolving a mix of strings, agents, and executors.""" + """Test resolving a mix of strings and agents.""" agent = MagicMock(spec=AgentProtocol) agent.name = "writer" - executor = DummyExecutor(id="custom_exec") - result = resolve_request_info_filter(["manual_name", agent, executor]) - assert result == {"manual_name", "writer", "custom_exec"} + result = resolve_request_info_filter(["manual_name", agent]) + assert result == {"manual_name", "writer"} + + def test_raises_on_unsupported_type(self): + """Test that unsupported types raise TypeError.""" + with pytest.raises(TypeError, match="Unsupported type for request_info filter"): + resolve_request_info_filter([123]) # type: ignore + + +class TestAgentRequestInfoResponse: + """Tests for AgentRequestInfoResponse dataclass.""" + + def test_create_response_with_messages(self): + """Test creating an AgentRequestInfoResponse with messages.""" + messages = [ChatMessage(role=Role.USER, text="Additional info")] + response = AgentRequestInfoResponse(messages=messages) + + assert response.messages == messages + + def test_from_messages_factory(self): + """Test creating response from ChatMessage list.""" + messages = [ + ChatMessage(role=Role.USER, text="Message 1"), + ChatMessage(role=Role.USER, text="Message 2"), + ] + response = AgentRequestInfoResponse.from_messages(messages) + + assert response.messages == messages - def test_skips_agent_without_name(self): - """Test that agents without names are skipped.""" - agent_with_name = MagicMock(spec=AgentProtocol) - agent_with_name.name = "valid" - agent_without_name = MagicMock(spec=AgentProtocol) - agent_without_name.name = None + def test_from_strings_factory(self): + """Test creating response from string list.""" + texts = ["First message", "Second message"] + response = AgentRequestInfoResponse.from_strings(texts) - result = resolve_request_info_filter([agent_with_name, agent_without_name]) - assert result == {"valid"} + assert len(response.messages) == 2 + assert response.messages[0].role == Role.USER + assert response.messages[0].text == "First message" + assert response.messages[1].role == Role.USER + assert response.messages[1].text == "Second message" + def test_approve_factory(self): + """Test creating an approval response (empty messages).""" + response = AgentRequestInfoResponse.approve() -class TestAgentInputRequest: - """Tests for AgentInputRequest dataclass (formerly AgentResponseReviewRequest).""" + assert response.messages == [] - def test_create_request(self): - """Test creating an AgentInputRequest with all fields.""" - conversation = [ChatMessage(role=Role.USER, text="Hello")] - request = AgentInputRequest( - target_agent_id="test_agent", - conversation=conversation, - instruction="Review this", - metadata={"key": "value"}, + +class TestAgentRequestInfoExecutor: + """Tests for AgentRequestInfoExecutor.""" + + @pytest.mark.asyncio + async def test_request_info_handler(self): + """Test that request_info handler calls ctx.request_info.""" + executor = AgentRequestInfoExecutor(id="test_executor") + + agent_run_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Agent response")]) + agent_response = AgentExecutorResponse( + executor_id="test_agent", + agent_run_response=agent_run_response, + ) + + ctx = MagicMock(spec=WorkflowContext) + ctx.request_info = AsyncMock() + + await executor.request_info(agent_response, ctx) + + ctx.request_info.assert_called_once_with(agent_response, AgentRequestInfoResponse) + + @pytest.mark.asyncio + async def test_handle_request_info_response_with_messages(self): + """Test response handler when user provides additional messages.""" + executor = AgentRequestInfoExecutor(id="test_executor") + + agent_run_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Original")]) + original_request = AgentExecutorResponse( + executor_id="test_agent", + agent_run_response=agent_run_response, ) - assert request.target_agent_id == "test_agent" - assert request.conversation == conversation - assert request.instruction == "Review this" - assert request.metadata == {"key": "value"} - - def test_create_request_defaults(self): - """Test creating an AgentInputRequest with default values.""" - request = AgentInputRequest(target_agent_id="test_agent") - - assert request.target_agent_id == "test_agent" - assert request.conversation == [] - assert request.instruction is None - assert request.metadata == {} - - def test_backward_compatibility_alias(self): - """Test that AgentResponseReviewRequest is an alias for AgentInputRequest.""" - assert AgentResponseReviewRequest is AgentInputRequest - - -class TestRequestInfoInterceptor: - """Tests for RequestInfoInterceptor executor.""" - - def test_interceptor_creation_generates_unique_id(self): - """Test creating a RequestInfoInterceptor generates unique IDs.""" - interceptor1 = RequestInfoInterceptor() - interceptor2 = RequestInfoInterceptor() - assert interceptor1.id.startswith("request_info_interceptor-") - assert interceptor2.id.startswith("request_info_interceptor-") - assert interceptor1.id != interceptor2.id - - def test_interceptor_with_custom_id(self): - """Test creating a RequestInfoInterceptor with custom ID.""" - interceptor = RequestInfoInterceptor(executor_id="custom_review") - assert interceptor.id == "custom_review" - - def test_interceptor_with_agent_filter(self): - """Test creating a RequestInfoInterceptor with agent filter.""" - agent_filter = {"agent1", "agent2"} - interceptor = RequestInfoInterceptor( - executor_id="filtered_review", - agent_filter=agent_filter, + response = AgentRequestInfoResponse.from_strings(["Additional input"]) + + ctx = MagicMock(spec=WorkflowContext) + ctx.send_message = AsyncMock() + + await executor.handle_request_info_response(original_request, response, ctx) + + # Should send new request with additional messages + ctx.send_message.assert_called_once() + call_args = ctx.send_message.call_args[0][0] + assert isinstance(call_args, AgentExecutorRequest) + assert call_args.should_respond is True + assert len(call_args.messages) == 1 + assert call_args.messages[0].text == "Additional input" + + @pytest.mark.asyncio + async def test_handle_request_info_response_approval(self): + """Test response handler when user approves (no additional messages).""" + executor = AgentRequestInfoExecutor(id="test_executor") + + agent_run_response = AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Original")]) + original_request = AgentExecutorResponse( + executor_id="test_agent", + agent_run_response=agent_run_response, ) - assert interceptor.id == "filtered_review" - assert interceptor._agent_filter == agent_filter - - def test_should_pause_for_agent_no_filter(self): - """Test that interceptor pauses for all agents when no filter is set.""" - interceptor = RequestInfoInterceptor() - assert interceptor._should_pause_for_agent("any_agent") is True - assert interceptor._should_pause_for_agent("another_agent") is True - assert interceptor._should_pause_for_agent(None) is True - - def test_should_pause_for_agent_with_filter(self): - """Test that interceptor only pauses for agents in the filter.""" - agent_filter = {"writer", "reviewer"} - interceptor = RequestInfoInterceptor(agent_filter=agent_filter) - - assert interceptor._should_pause_for_agent("writer") is True - assert interceptor._should_pause_for_agent("reviewer") is True - assert interceptor._should_pause_for_agent("drafter") is False - assert interceptor._should_pause_for_agent(None) is False - - def test_should_pause_for_agent_with_prefixed_id(self): - """Test that filter matches agent names in prefixed executor IDs.""" - agent_filter = {"writer"} - interceptor = RequestInfoInterceptor(agent_filter=agent_filter) - - # Should match the name portion after the colon - assert interceptor._should_pause_for_agent("groupchat_agent:writer") is True - assert interceptor._should_pause_for_agent("request_info:writer") is True - assert interceptor._should_pause_for_agent("groupchat_agent:editor") is False + + response = AgentRequestInfoResponse.approve() + + ctx = MagicMock(spec=WorkflowContext) + ctx.yield_output = AsyncMock() + + await executor.handle_request_info_response(original_request, response, ctx) + + # Should yield original response without modification + ctx.yield_output.assert_called_once_with(original_request) + + +class _TestAgent: + """Simple test agent implementation.""" + + def __init__(self, id: str, name: str | None = None, description: str | None = None): + self._id = id + self._name = name + self._description = description + + @property + def id(self) -> str: + return self._id + + @property + def name(self) -> str | None: + return self._name + + @property + def display_name(self) -> str: + return self._name or self._id + + @property + def description(self) -> str | None: + return self._description + + async def run( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AgentRunResponse: + """Dummy run method.""" + return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="Test response")]) + + def run_stream( + self, + messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, + *, + thread: AgentThread | None = None, + **kwargs: Any, + ) -> AsyncIterable[AgentRunResponseUpdate]: + """Dummy run_stream method.""" + + async def generator(): + yield AgentRunResponseUpdate(messages=[ChatMessage(role=Role.ASSISTANT, text="Test response stream")]) + + return generator() + + def get_new_thread(self, **kwargs: Any) -> AgentThread: + """Creates a new conversation thread for the agent.""" + return AgentThread(**kwargs) + + +class TestAgentApprovalExecutor: + """Tests for AgentApprovalExecutor.""" + + def test_initialization(self): + """Test that AgentApprovalExecutor initializes correctly.""" + agent = _TestAgent(id="test_id", name="test_agent", description="Test agent description") + + executor = AgentApprovalExecutor(agent) + + assert executor.id == "test_agent" + assert executor.description == "Test agent description" + + def test_builds_workflow_with_agent_and_request_info_executors(self): + """Test that the internal workflow is created successfully.""" + agent = _TestAgent(id="test_id", name="test_agent", description="Test description") + + executor = AgentApprovalExecutor(agent) + + # Verify the executor has a workflow + assert executor.workflow is not None + assert executor.id == "test_agent" + + def test_propagate_request_enabled(self): + """Test that AgentApprovalExecutor has propagate_request enabled.""" + agent = _TestAgent(id="test_id", name="test_agent", description="Test description") + + executor = AgentApprovalExecutor(agent) + + assert executor._propagate_request is True # type: ignore diff --git a/python/packages/core/tests/workflow/test_sequential.py b/python/packages/core/tests/workflow/test_sequential.py index 15d6e0c822..4a3076188c 100644 --- a/python/packages/core/tests/workflow/test_sequential.py +++ b/python/packages/core/tests/workflow/test_sequential.py @@ -6,6 +6,7 @@ import pytest from agent_framework import ( + AgentExecutorResponse, AgentRunResponse, AgentRunResponseUpdate, AgentThread, @@ -52,7 +53,8 @@ class _SummarizerExec(Executor): """Custom executor that summarizes by appending a short assistant message.""" @handler - async def summarize(self, conversation: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: + async def summarize(self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[list[ChatMessage]]) -> None: + conversation = agent_response.full_conversation or [] user_texts = [m.text for m in conversation if m.role == Role.USER] agents = [m.author_name or m.role for m in conversation if m.role == Role.ASSISTANT] summary = ChatMessage(role=Role.ASSISTANT, text=f"Summary of users:{len(user_texts)} agents:{len(agents)}") diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index d0140dc893..27638f8fe1 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -3,6 +3,8 @@ from collections.abc import AsyncIterable from typing import Annotated, Any +import pytest + from agent_framework import ( AgentRunResponse, AgentRunResponseUpdate, @@ -11,7 +13,7 @@ ChatMessage, ConcurrentBuilder, GroupChatBuilder, - GroupChatStateSnapshot, + GroupChatState, HandoffBuilder, Role, SequentialBuilder, @@ -26,11 +28,6 @@ _received_kwargs: list[dict[str, Any]] = [] -def _reset_received_kwargs() -> None: - """Reset the kwargs tracker before each test.""" - _received_kwargs.clear() - - @ai_function def tool_with_kwargs( action: Annotated[str, "The action to perform"], @@ -73,28 +70,6 @@ async def run_stream( yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.name} response")]) -class _EchoAgent(BaseAgent): - """Simple agent that echoes back for workflow completion.""" - - async def run( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AgentRunResponse: - return AgentRunResponse(messages=[ChatMessage(role=Role.ASSISTANT, text=f"{self.name} reply")]) - - async def run_stream( - self, - messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, - *, - thread: AgentThread | None = None, - **kwargs: Any, - ) -> AsyncIterable[AgentRunResponseUpdate]: - yield AgentRunResponseUpdate(contents=[TextContent(text=f"{self.name} reply")]) - - # region Sequential Builder Tests @@ -200,17 +175,21 @@ async def test_groupchat_kwargs_flow_to_agents() -> None: # Simple selector that takes GroupChatStateSnapshot turn_count = 0 - def simple_selector(state: GroupChatStateSnapshot) -> str | None: + def simple_selector(state: GroupChatState) -> str: nonlocal turn_count turn_count += 1 - if turn_count > 2: # Stop after 2 turns - return None + if turn_count > 2: # Loop after two turns for test + turn_count = 0 # state is a Mapping - access via dict syntax - names = list(state["participants"].keys()) + names = list(state.participants.keys()) return names[(turn_count - 1) % len(names)] workflow = ( - GroupChatBuilder().participants(chat1=agent1, chat2=agent2).set_select_speakers_func(simple_selector).build() + GroupChatBuilder() + .participants([agent1, agent2]) + .with_select_speaker_func(simple_selector) + .with_max_rounds(2) # Limit rounds to prevent infinite loop + .build() ) custom_data = {"session_id": "group123"} @@ -359,6 +338,7 @@ async def test_kwargs_preserved_across_workflow_reruns() -> None: # region Handoff Builder Tests +@pytest.mark.xfail(reason="Handoff workflow does not yet propagate kwargs to agents") async def test_handoff_kwargs_flow_to_agents() -> None: """Test that kwargs flow to agents in a handoff workflow.""" agent1 = _KwargsCapturingAgent(name="coordinator") @@ -367,8 +347,9 @@ async def test_handoff_kwargs_flow_to_agents() -> None: workflow = ( HandoffBuilder() .participants([agent1, agent2]) - .set_coordinator(agent1) - .with_interaction_mode("autonomous") + .with_start_agent(agent1) + .with_autonomous_mode() + .with_termination_condition(lambda conv: len(conv) >= 4) .build() ) @@ -395,8 +376,8 @@ async def test_magentic_kwargs_flow_to_agents() -> None: from agent_framework._workflows._magentic import ( MagenticContext, MagenticManagerBase, - _MagenticProgressLedger, - _MagenticProgressLedgerItem, + MagenticProgressLedger, + MagenticProgressLedgerItem, ) # Create a mock manager that completes after one round @@ -405,29 +386,29 @@ def __init__(self) -> None: super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=2) self.task_ledger = None - async def plan(self, context: MagenticContext) -> ChatMessage: + async def plan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Plan: Test task", author_name="manager") - async def replan(self, context: MagenticContext) -> ChatMessage: + async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Replan: Test task", author_name="manager") - async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: # Return completed on first call - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), - is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), - is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), - instruction_or_question=_MagenticProgressLedgerItem(answer="Complete", reason="Done"), - next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=MagenticProgressLedgerItem(answer="Complete", reason="Done"), + next_speaker=MagenticProgressLedgerItem(answer="agent1", reason="First"), ) - async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Final answer", author_name="manager") agent = _KwargsCapturingAgent(name="agent1") manager = _MockManager() - workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + workflow = MagenticBuilder().participants([agent]).with_standard_manager(manager=manager).build() custom_data = {"session_id": "magentic123"} @@ -446,8 +427,8 @@ async def test_magentic_kwargs_stored_in_shared_state() -> None: from agent_framework._workflows._magentic import ( MagenticContext, MagenticManagerBase, - _MagenticProgressLedger, - _MagenticProgressLedgerItem, + MagenticProgressLedger, + MagenticProgressLedgerItem, ) class _MockManager(MagenticManagerBase): @@ -455,28 +436,28 @@ def __init__(self) -> None: super().__init__(max_stall_count=3, max_reset_count=None, max_round_count=1) self.task_ledger = None - async def plan(self, context: MagenticContext) -> ChatMessage: + async def plan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Plan", author_name="manager") - async def replan(self, context: MagenticContext) -> ChatMessage: + async def replan(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Replan", author_name="manager") - async def create_progress_ledger(self, context: MagenticContext) -> _MagenticProgressLedger: - return _MagenticProgressLedger( - is_request_satisfied=_MagenticProgressLedgerItem(answer=True, reason="Done"), - is_progress_being_made=_MagenticProgressLedgerItem(answer=True, reason="Progress"), - is_in_loop=_MagenticProgressLedgerItem(answer=False, reason="Not looping"), - instruction_or_question=_MagenticProgressLedgerItem(answer="Done", reason="Done"), - next_speaker=_MagenticProgressLedgerItem(answer="agent1", reason="First"), + async def create_progress_ledger(self, magentic_context: MagenticContext) -> MagenticProgressLedger: + return MagenticProgressLedger( + is_request_satisfied=MagenticProgressLedgerItem(answer=True, reason="Done"), + is_progress_being_made=MagenticProgressLedgerItem(answer=True, reason="Progress"), + is_in_loop=MagenticProgressLedgerItem(answer=False, reason="Not looping"), + instruction_or_question=MagenticProgressLedgerItem(answer="Done", reason="Done"), + next_speaker=MagenticProgressLedgerItem(answer="agent1", reason="First"), ) - async def prepare_final_answer(self, context: MagenticContext) -> ChatMessage: + async def prepare_final_answer(self, magentic_context: MagenticContext) -> ChatMessage: return ChatMessage(role=Role.ASSISTANT, text="Final", author_name="manager") agent = _KwargsCapturingAgent(name="agent1") manager = _MockManager() - magentic_workflow = MagenticBuilder().participants(agent1=agent).with_standard_manager(manager=manager).build() + magentic_workflow = MagenticBuilder().participants([agent]).with_standard_manager(manager=manager).build() # Use MagenticWorkflow.run_stream() which goes through the kwargs attachment path custom_data = {"magentic_key": "magentic_value"} diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index 33178389c8..8ca5e0f4bc 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -115,18 +115,14 @@ For additional observability samples in Agent Framework, see the [observability | Concurrent Orchestration (Custom Aggregator) | [orchestration/concurrent_custom_aggregator.py](./orchestration/concurrent_custom_aggregator.py) | Override aggregator via callback; summarize results with an LLM | | Concurrent Orchestration (Custom Agent Executors) | [orchestration/concurrent_custom_agent_executors.py](./orchestration/concurrent_custom_agent_executors.py) | Child executors own ChatAgents; concurrent fan-out/fan-in via ConcurrentBuilder | | Concurrent Orchestration (Participant Factory) | [orchestration/concurrent_participant_factory.py](./orchestration/concurrent_participant_factory.py) | Use participant factories for state isolation between workflow instances | -| Group Chat with Agent Manager | [orchestration/group_chat_agent_manager.py](./orchestration/group_chat_agent_manager.py) | Agent-based manager using `set_manager()` to select next speaker | +| Group Chat with Agent Manager | [orchestration/group_chat_agent_manager.py](./orchestration/group_chat_agent_manager.py) | Agent-based manager using `with_agent_orchestrator()` to select next speaker | | Group Chat Philosophical Debate | [orchestration/group_chat_philosophical_debate.py](./orchestration/group_chat_philosophical_debate.py) | Agent manager moderates long-form, multi-round debate across diverse participants | | Group Chat with Simple Function Selector | [orchestration/group_chat_simple_selector.py](./orchestration/group_chat_simple_selector.py) | Group chat with a simple function selector for next speaker | | Handoff (Simple) | [orchestration/handoff_simple.py](./orchestration/handoff_simple.py) | Single-tier routing: triage agent routes to specialists, control returns to user after each specialist response | -| Handoff (Specialist-to-Specialist) | [orchestration/handoff_specialist_to_specialist.py](./orchestration/handoff_specialist_to_specialist.py) | Multi-tier routing: specialists can hand off to other specialists using `.add_handoff()` fluent API | -| Handoff (Return-to-Previous) | [orchestration/handoff_return_to_previous.py](./orchestration/handoff_return_to_previous.py) | Return-to-previous routing: after user input, routes back to the previous specialist instead of coordinator using `.enable_return_to_previous()` | -| Handoff (Autonomous) | [orchestration/handoff_autonomous.py](./orchestration/handoff_autonomous.py) | Autonomous mode: specialists iterate independently until invoking a handoff tool using `.with_interaction_mode("autonomous", autonomous_turn_limit=N)` | +| Handoff (Autonomous) | [orchestration/handoff_autonomous.py](./orchestration/handoff_autonomous.py) | Autonomous mode: specialists iterate independently until invoking a handoff tool using `.with_autonomous_mode()` | | Handoff (Participant Factory) | [orchestration/handoff_participant_factory.py](./orchestration/handoff_participant_factory.py) | Use participant factories for state isolation between workflow instances | | Magentic Workflow (Multi-Agent) | [orchestration/magentic.py](./orchestration/magentic.py) | Orchestrate multiple agents with Magentic manager and streaming | -| Magentic + Human Plan Review | [orchestration/magentic_human_plan_update.py](./orchestration/magentic_human_plan_update.py) | Human reviews/updates the plan before execution | -| Magentic + Human Stall Intervention | [orchestration/magentic_human_replan.py](./orchestration/magentic_human_replan.py) | Human intervenes when workflow stalls with `with_human_input_on_stall()` | -| Magentic + Agent Clarification | [orchestration/magentic_agent_clarification.py](./orchestration/magentic_agent_clarification.py) | Agents ask clarifying questions via `ask_user` tool with `@ai_function(approval_mode="always_require")` | +| Magentic + Human Plan Review | [orchestration/magentic_human_plan_review.py](./orchestration/magentic_human_plan_review.py) | Human reviews/updates the plan before execution | | Magentic + Checkpoint Resume | [orchestration/magentic_checkpoint.py](./orchestration/magentic_checkpoint.py) | Resume Magentic orchestration from saved checkpoints | | Sequential Orchestration (Agents) | [orchestration/sequential_agents.py](./orchestration/sequential_agents.py) | Chain agents sequentially with shared conversation context | | Sequential Orchestration (Custom Executor) | [orchestration/sequential_custom_executors.py](./orchestration/sequential_custom_executors.py) | Mix agents with a summarizer that appends a compact summary | diff --git a/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py b/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py index c94b3004d8..2a1ab234f9 100644 --- a/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py +++ b/python/samples/getting_started/workflows/agents/group_chat_workflow_as_agent.py @@ -1,20 +1,16 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging from agent_framework import ChatAgent, GroupChatBuilder from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient -logging.basicConfig(level=logging.INFO) - """ -Sample: Group Chat Orchestration (manager-directed) +Sample: Group Chat Orchestration What it does: -- Demonstrates the generic GroupChatBuilder with a language-model manager directing two agents. -- The manager coordinates a researcher (chat completions) and a writer (responses API) to solve a task. -- Uses the default group chat orchestration pipeline shared with Magentic. +- Demonstrates the generic GroupChatBuilder with a agent orchestrator directing two agents. +- The orchestrator coordinates a researcher (chat completions) and a writer (responses API) to solve a task. Prerequisites: - OpenAI environment variables configured for `OpenAIChatClient` and `OpenAIResponsesClient`. @@ -38,8 +34,13 @@ async def main() -> None: workflow = ( GroupChatBuilder() - .set_manager(manager=OpenAIChatClient().create_agent(), display_name="Coordinator") - .participants(researcher=researcher, writer=writer) + .with_agent_orchestrator( + OpenAIChatClient().create_agent( + name="Orchestrator", + instructions="You coordinate a team conversation to solve the user's task.", + ) + ) + .participants([researcher, writer]) .build() ) diff --git a/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py b/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py index 0dd1d9e644..e4f1c1e6cb 100644 --- a/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py +++ b/python/samples/getting_started/workflows/agents/handoff_workflow_as_agent.py @@ -1,230 +1,224 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import Mapping -from typing import Any +from typing import Annotated from agent_framework import ( + AgentRunResponse, ChatAgent, ChatMessage, FunctionCallContent, FunctionResultContent, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, Role, WorkflowAgent, + ai_function, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -""" -Sample: Handoff Workflow as Agent with Human-in-the-Loop +"""Sample: Handoff Workflow as Agent with Human-in-the-Loop. -Purpose: -This sample demonstrates how to use a HandoffBuilder workflow as an agent via -`.as_agent()`, enabling human-in-the-loop interactions through the standard -agent interface. The handoff pattern routes user requests through a triage agent -to specialist agents, with the workflow requesting user input as needed. +This sample demonstrates how to use a handoff workflow as an agent, enabling +human-in-the-loop interactions through the agent interface. -When using a handoff workflow as an agent: -1. The workflow emits `HandoffUserInputRequest` when it needs user input -2. `WorkflowAgent` converts this to a `FunctionCallContent` named "request_info" -3. The caller extracts `HandoffUserInputRequest` from the function call arguments -4. The caller provides a response via `FunctionResultContent` +A handoff workflow defines a pattern that assembles agents in a mesh topology, allowing +them to transfer control to each other based on the conversation context. -This differs from running the workflow directly: -- Direct workflow: Use `workflow.run_stream()` and `workflow.send_responses_streaming()` -- As agent: Use `agent.run()` with `FunctionCallContent`/`FunctionResultContent` messages +Prerequisites: + - `az login` (Azure CLI authentication) + - Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) Key Concepts: -- HandoffBuilder: Creates triage-to-specialist routing workflows -- WorkflowAgent: Wraps workflows to expose them as standard agents -- HandoffUserInputRequest: Contains conversation context and the awaiting agent -- FunctionCallContent/FunctionResultContent: Standard agent interface for HITL - -Prerequisites: -- `az login` (Azure CLI authentication) -- Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) + - Auto-registered handoff tools: HandoffBuilder automatically creates handoff tools + for each participant, allowing the coordinator to transfer control to specialists + - Termination condition: Controls when the workflow stops requesting user input + - Request/response cycle: Workflow requests input, user responds, cycle continues """ +@ai_function +def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str: + """Simulated function to process a refund for a given order number.""" + return f"Refund processed successfully for order {order_number}." + + +@ai_function +def check_order_status(order_number: Annotated[str, "Order number to check status for"]) -> str: + """Simulated function to check the status of a given order number.""" + return f"Order {order_number} is currently being processed and will ship in 2 business days." + + +@ai_function +def process_return(order_number: Annotated[str, "Order number to process return for"]) -> str: + """Simulated function to process a return for a given order number.""" + return f"Return initiated successfully for order {order_number}. You will receive return instructions via email." + + def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAgent, ChatAgent, ChatAgent]: """Create and configure the triage and specialist agents. - The triage agent dispatches requests to the appropriate specialist. - Specialists handle their domain-specific queries. + Args: + chat_client: The AzureOpenAIChatClient to use for creating agents. Returns: - Tuple of (triage_agent, refund_agent, order_agent, support_agent) + Tuple of (triage_agent, refund_agent, order_agent, return_agent) """ - triage = chat_client.create_agent( + # Triage agent: Acts as the frontline dispatcher + triage_agent = chat_client.create_agent( instructions=( - "You are frontline support triage. Read the latest user message and decide whether " - "to hand off to refund_agent, order_agent, or support_agent. Provide a brief natural-language " - "response for the user. When delegation is required, call the matching handoff tool " - "(`handoff_to_refund_agent`, `handoff_to_order_agent`, or `handoff_to_support_agent`)." + "You are frontline support triage. Route customer issues to the appropriate specialist agents " + "based on the problem described." ), name="triage_agent", ) - refund = chat_client.create_agent( - instructions=( - "You handle refund workflows. Ask for any order identifiers you require and outline the refund steps." - ), + # Refund specialist: Handles refund requests + refund_agent = chat_client.create_agent( + instructions="You process refund requests.", name="refund_agent", + # In a real application, an agent can have multiple tools; here we keep it simple + tools=[process_refund], ) - order = chat_client.create_agent( - instructions=( - "You resolve shipping and fulfillment issues. Clarify the delivery problem and describe the actions " - "you will take to remedy it." - ), + # Order/shipping specialist: Resolves delivery issues + order_agent = chat_client.create_agent( + instructions="You handle order and shipping inquiries.", name="order_agent", + # In a real application, an agent can have multiple tools; here we keep it simple + tools=[check_order_status], ) - support = chat_client.create_agent( - instructions=( - "You are a general support agent. Offer empathetic troubleshooting and gather missing details if the " - "issue does not match other specialists." - ), - name="support_agent", + # Return specialist: Handles return requests + return_agent = chat_client.create_agent( + instructions="You manage product return requests.", + name="return_agent", + # In a real application, an agent can have multiple tools; here we keep it simple + tools=[process_return], ) - return triage, refund, order, support + return triage_agent, refund_agent, order_agent, return_agent -def extract_handoff_request( - response_messages: list[ChatMessage], -) -> tuple[FunctionCallContent, HandoffUserInputRequest]: - """Extract the HandoffUserInputRequest from agent response messages. +def handle_response_and_requests(response: AgentRunResponse) -> dict[str, HandoffAgentUserRequest]: + """Process agent response messages and extract any user requests. - When a handoff workflow running as an agent needs user input, it emits a - FunctionCallContent with name="request_info" containing the HandoffUserInputRequest. + This function inspects the agent response and: + - Displays agent messages to the console + - Collects HandoffAgentUserRequest instances for response handling Args: - response_messages: Messages from the agent response + response: The AgentRunResponse from the agent run call. Returns: - Tuple of (function_call, handoff_request) - - Raises: - ValueError: If no request_info function call is found or payload is invalid + A dictionary mapping request IDs to HandoffAgentUserRequest instances. """ - for message in response_messages: + pending_requests: dict[str, HandoffAgentUserRequest] = {} + for message in response.messages: + if message.text: + print(f"- {message.author_name or message.role.value}: {message.text}") for content in message.contents: - if isinstance(content, FunctionCallContent) and content.name == WorkflowAgent.REQUEST_INFO_FUNCTION_NAME: - # Parse the function arguments to extract the HandoffUserInputRequest - args = content.arguments - if isinstance(args, str): - request_args = WorkflowAgent.RequestInfoFunctionArgs.from_json(args) - elif isinstance(args, Mapping): - request_args = WorkflowAgent.RequestInfoFunctionArgs.from_dict(dict(args)) + if isinstance(content, FunctionCallContent): + if isinstance(content.arguments, dict): + request = WorkflowAgent.RequestInfoFunctionArgs.from_dict(content.arguments) + elif isinstance(content.arguments, str): + request = WorkflowAgent.RequestInfoFunctionArgs.from_json(content.arguments) else: - raise ValueError("Unexpected argument type for request_info function call.") - - payload: Any = request_args.data - if not isinstance(payload, HandoffUserInputRequest): - raise ValueError( - f"Expected HandoffUserInputRequest in request_info payload, got {type(payload).__name__}" - ) - - return content, payload - - raise ValueError("No request_info function call found in response messages.") - - -def print_conversation(request: HandoffUserInputRequest) -> None: - """Display the conversation history from a HandoffUserInputRequest.""" - print("\n=== Conversation History ===") - for message in request.conversation: - speaker = message.author_name or message.role.value - print(f" [{speaker}]: {message.text}") - print(f" [Awaiting]: {request.awaiting_agent_id}") - print("============================") + raise ValueError("Invalid arguments type. Expecting a request info structure for this sample.") + if isinstance(request.data, HandoffAgentUserRequest): + pending_requests[request.request_id] = request.data + return pending_requests async def main() -> None: - """Main entry point demonstrating handoff workflow as agent. + """Main entry point for the handoff workflow demo. - This demo: - 1. Builds a handoff workflow with triage and specialist agents - 2. Converts it to an agent using .as_agent() - 3. Runs a multi-turn conversation with scripted user responses - 4. Demonstrates the FunctionCallContent/FunctionResultContent pattern for HITL - """ - print("Starting Handoff Workflow as Agent Demo") - print("=" * 55) + This function demonstrates: + 1. Creating triage and specialist agents + 2. Building a handoff workflow with custom termination condition + 3. Running the workflow with scripted user responses + 4. Processing events and handling user input requests + The workflow uses scripted responses instead of interactive input to make + the demo reproducible and testable. In a production application, you would + replace the scripted_responses with actual user input collection. + """ # Initialize the Azure OpenAI chat client chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - # Create agents + # Create all agents: triage + specialists triage, refund, order, support = create_agents(chat_client) - # Build the handoff workflow and convert to agent - # Termination condition: stop after 4 user messages + # Build the handoff workflow + # - participants: All agents that can participate in the workflow + # - with_start_agent: The triage agent is designated as the start agent, which means + # it receives all user input first and orchestrates handoffs to specialists + # - with_termination_condition: Custom logic to stop the request/response loop. + # Without this, the default behavior continues requesting user input until max_turns + # is reached. Here we use a custom condition that checks if the conversation has ended + # naturally (when one of the agents says something like "you're welcome"). agent = ( HandoffBuilder( name="customer_support_handoff", participants=[triage, refund, order, support], ) - .set_coordinator("triage_agent") - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 4) + .with_start_agent(triage) + .with_termination_condition( + # Custom termination: Check if one of the agents has provided a closing message. + # This looks for the last message containing "welcome", which indicates the + # conversation has concluded naturally. + lambda conversation: len(conversation) > 0 and "welcome" in conversation[-1].text.lower() + ) .build() .as_agent() # Convert workflow to agent interface ) # Scripted user responses for reproducible demo + # In a console application, replace this with: + # user_input = input("Your response: ") + # or integrate with a UI/chat interface scripted_responses = [ - "My order 1234 arrived damaged and the packaging was destroyed.", - "Yes, I'd like a refund if that's possible.", - "Thanks for your help!", + "My order 1234 arrived damaged and the packaging was destroyed. I'd like to return it.", + "Please also process a refund for order 1234.", + "Thanks for resolving this.", ] - # Start the conversation - print("\n[User]: Hello, I need assistance with my recent purchase.") - response = await agent.run("Hello, I need assistance with my recent purchase.") - - # Process conversation turns until workflow completes or responses exhausted - while True: - # Check if the agent is requesting user input - try: - function_call, handoff_request = extract_handoff_request(response.messages) - except ValueError: - # No request_info call found - workflow has completed - print("\n[Workflow completed - no pending requests]") - if response.messages: - final_text = response.messages[-1].text - if final_text: - print(f"[Final response]: {final_text}") - break - - # Display the conversation context - print_conversation(handoff_request) - - # Get the next scripted response - if not scripted_responses: - print("\n[No more scripted responses - ending conversation]") - break - - user_input = scripted_responses.pop(0) + # Start the workflow with the initial user message + print("[Starting workflow with initial user message...]\n") + initial_message = "Hello, I need assistance with my recent purchase." + print(f"- User: {initial_message}") + response = await agent.run(initial_message) + pending_requests = handle_response_and_requests(response) + + # Process the request/response cycle + # The workflow will continue requesting input until: + # 1. The termination condition is met, OR + # 2. We run out of scripted responses + while pending_requests: + for request in pending_requests.values(): + for message in request.agent_response.messages: + if message.text: + print(f"- {message.author_name or message.role.value}: {message.text}") - print(f"\n[User responding]: {user_input}") - - # Create the function result to send back to the agent - # The result is the user's text response which gets converted to ChatMessage - function_result = FunctionResultContent( - call_id=function_call.call_id, - result=user_input, - ) + if not scripted_responses: + # No more scripted responses; terminate the workflow + responses = {req_id: HandoffAgentUserRequest.terminate() for req_id in pending_requests} + else: + # Get the next scripted response + user_response = scripted_responses.pop(0) + print(f"\n- User: {user_response}") - # Send the response back to the agent - response = await agent.run(ChatMessage(role=Role.TOOL, contents=[function_result])) + # Send response(s) to all pending requests + # In this demo, there's typically one request per cycle, but the API supports multiple + responses = {req_id: HandoffAgentUserRequest.create_response(user_response) for req_id in pending_requests} - print("\n" + "=" * 55) - print("Demo completed!") + function_results = [ + FunctionResultContent(call_id=req_id, result=response) for req_id, response in responses.items() + ] + response = await agent.run(ChatMessage(role=Role.TOOL, contents=function_results)) + pending_requests = handle_response_and_requests(response) if __name__ == "__main__": - print("Initializing Handoff Workflow as Agent Sample...") asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py b/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py index f6dd8ca83d..f4e5b38e86 100644 --- a/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py +++ b/python/samples/getting_started/workflows/agents/magentic_workflow_as_agent.py @@ -1,20 +1,14 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, ChatAgent, HostedCodeInterpreterTool, MagenticBuilder, ) from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - """ Sample: Build a Magentic orchestration and wrap it as an agent. @@ -60,7 +54,7 @@ async def main() -> None: workflow = ( MagenticBuilder() - .participants(researcher=researcher_agent, coder=coder_agent) + .participants([researcher_agent, coder_agent]) .with_standard_manager( agent=manager_agent, max_round_count=10, @@ -87,20 +81,8 @@ async def main() -> None: print("\nWrapping workflow as an agent and running...") workflow_agent = workflow.as_agent(name="MagenticWorkflowAgent") async for response in workflow_agent.run_stream(task): - # AgentRunResponseUpdate objects contain the streaming agent data - # Check metadata to understand event type - props = response.additional_properties - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - print(f"\n[ORCHESTRATOR:{kind}] {response.text}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - if response.text: - print(response.text, end="", flush=True) - elif response.text: - # Fallback for any other events with text - print(response.text, end="", flush=True) + # Fallback for any other events with text + print(response.text, end="", flush=True) except Exception as e: print(f"Workflow execution failed: {e}") diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_basics.py b/python/samples/getting_started/workflows/composition/sub_workflow_basics.py index 826425a0ae..9189e70d29 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_basics.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_basics.py @@ -12,7 +12,7 @@ handler, ) from typing_extensions import Never - + """ Sample: Sub-Workflows (Basics) diff --git a/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py b/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py index c8c4f40e41..4d8ee96d06 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/concurrent_request_info.py @@ -4,17 +4,17 @@ Sample: Request Info with ConcurrentBuilder This sample demonstrates using the `.with_request_info()` method to pause a -ConcurrentBuilder workflow AFTER all parallel agents complete but BEFORE -aggregation, allowing human review and modification of the combined results. +ConcurrentBuilder workflow for specific agents, allowing human review and +modification of individual agent outputs before aggregation. Purpose: -Show how to use the request info API that pauses after concurrent agents run, -allowing review and steering of results before they are aggregated. +Show how to use the request info API that pauses for selected concurrent agents, +allowing review and steering of their results. Demonstrate: -- Configuring request info with `.with_request_info()` -- Reviewing outputs from multiple concurrent agents -- Injecting human guidance after agents execute but before aggregation +- Configuring request info with `.with_request_info()` for specific agents +- Reviewing output from individual agents during concurrent execution +- Injecting human guidance for specific agents before aggregation Prerequisites: - Azure OpenAI configured for AzureOpenAIChatClient with required environment variables @@ -25,7 +25,7 @@ from typing import Any from agent_framework import ( - AgentInputRequest, + AgentRequestInfoResponse, ChatMessage, ConcurrentBuilder, RequestInfoEvent, @@ -131,12 +131,13 @@ async def main() -> None: ConcurrentBuilder() .participants([technical_analyst, business_analyst, user_experience_analyst]) .with_aggregator(aggregate_with_synthesis) - .with_request_info() + # Only enable request info for the technical analyst agent + .with_request_info(agents=["technical_analyst"]) .build() ) # Run the workflow with human-in-the-loop - pending_responses: dict[str, str] | None = None + pending_responses: dict[str, AgentRequestInfoResponse] | None = None workflow_complete = False print("Starting multi-perspective analysis workflow...") @@ -155,26 +156,34 @@ async def main() -> None: # Process events async for event in stream: if isinstance(event, RequestInfoEvent): - if isinstance(event.data, AgentInputRequest): - # Display pre-execution context for steering concurrent agents + if isinstance(event.data, AgentExecutorResponse): + # Display agent output for review and potential modification print("\n" + "-" * 40) - print("INPUT REQUESTED (BEFORE CONCURRENT AGENTS)") - print("-" * 40) - print(f"About to call agents: {event.data.target_agent_id}") - print("Conversation context:") - recent = ( - event.data.conversation[-2:] if len(event.data.conversation) > 2 else event.data.conversation + print("INPUT REQUESTED") + print( + f"Agent {event.source_executor_id} just responded with: '{event.data.agent_run_response.text}'. " + "Please provide your feedback." ) - for msg in recent: - role = msg.role.value if msg.role else "unknown" - text = (msg.text or "")[:150] - print(f" [{role}]: {text}...") print("-" * 40) - - # Get human input to steer all agents - user_input = input("Your guidance for the analysts (or 'skip' to continue): ") # noqa: ASYNC250 + if event.data.full_conversation: + print("Conversation context:") + recent = ( + event.data.full_conversation[-2:] + if len(event.data.full_conversation) > 2 + else event.data.full_conversation + ) + for msg in recent: + name = msg.author_name or msg.role.value + text = (msg.text or "")[:150] + print(f" [{name}]: {text}...") + print("-" * 40) + + # Get human input to steer this agent's contribution + user_input = input("Your guidance for the analysts (or 'skip' to approve): ") # noqa: ASYNC250 if user_input.lower() == "skip": - user_input = "Please analyze objectively from your unique perspective." + user_input = AgentRequestInfoResponse.approve() + else: + user_input = AgentRequestInfoResponse.from_strings([user_input]) pending_responses = {event.request_id: user_input} print("(Resuming workflow...)") @@ -189,9 +198,8 @@ async def main() -> None: print(event.data) workflow_complete = True - elif isinstance(event, WorkflowStatusEvent): - if event.state == WorkflowRunState.IDLE: - workflow_complete = True + elif isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + workflow_complete = True if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py b/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py index c3a193a6a8..d76308c657 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/group_chat_request_info.py @@ -25,7 +25,9 @@ import asyncio from agent_framework import ( - AgentInputRequest, + AgentExecutorResponse, + AgentRequestInfoResponse, + AgentRunResponse, AgentRunUpdateEvent, ChatMessage, GroupChatBuilder, @@ -69,18 +71,17 @@ async def main() -> None: ), ) - # Manager orchestrates the discussion - manager = chat_client.create_agent( - name="manager", + # Orchestrator coordinates the discussion + orchestrator = chat_client.create_agent( + name="orchestrator", instructions=( - "You are a discussion manager coordinating a team conversation between optimist, " - "pragmatist, and creative. Your job is to select who speaks next.\n\n" + "You are a discussion manager coordinating a team conversation between participants. " + "Your job is to select who speaks next.\n\n" "RULES:\n" "1. Rotate through ALL participants - do not favor any single participant\n" "2. Each participant should speak at least once before any participant speaks twice\n" - "3. If human feedback redirects the topic, acknowledge it and continue rotating\n" - "4. Continue for at least 5 participant turns before concluding\n" - "5. Do NOT select the same participant twice in a row" + "3. Continue for at least 5 rounds before ending the discussion\n" + "4. Do NOT select the same participant twice in a row" ), ) @@ -88,7 +89,7 @@ async def main() -> None: # Using agents= filter to only pause before pragmatist speaks (not every turn) workflow = ( GroupChatBuilder() - .set_manager(manager=manager, display_name="Discussion Manager") + .with_agent_orchestrator(orchestrator) .participants([optimist, pragmatist, creative]) .with_max_rounds(6) .with_request_info(agents=[pragmatist]) # Only pause before pragmatist speaks @@ -96,7 +97,7 @@ async def main() -> None: ) # Run the workflow with human-in-the-loop - pending_responses: dict[str, str] | None = None + pending_responses: dict[str, AgentRequestInfoResponse] | None = None workflow_complete = False current_agent: str | None = None # Track current streaming agent @@ -130,28 +131,28 @@ async def main() -> None: elif isinstance(event, RequestInfoEvent): current_agent = None # Reset for next agent - if isinstance(event.data, AgentInputRequest): + if isinstance(event.data, AgentExecutorResponse): # Display pre-agent context for human input print("\n" + "-" * 40) print("INPUT REQUESTED") - print(f"About to call agent: {event.data.target_agent_id}") + print(f"About to call agent: {event.source_executor_id}") print("-" * 40) print("Conversation context:") - recent = ( - event.data.conversation[-3:] if len(event.data.conversation) > 3 else event.data.conversation - ) + agent_run_response: AgentRunResponse = event.data.agent_run_response + messages: list[ChatMessage] = agent_run_response.messages + recent: list[ChatMessage] = messages[-3:] if len(messages) > 3 else messages # type: ignore for msg in recent: - role = msg.role.value if msg.role else "unknown" + name = msg.author_name or "unknown" text = (msg.text or "")[:100] - print(f" [{role}]: {text}...") + print(f" [{name}]: {text}...") print("-" * 40) # Get human input to steer the agent - user_input = input("Steer the discussion (or 'skip' to continue): ") # noqa: ASYNC250 + user_input = input(f"Feedback for {event.source_executor_id} (or 'skip' to approve): ") # noqa: ASYNC250 if user_input.lower() == "skip": - user_input = "Please continue the discussion naturally." - - pending_responses = {event.request_id: user_input} + pending_responses = {event.request_id: AgentRequestInfoResponse.approve()} + else: + pending_responses = {event.request_id: AgentRequestInfoResponse.from_strings([user_input])} print("(Resuming discussion...)") elif isinstance(event, WorkflowOutputEvent): @@ -160,11 +161,12 @@ async def main() -> None: print("=" * 60) print("Final conversation:") if event.data: - messages: list[ChatMessage] = event.data[-4:] + messages: list[ChatMessage] = event.data for msg in messages: - role = msg.role.value if msg.role else "unknown" + role = msg.role.value.capitalize() + name = msg.author_name or "unknown" text = (msg.text or "")[:200] - print(f"[{role}]: {text}...") + print(f"[{role}][{name}]: {text}...") workflow_complete = True elif isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: diff --git a/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py b/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py index 55c8652984..8e735bfb1a 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/sequential_request_info.py @@ -4,11 +4,11 @@ Sample: Request Info with SequentialBuilder This sample demonstrates using the `.with_request_info()` method to pause a -SequentialBuilder workflow BEFORE each agent runs, allowing external input -(e.g., human steering) before the agent responds. +SequentialBuilder workflow AFTER each agent runs, allowing external input +(e.g., human feedback) for review and optional iteration. Purpose: -Show how to use the request info API that pauses before every agent response, +Show how to use the request info API that pauses after every agent response, using the standard request_info pattern for consistency. Demonstrate: @@ -24,7 +24,8 @@ import asyncio from agent_framework import ( - AgentInputRequest, + AgentExecutorResponse, + AgentRequestInfoResponse, ChatMessage, RequestInfoEvent, SequentialBuilder, @@ -48,7 +49,7 @@ async def main() -> None: editor = chat_client.create_agent( name="editor", instructions=( - "You are an editor. Review the draft and suggest improvements. " + "You are an editor. Review the draft and make improvements. " "Incorporate any human feedback that was provided." ), ) @@ -61,11 +62,17 @@ async def main() -> None: ), ) - # Build workflow with request info enabled (pauses before each agent) - workflow = SequentialBuilder().participants([drafter, editor, finalizer]).with_request_info().build() + # Build workflow with request info enabled (pauses after each agent responds) + workflow = ( + SequentialBuilder() + .participants([drafter, editor, finalizer]) + # Only enable request info for the editor agent + .with_request_info(agents=["editor"]) + .build() + ) # Run the workflow with request info handling - pending_responses: dict[str, str] | None = None + pending_responses: dict[str, AgentRequestInfoResponse] | None = None workflow_complete = False print("Starting document review workflow...") @@ -84,26 +91,34 @@ async def main() -> None: # Process events async for event in stream: if isinstance(event, RequestInfoEvent): - if isinstance(event.data, AgentInputRequest): - # Display pre-agent context for steering + if isinstance(event.data, AgentExecutorResponse): + # Display agent response and conversation context for review print("\n" + "-" * 40) print("REQUEST INFO: INPUT REQUESTED") - print(f"About to call agent: {event.data.target_agent_id}") - print("-" * 40) - print("Conversation context:") - recent = ( - event.data.conversation[-2:] if len(event.data.conversation) > 2 else event.data.conversation + print( + f"Agent {event.source_executor_id} just responded with: '{event.data.agent_run_response.text}'. " + "Please provide your feedback." ) - for msg in recent: - role = msg.role.value if msg.role else "unknown" - text = (msg.text or "")[:150] - print(f" [{role}]: {text}...") print("-" * 40) - - # Get input to steer the agent - user_input = input("Your guidance (or 'skip' to continue): ") # noqa: ASYNC250 + if event.data.full_conversation: + print("Conversation context:") + recent = ( + event.data.full_conversation[-2:] + if len(event.data.full_conversation) > 2 + else event.data.full_conversation + ) + for msg in recent: + name = msg.author_name or msg.role.value + text = (msg.text or "")[:150] + print(f" [{name}]: {text}...") + print("-" * 40) + + # Get feedback on the agent's response (approve or request iteration) + user_input = input("Your guidance (or 'skip' to approve): ") # noqa: ASYNC250 if user_input.lower() == "skip": - user_input = "Please continue naturally." + user_input = AgentRequestInfoResponse.approve() + else: + user_input = AgentRequestInfoResponse.from_strings([user_input]) pending_responses = {event.request_id: user_input} print("(Resuming workflow...)") diff --git a/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py b/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py index 3bc79fcddc..12475205d3 100644 --- a/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py +++ b/python/samples/getting_started/workflows/orchestration/group_chat_agent_manager.py @@ -1,8 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging -from typing import cast from agent_framework import ( AgentRunUpdateEvent, @@ -15,8 +13,6 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -logging.basicConfig(level=logging.INFO) - """ Sample: Group Chat with Agent-Based Manager @@ -29,50 +25,54 @@ - OpenAI environment variables configured for OpenAIChatClient """ - -def _get_chat_client() -> AzureOpenAIChatClient: - return AzureOpenAIChatClient(credential=AzureCliCredential()) - - -async def main() -> None: - # Create coordinator agent with structured output for speaker selection - # Note: response_format is enforced to ManagerSelectionResponse by set_manager() - coordinator = ChatAgent( - name="Coordinator", - description="Coordinates multi-agent collaboration by selecting speakers", - instructions=""" +ORCHESTRATOR_AGENT_INSTRUCTIONS = """ You coordinate a team conversation to solve the user's task. -Review the conversation history and select the next participant to speak. - Guidelines: - Start with Researcher to gather information - Then have Writer synthesize the final answer - Only finish after both have contributed meaningfully -- Allow for multiple rounds of information gathering if needed -""", - chat_client=_get_chat_client(), +""" + + +async def main() -> None: + # Create a chat client using Azure OpenAI and Azure CLI credentials for all agents + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + # Orchestrator agent that manages the conversation + # Note: This agent (and the underlying chat client) must support structured outputs. + # The group chat workflow relies on this to parse the orchestrator's decisions. + # `response_format` is set internally by the GroupChat workflow when the agent is invoked. + orchestrator_agent = ChatAgent( + name="Orchestrator", + description="Coordinates multi-agent collaboration by selecting speakers", + instructions=ORCHESTRATOR_AGENT_INSTRUCTIONS, + chat_client=chat_client, ) + # Participant agents researcher = ChatAgent( name="Researcher", description="Collects relevant background information", instructions="Gather concise facts that help a teammate answer the question.", - chat_client=_get_chat_client(), + chat_client=chat_client, ) writer = ChatAgent( name="Writer", description="Synthesizes polished answers from gathered information", instructions="Compose clear and structured answers using any notes provided.", - chat_client=_get_chat_client(), + chat_client=chat_client, ) + # Build the group chat workflow workflow = ( GroupChatBuilder() - .set_manager(coordinator, display_name="Orchestrator") - .with_termination_condition(lambda messages: sum(1 for msg in messages if msg.role == Role.ASSISTANT) >= 2) + .with_agent_orchestrator(orchestrator_agent) .participants([researcher, writer]) + # Set a hard termination condition: stop after 4 assistant messages + # The agent orchestrator will intelligently decide when to end before this limit but just in case + .with_termination_condition(lambda messages: sum(1 for msg in messages if msg.role == Role.ASSISTANT) >= 4) .build() ) @@ -82,30 +82,35 @@ async def main() -> None: print(f"TASK: {task}\n") print("=" * 80) - final_conversation: list[ChatMessage] = [] + # Keep track of the last executor to format output nicely in streaming mode last_executor_id: str | None = None + output_event: WorkflowOutputEvent | None = None async for event in workflow.run_stream(task): if isinstance(event, AgentRunUpdateEvent): eid = event.executor_id if eid != last_executor_id: if last_executor_id is not None: - print() + print("\n") print(f"{eid}:", end=" ", flush=True) last_executor_id = eid print(event.data, end="", flush=True) elif isinstance(event, WorkflowOutputEvent): - final_conversation = cast(list[ChatMessage], event.data) - - if final_conversation and isinstance(final_conversation, list): - print("\n\n" + "=" * 80) - print("FINAL CONVERSATION") - print("=" * 80) - for msg in final_conversation: - author = getattr(msg, "author_name", "Unknown") - text = getattr(msg, "text", str(msg)) - print(f"\n[{author}]") - print(text) - print("-" * 80) + output_event = event + + # The output of the workflow is the full list of messages exchanged + if output_event: + if not isinstance(output_event.data, list) or not all( + isinstance(msg, ChatMessage) + for msg in output_event.data # type: ignore + ): + raise RuntimeError("Unexpected output event data format.") + print("\n" + "=" * 80) + print("\nFINAL OUTPUT (The conversation history)\n") + for msg in output_event.data: # type: ignore + assert isinstance(msg, ChatMessage) + print(f"{msg.author_name or msg.role}: {msg.text}\n") + else: + raise RuntimeError("Workflow did not produce a final output event.") if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py b/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py index 7059a84e32..a26b9df4d0 100644 --- a/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py +++ b/python/samples/getting_started/workflows/orchestration/group_chat_philosophical_debate.py @@ -211,7 +211,7 @@ async def main() -> None: workflow = ( GroupChatBuilder() - .set_manager(moderator, display_name="Moderator") + .with_agent_orchestrator(moderator) .participants([farmer, developer, teacher, activist, spiritual_leader, artist, immigrant, doctor]) .with_termination_condition(lambda messages: sum(1 for msg in messages if msg.role == Role.ASSISTANT) >= 10) .build() @@ -241,13 +241,11 @@ async def main() -> None: async for event in workflow.run_stream(f"Please begin the discussion on: {topic}"): if isinstance(event, AgentRunUpdateEvent): - speaker_id = event.executor_id.replace("groupchat_agent:", "") - - if speaker_id != current_speaker: + if event.executor_id != current_speaker: if current_speaker is not None: print("\n") - print(f"[{speaker_id}]", flush=True) - current_speaker = speaker_id + print(f"[{event.executor_id}]", flush=True) + current_speaker = event.executor_id print(event.data, end="", flush=True) @@ -286,10 +284,6 @@ async def main() -> None: DISCUSSION BEGINS ================================================================================ - [Moderator] - {"selected_participant":"Farmer","instruction":"Please start by sharing what living a good life means to you, - especially from your perspective living in a rural area in Southeast Asia.","finish":false,"final_message":null} - [Farmer] To me, a good life is deeply intertwined with the rhythm of the land and the nurturing of relationships with my family and community. It means cultivating crops that respect our environment, ensuring sustainability for future @@ -298,11 +292,6 @@ async def main() -> None: wealth. It's the simple moments, like sharing stories with my children under the stars, that truly define a good life. What good is progress if it isolates us from those we love and the land that sustains us? - [Moderator] - {"selected_participant":"Developer","instruction":"Given the insights shared by the Farmer, please discuss what a - good life means to you as a software developer in an urban setting in the United States and how it might contrast - with or complement the Farmer's view.","finish":false,"final_message":null} - [Developer] As a software developer in an urban environment, a good life for me hinges on the intersection of innovation, creativity, and balance. It's about having the freedom to explore new technologies that can solve real-world @@ -312,11 +301,6 @@ async def main() -> None: rich personal experiences. The challenge is finding harmony between technological progress and preserving the intimate human connections that truly enrich our lives. - [Moderator] - {"selected_participant":"SpiritualLeader","instruction":"Reflect on both the Farmer's and Developer's perspectives - and share your view of what constitutes a good life, particularly from your spiritual and cultural standpoint in - the Middle East.","finish":false,"final_message":null} - [SpiritualLeader] From my spiritual perspective, a good life embodies a balance between personal fulfillment and service to others, rooted in compassion and community. In our teachings, we emphasize that true happiness comes from helping those in @@ -326,11 +310,6 @@ async def main() -> None: with those around us. Ultimately, as we align our personal beliefs with our communal responsibilities, we cultivate a richness that transcends material wealth. - [Moderator] - {"selected_participant":"Activist","instruction":"Add to the discussion by sharing your perspective on what a good - life entails, particularly from your background as a young activist in South America.","finish":false, - "final_message":null} - [Activist] As a young activist in South America, a good life for me is about advocating for social justice and environmental sustainability. It means living in a society where everyone's rights are respected and where marginalized voices, @@ -341,11 +320,6 @@ async def main() -> None: not just lived for oneself but is deeply tied to the well-being of our communities and the health of our environment. How can we, regardless of our backgrounds, collaborate to foster these essential changes? - [Moderator] - {"selected_participant":"Teacher","instruction":"Considering the views shared so far, tell us how your experience - as a retired history teacher from Eastern Europe shapes your understanding of a good life, perhaps reflecting on - lessons from the past and their impact on present-day life choices.","finish":false,"final_message":null} - [Teacher] As a retired history teacher from Eastern Europe, my understanding of a good life is deeply rooted in the lessons drawn from history and the struggle for freedom and dignity. Historical events, such as the fall of the Iron @@ -357,11 +331,6 @@ async def main() -> None: contributions to the rich tapestry of our shared humanity. How can we ensure that the lessons of history inform a more compassionate and just society moving forward? - [Moderator] - {"selected_participant":"Artist","instruction":"Expound on the themes and perspectives discussed so far by sharing - how, as an artist from Africa, you define a good life and how art plays a role in that vision.","finish":false, - "final_message":null} - [Artist] As an artist from Africa, I define a good life as one steeped in cultural expression, storytelling, and the celebration of our collective memories. Art is a powerful medium through which we capture our histories, struggles, @@ -373,19 +342,6 @@ async def main() -> None: collective good, fostering empathy and understanding among diverse communities. How can we harness art to bridge differences and amplify marginalized voices in our pursuit of a good life? - [Moderator] - {"selected_participant":null,"instruction":null,"finish":true,"final_message":"As our discussion unfolds, several - key themes have gracefully emerged, reflecting the richness of diverse perspectives on what constitutes a good life. - From the rural farmer's integration with the land to the developer's search for balance between technology and - personal connection, each viewpoint validates that fulfillment, at its core, transcends material wealth. The - spiritual leader and the activist highlight the importance of community and social justice, while the history - teacher and the artist remind us of the lessons and narratives that shape our cultural and personal identities. - - Ultimately, the good life seems to revolve around meaningful relationships, honoring our legacies while striving for - progress, and nurturing both our inner selves and external communities. This dialogue demonstrates that despite our - varied backgrounds and experiences, the quest for a good life binds us together, urging cooperation and empathy in - our shared human journey."} - ================================================================================ DISCUSSION SUMMARY ================================================================================ diff --git a/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py b/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py index 1fd074ca4d..517ae313f3 100644 --- a/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py +++ b/python/samples/getting_started/workflows/orchestration/group_chat_simple_selector.py @@ -1,113 +1,134 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import logging -from typing import cast -from agent_framework import ChatAgent, ChatMessage, GroupChatBuilder, GroupChatStateSnapshot, WorkflowOutputEvent -from agent_framework.openai import OpenAIChatClient - -logging.basicConfig(level=logging.INFO) +from agent_framework import ( + AgentRunUpdateEvent, + ChatAgent, + ChatMessage, + GroupChatBuilder, + GroupChatState, + WorkflowOutputEvent, +) +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential """ -Sample: Group Chat with Simple Speaker Selector Function +Sample: Group Chat with a round-robin speaker selector What it does: -- Demonstrates the set_select_speakers_func() API for GroupChat orchestration +- Demonstrates the with_select_speaker_func() API for GroupChat orchestration - Uses a pure Python function to control speaker selection based on conversation state -- Alternates between researcher and writer agents in a simple round-robin pattern -- Shows how to access conversation history, round index, and participant metadata - -Key pattern: - def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - # state contains: task, participants, conversation, history, round_index - # Return participant name to continue, or None to finish - ... Prerequisites: - OpenAI environment variables configured for OpenAIChatClient """ -def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: - """Simple speaker selector that alternates between researcher and writer. - - This function demonstrates the core pattern: - 1. Examine the current state of the group chat - 2. Decide who should speak next - 3. Return participant name or None to finish - - Args: - state: Immutable snapshot containing: - - task: ChatMessage - original user task - - participants: dict[str, str] - participant names → descriptions - - conversation: tuple[ChatMessage, ...] - full conversation history - - history: tuple[GroupChatTurn, ...] - turn-by-turn with speaker attribution - - round_index: int - number of selection rounds so far - - pending_agent: str | None - currently active agent (if any) - - Returns: - Name of next speaker, or None to finish the conversation - """ - round_idx = state["round_index"] - history = state["history"] +def round_robin_selector(state: GroupChatState) -> str: + """A round-robin selector function that picks the next speaker based on the current round index.""" - # Finish after 4 turns (researcher → writer → researcher → writer) - if round_idx >= 4: - return None + participant_names = list(state.participants.keys()) + return participant_names[state.current_round % len(participant_names)] - # Get the last speaker from history - last_speaker = history[-1].speaker if history else None - # Simple alternation: researcher → writer → researcher → writer - if last_speaker == "Researcher": - return "Writer" - return "Researcher" +async def main() -> None: + # Create a chat client using Azure OpenAI and Azure CLI credentials for all agents + chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) + + # Participant agents + expert = ChatAgent( + name="PythonExpert", + instructions=( + "You are an expert in Python in a workgroup. " + "Your job is to answer Python related questions and refine your answer " + "based on feedback from all the other participants." + ), + chat_client=chat_client, + ) + verifier = ChatAgent( + name="AnswerVerifier", + instructions=( + "You are a programming expert in a workgroup. " + f"Your job is to review the answer provided by {expert.name} and point " + "out statements that are technically true but practically dangerous." + "If there is nothing woth pointing out, respond with 'The answer looks good to me.'" + ), + chat_client=chat_client, + ) -async def main() -> None: - researcher = ChatAgent( - name="Researcher", - description="Collects relevant background information.", - instructions="Gather concise facts that help answer the question. Be brief.", - chat_client=OpenAIChatClient(model_id="gpt-4o-mini"), + clarifier = ChatAgent( + name="AnswerClarifier", + instructions=( + "You are an accessibility expert in a workgroup. " + f"Your job is to review the answer provided by {expert.name} and point " + "out jargons or complex terms that may be difficult for a beginner to understand." + "If there is nothing worth pointing out, respond with 'The answer looks clear to me.'" + ), + chat_client=chat_client, ) - writer = ChatAgent( - name="Writer", - description="Synthesizes a polished answer using the gathered notes.", - instructions="Compose a clear, structured answer using any notes provided.", - chat_client=OpenAIChatClient(model_id="gpt-4o-mini"), + skeptic = ChatAgent( + name="Skeptic", + instructions=( + "You are a devil's advocate in a workgroup. " + f"Your job is to review the answer provided by {expert.name} and point " + "out caveats, exceptions, and alternative perspectives." + "If there is nothing worth pointing out, respond with 'I have no further questions.'" + ), + chat_client=chat_client, ) - # Two ways to specify participants: - # 1. List form - uses agent.name attribute: .participants([researcher, writer]) - # 2. Dict form - explicit names: .participants(researcher=researcher, writer=writer) + # Build the group chat workflow workflow = ( GroupChatBuilder() - .set_select_speakers_func(select_next_speaker, display_name="Orchestrator") - .participants([researcher, writer]) # Uses agent.name for participant names + .participants([expert, verifier, clarifier, skeptic]) + .with_select_speaker_func(round_robin_selector) + # Set a hard termination condition: stop after 6 messages (user task + one full rounds + 1) + # One round is expert -> verifier -> clarifier -> skeptic, after which the expert gets to respond again. + # This will end the conversation after the expert has spoken 2 times (one iteration loop) + # Note: it's possible that the expert gets it right the first time and the other participants + # have nothing to add, but for demo purposes we want to see at least one full round of interaction. + .with_termination_condition(lambda conversation: len(conversation) >= 6) .build() ) - task = "What are the key benefits of using async/await in Python?" + task = "How does Python’s Protocol differ from abstract base classes?" - print("\nStarting Group Chat with Simple Speaker Selector...\n") + print("\nStarting Group Chat with round-robin speaker selector...\n") print(f"TASK: {task}\n") print("=" * 80) + # Keep track of the last executor to format output nicely in streaming mode + last_executor_id: str | None = None + output_event: WorkflowOutputEvent | None = None async for event in workflow.run_stream(task): - if isinstance(event, WorkflowOutputEvent): - conversation = cast(list[ChatMessage], event.data) - if isinstance(conversation, list): - print("\n===== Final Conversation =====\n") - for msg in conversation: - author = getattr(msg, "author_name", "Unknown") - text = getattr(msg, "text", str(msg)) - print(f"[{author}]\n{text}\n") - print("-" * 80) - - print("\nWorkflow completed.") + if isinstance(event, AgentRunUpdateEvent): + eid = event.executor_id + if eid != last_executor_id: + if last_executor_id is not None: + print("\n") + print(f"{eid}:", end=" ", flush=True) + last_executor_id = eid + print(event.data, end="", flush=True) + elif isinstance(event, WorkflowOutputEvent): + output_event = event + + # The output of the workflow is the full list of messages exchanged + if output_event: + if not isinstance(output_event.data, list) or not all( + isinstance(msg, ChatMessage) + for msg in output_event.data # type: ignore + ): + raise RuntimeError("Unexpected output event data format.") + print("\n" + "=" * 80) + print("\nFINAL OUTPUT (The conversation history)\n") + for msg in output_event.data: # type: ignore + assert isinstance(msg, ChatMessage) + print(f"{msg.author_name or msg.role}: {msg.text}\n") + else: + raise RuntimeError("Workflow did not produce a final output event.") if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py b/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py index 154f768d09..2ea327751d 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_autonomous.py @@ -13,6 +13,7 @@ HostedWebSearchTool, WorkflowEvent, WorkflowOutputEvent, + resolve_agent_id, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential @@ -21,7 +22,7 @@ """Sample: Autonomous handoff workflow with agent iteration. -This sample demonstrates `with_interaction_mode("autonomous")`, where agents continue +This sample demonstrates `.with_autonomous_mode()`, where agents continue iterating on their task until they explicitly invoke a handoff tool. This allows specialists to perform long-running autonomous work (research, coding, analysis) without prematurely returning control to the coordinator or user. @@ -35,7 +36,7 @@ Key Concepts: - Autonomous interaction mode: agents iterate until they handoff - - Turn limits: use `with_interaction_mode("autonomous", autonomous_turn_limit=N)` to cap total iterations + - Turn limits: use `.with_autonomous_mode(turn_limits={agent_name: N})` to cap iterations per agent """ @@ -53,7 +54,7 @@ def create_agents( research_agent = chat_client.create_agent( instructions=( - "You are a research specialist that explores topics thoroughly on the Microsoft Learn Site." + "You are a research specialist that explores topics thoroughly using web search. " "When given a research task, break it down into multiple aspects and explore each one. " "Continue your research across multiple responses - don't try to finish everything in one " "response. After each response, think about what else needs to be explored. When you have " @@ -112,11 +113,21 @@ async def main() -> None: name="autonomous_iteration_handoff", participants=[coordinator, research_agent, summary_agent], ) - .set_coordinator(coordinator) + .with_start_agent(coordinator) .add_handoff(coordinator, [research_agent, summary_agent]) - .add_handoff(research_agent, coordinator) # Research can hand back to coordinator - .add_handoff(summary_agent, coordinator) - .with_interaction_mode("autonomous", autonomous_turn_limit=15) + .add_handoff(research_agent, [coordinator]) # Research can hand back to coordinator + .add_handoff(summary_agent, [coordinator]) + .with_autonomous_mode( + # You can set turn limits per agent to allow some agents to go longer. + # If a limit is not set, the agent will get an default limit: 50. + # Internally, handoff prefers agent names as the agent identifiers if set. + # Otherwise, it falls back to agent IDs. + turn_limits={ + resolve_agent_id(coordinator): 5, + resolve_agent_id(research_agent): 10, + resolve_agent_id(summary_agent): 5, + } + ) .with_termination_condition( # Terminate after coordinator provides 5 assistant responses lambda conv: sum(1 for msg in conv if msg.author_name == "coordinator" and msg.role.value == "assistant") @@ -133,10 +144,10 @@ async def main() -> None: """ Expected behavior: - Coordinator routes to research_agent. - - Research agent iterates multiple times, exploring different aspects of renewable energy. + - Research agent iterates multiple times, exploring different aspects of Microsoft Agent Framework. - Each iteration adds to the conversation without returning to coordinator. - After thorough research, research_agent calls handoff to coordinator. - - Coordinator provides final summary. + - Coordinator routes to summary_agent for final summary. In autonomous mode, agents continue working until they invoke a handoff tool, allowing the research_agent to perform 3-4+ responses before handing off. diff --git a/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py b/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py index 1b676c5ffd..58a562c4cf 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_participant_factory.py @@ -2,28 +2,30 @@ import asyncio import logging -from collections.abc import AsyncIterable -from typing import cast +from typing import Annotated, cast from agent_framework import ( + AgentRunEvent, + AgentRunResponse, ChatAgent, ChatMessage, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, + HandoffSentEvent, RequestInfoEvent, - Role, Workflow, WorkflowEvent, WorkflowOutputEvent, + WorkflowRunState, + WorkflowStatusEvent, ai_function, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -from typing import Annotated logging.basicConfig(level=logging.ERROR) -"""Sample: Autonomous handoff workflow with agent factory. +"""Sample: Handoff workflow with participant factories for state isolation. This sample demonstrates how to use participant factories in HandoffBuilder to create agents dynamically. @@ -33,7 +35,7 @@ requests or tasks in parallel with stateful participants. Routing Pattern: - User -> Coordinator -> Specialist (iterates N times) -> Handoff -> Final Output + User -> Triage Agent -> Specialist (Refund/Order Status/Return) -> User Prerequisites: - `az login` (Azure CLI authentication) @@ -41,6 +43,7 @@ Key Concepts: - Participant factories: create agents via factory functions for isolation + - State isolation: each workflow instance gets its own agent instances """ @@ -103,21 +106,6 @@ def create_return_agent() -> ChatAgent: ) -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Collect all events from an async stream into a list. - - This helper drains the workflow's event stream so we can process events - synchronously after each workflow step completes. - - Args: - stream: Async iterable of WorkflowEvent - - Returns: - List of all events from the stream - """ - return [event async for event in stream] - - def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: """Process workflow events and extract any pending user input requests. @@ -136,75 +124,98 @@ def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: requests: list[RequestInfoEvent] = [] for event in events: + # AgentRunEvent: Contains messages generated by agents during their turn + if isinstance(event, AgentRunEvent): + for message in event.data.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue + speaker = message.author_name or message.role.value + print(f"- {speaker}: {message.text}") + + # HandoffSentEvent: Indicates a handoff has been initiated + if isinstance(event, HandoffSentEvent): + print(f"\n[Handoff from {event.source} to {event.target} initiated.]") + + # WorkflowStatusEvent: Indicates workflow state changes + if isinstance(event, WorkflowStatusEvent) and event.state in { + WorkflowRunState.IDLE, + WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, + }: + print(f"\n[Workflow Status] {event.state.name}") + # WorkflowOutputEvent: Contains the final conversation when workflow terminates - if isinstance(event, WorkflowOutputEvent): + elif isinstance(event, WorkflowOutputEvent): conversation = cast(list[ChatMessage], event.data) if isinstance(conversation, list): print("\n=== Final Conversation Snapshot ===") for message in conversation: speaker = message.author_name or message.role.value - print(f"- {speaker}: {message.text}") + print(f"- {speaker}: {message.text or [content.type for content in message.contents]}") print("===================================") # RequestInfoEvent: Workflow is requesting user input elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - _print_agent_responses_since_last_user_message(event.data) + if isinstance(event.data, HandoffAgentUserRequest): + _print_handoff_agent_user_request(event.data.agent_response) requests.append(event) return requests -def _print_agent_responses_since_last_user_message(request: HandoffUserInputRequest) -> None: - """Display agent responses since the last user message in a handoff request. +def _print_handoff_agent_user_request(response: AgentRunResponse) -> None: + """Display the agent's response messages when requesting user input. - The HandoffUserInputRequest contains the full conversation history so far, - allowing the user to see what's been discussed before providing their next input. + This will happen when an agent generates a response that doesn't trigger + a handoff, i.e., the agent is asking the user for more information. Args: - request: The user input request containing conversation and prompt + response: The AgentRunResponse from the agent requesting user input """ - if not request.conversation: - raise RuntimeError("HandoffUserInputRequest missing conversation history.") - - # Reverse iterate to collect agent responses since last user message - agent_responses: list[ChatMessage] = [] - for message in request.conversation[::-1]: - if message.role == Role.USER: - break - agent_responses.append(message) - - # Print agent responses in original order - agent_responses.reverse() - for message in agent_responses: + if not response.messages: + raise RuntimeError("Cannot print agent responses: response has no messages.") + + print("\n[Agent is requesting your input...]") + + # Print agent responses + for message in response.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue speaker = message.author_name or message.role.value print(f"- {speaker}: {message.text}") -async def _run_Workflow(workflow: Workflow, user_inputs: list[str]) -> None: +async def _run_workflow(workflow: Workflow, user_inputs: list[str]) -> None: """Run the workflow with the given user input and display events.""" print(f"- User: {user_inputs[0]}") - events = await _drain(workflow.run_stream(user_inputs[0])) - pending_requests = _handle_events(events) + workflow_result = await workflow.run(user_inputs[0]) + pending_requests = _handle_events(workflow_result) # Process the request/response cycle # The workflow will continue requesting input until: # 1. The termination condition is met (4 user messages in this case), OR # 2. We run out of scripted responses - while pending_requests and user_inputs[1:]: - # Get the next scripted response - user_response = user_inputs.pop(1) - print(f"\n- User: {user_response}") - - # Send response(s) to all pending requests - # In this demo, there's typically one request per cycle, but the API supports multiple - responses = {req.request_id: user_response for req in pending_requests} + while pending_requests: + if user_inputs[1:]: + # Get the next scripted response + user_response = user_inputs.pop(1) + print(f"\n- User: {user_response}") + + # Send response(s) to all pending requests + # In this demo, there's typically one request per cycle, but the API supports multiple + responses = { + req.request_id: HandoffAgentUserRequest.create_response(user_response) for req in pending_requests + } + else: + # No more scripted responses; terminate the workflow + responses = {req.request_id: HandoffAgentUserRequest.terminate() for req in pending_requests} # Send responses and get new events # We use send_responses_streaming() to get events as they occur, allowing us to # display agent responses in real-time and handle new requests as they arrive - events = await _drain(workflow.send_responses_streaming(responses)) - pending_requests = _handle_events(events) + workflow_result = await workflow.send_responses(responses) + pending_requests = _handle_events(workflow_result) async def main() -> None: @@ -220,7 +231,7 @@ async def main() -> None: "return": create_return_agent, }, ) - .set_coordinator("triage") + .with_start_agent("triage") .with_termination_condition( # Custom termination: Check if the triage agent has provided a closing message. # This looks for the last message being from triage_agent and containing "welcome", @@ -244,14 +255,14 @@ async def main() -> None: workflow_a = workflow_builder.build() print("=== Running workflow_a ===") - await _run_Workflow(workflow_a, list(user_inputs)) + await _run_workflow(workflow_a, list(user_inputs)) workflow_b = workflow_builder.build() print("=== Running workflow_b ===") # Only provide the last two inputs to workflow_b to demonstrate state isolation # The agents in this workflow have no prior context thus should not have knowledge of # order 1234 or previous interactions. - await _run_Workflow(workflow_b, user_inputs[2:]) + await _run_workflow(workflow_b, user_inputs[2:]) """ Expected behavior: - workflow_a and workflow_b maintain separate states for their participants. diff --git a/python/samples/getting_started/workflows/orchestration/handoff_return_to_previous.py b/python/samples/getting_started/workflows/orchestration/handoff_return_to_previous.py deleted file mode 100644 index 8f859bfb0f..0000000000 --- a/python/samples/getting_started/workflows/orchestration/handoff_return_to_previous.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -from collections.abc import AsyncIterable -from typing import cast - -from agent_framework import ( - ChatAgent, - HandoffBuilder, - HandoffUserInputRequest, - RequestInfoEvent, - WorkflowEvent, - WorkflowOutputEvent, -) -from agent_framework.azure import AzureOpenAIChatClient -from azure.identity import AzureCliCredential - -"""Sample: Handoff workflow with return-to-previous routing enabled. - -This interactive sample demonstrates the return-to-previous feature where user inputs -route directly back to the specialist currently handling their request, rather than -always going through the coordinator for re-evaluation. - -Routing Pattern (with return-to-previous enabled): - User -> Coordinator -> Technical Support -> User -> Technical Support -> ... - -Routing Pattern (default, without return-to-previous): - User -> Coordinator -> Technical Support -> User -> Coordinator -> Technical Support -> ... - -This is useful when a specialist needs multiple turns with the user to gather -information or resolve an issue, avoiding unnecessary coordinator involvement. - -Specialist-to-Specialist Handoff: - When a user's request changes to a topic outside the current specialist's domain, - the specialist can hand off DIRECTLY to another specialist without going back through - the coordinator: - - User -> Coordinator -> Technical Support -> User -> Technical Support (billing question) - -> Billing -> User -> Billing ... - -Example Interaction: - 1. User reports a technical issue - 2. Coordinator routes to technical support specialist - 3. Technical support asks clarifying questions - 4. User provides details (routes directly back to technical support) - 5. Technical support continues troubleshooting with full context - 6. Issue resolved, user asks about billing - 7. Technical support hands off DIRECTLY to billing specialist - 8. Billing specialist helps with payment - 9. User continues with billing (routes directly to billing) - -Prerequisites: - - `az login` (Azure CLI authentication) - - Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) - -Usage: - Run the script and interact with the support workflow by typing your requests. - Type 'exit' or 'quit' to end the conversation. - -Key Concepts: - - Return-to-previous: Direct routing to current agent handling the conversation - - Current agent tracking: Framework remembers which agent is actively helping the user - - Context preservation: Specialist maintains full conversation context - - Domain switching: Specialists can hand back to coordinator when topic changes -""" - - -def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAgent, ChatAgent, ChatAgent]: - """Create and configure the coordinator and specialist agents. - - Returns: - Tuple of (coordinator, technical_support, account_specialist, billing_agent) - """ - coordinator = chat_client.create_agent( - instructions=( - "You are a customer support coordinator. Analyze the user's request and route to " - "the appropriate specialist:\n" - "- technical_support for technical issues, troubleshooting, repairs, hardware/software problems\n" - "- account_specialist for account changes, profile updates, settings, login issues\n" - "- billing_agent for payments, invoices, refunds, charges, billing questions\n" - "\n" - "When you receive a request, immediately call the matching handoff tool without explaining. " - "Read the most recent user message to determine the correct specialist." - ), - name="coordinator", - ) - - technical_support = chat_client.create_agent( - instructions=( - "You provide technical support. Help users troubleshoot technical issues, " - "arrange repairs, and answer technical questions. " - "Gather information through conversation. " - "If the user asks about billing, payments, invoices, or refunds, hand off to billing_agent. " - "If the user asks about account settings or profile changes, hand off to account_specialist." - ), - name="technical_support", - ) - - account_specialist = chat_client.create_agent( - instructions=( - "You handle account management. Help with profile updates, account settings, " - "and preferences. Gather information through conversation. " - "If the user asks about technical issues or troubleshooting, hand off to technical_support. " - "If the user asks about billing, payments, invoices, or refunds, hand off to billing_agent." - ), - name="account_specialist", - ) - - billing_agent = chat_client.create_agent( - instructions=( - "You handle billing only. Process payments, explain invoices, handle refunds. " - "If the user asks about technical issues or troubleshooting, hand off to technical_support. " - "If the user asks about account settings or profile changes, hand off to account_specialist." - ), - name="billing_agent", - ) - - return coordinator, technical_support, account_specialist, billing_agent - - -def handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: - """Process events and return pending input requests.""" - pending_requests: list[RequestInfoEvent] = [] - for event in events: - if isinstance(event, RequestInfoEvent): - pending_requests.append(event) - request_data = cast(HandoffUserInputRequest, event.data) - print(f"\n{'=' * 60}") - print(f"AWAITING INPUT FROM: {request_data.awaiting_agent_id.upper()}") - print(f"{'=' * 60}") - for msg in request_data.conversation[-3:]: - author = msg.author_name or msg.role.value - prefix = ">>> " if author == request_data.awaiting_agent_id else " " - print(f"{prefix}[{author}]: {msg.text}") - elif isinstance(event, WorkflowOutputEvent): - print(f"\n{'=' * 60}") - print("[WORKFLOW COMPLETE]") - print(f"{'=' * 60}") - return pending_requests - - -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Drain an async iterable into a list.""" - events: list[WorkflowEvent] = [] - async for event in stream: - events.append(event) - return events - - -async def main() -> None: - """Demonstrate return-to-previous routing in a handoff workflow.""" - chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - coordinator, technical, account, billing = create_agents(chat_client) - - print("Handoff Workflow with Return-to-Previous Routing") - print("=" * 60) - print("\nThis interactive demo shows how user inputs route directly") - print("to the specialist handling your request, avoiding unnecessary") - print("coordinator re-evaluation on each turn.") - print("\nSpecialists can hand off directly to other specialists when") - print("your request changes topics (e.g., from technical to billing).") - print("\nType 'exit' or 'quit' to end the conversation.\n") - - # Configure handoffs with return-to-previous enabled - # Specialists can hand off directly to other specialists when topic changes - workflow = ( - HandoffBuilder( - name="return_to_previous_demo", - participants=[coordinator, technical, account, billing], - ) - .set_coordinator(coordinator) - .add_handoff(coordinator, [technical, account, billing]) # Coordinator routes to all specialists - .add_handoff(technical, [billing, account]) # Technical can route to billing or account - .add_handoff(account, [technical, billing]) # Account can route to technical or billing - .add_handoff(billing, [technical, account]) # Billing can route to technical or account - .enable_return_to_previous(True) # Enable the `return to previous handoff` feature - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 10) - .build() - ) - - # Get initial user request - initial_request = input("You: ").strip() # noqa: ASYNC250 - if not initial_request or initial_request.lower() in ["exit", "quit"]: - print("Goodbye!") - return - - # Start workflow with initial message - events = await _drain(workflow.run_stream(initial_request)) - pending_requests = handle_events(events) - - # Interactive loop: keep prompting for user input - while pending_requests: - user_input = input("\nYou: ").strip() # noqa: ASYNC250 - - if not user_input or user_input.lower() in ["exit", "quit"]: - print("\nEnding conversation. Goodbye!") - break - - responses = {req.request_id: user_input for req in pending_requests} - events = await _drain(workflow.send_responses_streaming(responses)) - pending_requests = handle_events(events) - - print("\n" + "=" * 60) - print("Conversation ended.") - - """ - Sample Output: - - Handoff Workflow with Return-to-Previous Routing - ============================================================ - - This interactive demo shows how user inputs route directly - to the specialist handling your request, avoiding unnecessary - coordinator re-evaluation on each turn. - - Specialists can hand off directly to other specialists when - your request changes topics (e.g., from technical to billing). - - Type 'exit' or 'quit' to end the conversation. - - You: I need help with my bill, I was charged twice by mistake. - - ============================================================ - AWAITING INPUT FROM: BILLING_AGENT - ============================================================ - [user]: I need help with my bill, I was charged twice by mistake. - [coordinator]: You will be connected to a billing agent who can assist you with the double charge on your bill. - >>> [billing_agent]: I'm here to help with billing concerns! I'm sorry you were charged twice. Could you - please provide the invoice number or your account email so I can look into this and begin processing a refund? - - You: Invoice 1234 - - ============================================================ - AWAITING INPUT FROM: BILLING_AGENT - ============================================================ - >>> [billing_agent]: I'm here to help with billing concerns! I'm sorry you were charged twice. - Could you please provide the invoice number or your account email so I can look into this and begin - processing a refund? - [user]: Invoice 1234 - >>> [billing_agent]: Thank you for providing the invoice number (1234). I will review the details and work - on processing a refund for the duplicate charge. - - Can you confirm which payment method you used for this bill (e.g., credit card, PayPal)? - This helps ensure your refund is processed to the correct account. - - You: I used my credit card, which is on autopay. - - ============================================================ - AWAITING INPUT FROM: BILLING_AGENT - ============================================================ - >>> [billing_agent]: Thank you for providing the invoice number (1234). I will review the details and work on - processing a refund for the duplicate charge. - - Can you confirm which payment method you used for this bill (e.g., credit card, PayPal)? This helps ensure - your refund is processed to the correct account. - [user]: I used my credit card, which is on autopay. - >>> [billing_agent]: Thank you for confirming your payment method. I will look into invoice 1234 and - process a refund for the duplicate charge to your credit card. - - You will receive a notification once the refund is completed. If you have any further questions about your billing - or need an update, please let me know! - - You: Actually I also can't turn on my modem. It reset and now won't turn on. - - ============================================================ - AWAITING INPUT FROM: TECHNICAL_SUPPORT - ============================================================ - [user]: Actually I also can't turn on my modem. It reset and now won't turn on. - [billing_agent]: I'm connecting you with technical support for assistance with your modem not turning on after - the reset. They'll be able to help troubleshoot and resolve this issue. - - At the same time, technical support will also handle your refund request for the duplicate charge on invoice 1234 - to your credit card on autopay. - - You will receive updates from the appropriate teams shortly. - >>> [technical_support]: Thanks for letting me know about your modem issue! To help you further, could you tell me: - - 1. Is there any light showing on the modem at all, or is it completely off? - 2. Have you tried unplugging the modem from power and plugging it back in? - 3. Do you hear or feel anything (like a slight hum or vibration) when the modem is plugged in? - - Let me know, and I'll guide you through troubleshooting or arrange a repair if needed. - - You: exit - - Ending conversation. Goodbye! - - ============================================================ - Conversation ended. - """ - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/handoff_simple.py b/python/samples/getting_started/workflows/orchestration/handoff_simple.py index 84b6e0f243..32fd3ba441 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_simple.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_simple.py @@ -1,16 +1,17 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import AsyncIterable from typing import Annotated, cast from agent_framework import ( + AgentRunEvent, + AgentRunResponse, ChatAgent, ChatMessage, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, + HandoffSentEvent, RequestInfoEvent, - Role, WorkflowEvent, WorkflowOutputEvent, WorkflowRunState, @@ -20,27 +21,16 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -"""Sample: Simple handoff workflow with single-tier triage-to-specialist routing. +"""Sample: Simple handoff workflow. -This sample demonstrates the basic handoff pattern where only the triage agent can -route to specialists. Specialists cannot hand off to other specialists - after any -specialist responds, control returns to the user (via the triage agent) for the next input. - -Routing Pattern: - User → Triage Agent → Specialist → Triage Agent → User → Triage Agent → ... - -This is the simplest handoff configuration, suitable for straightforward support -scenarios where a triage agent dispatches to domain specialists, and each specialist -works independently. - -For multi-tier specialist-to-specialist handoffs, see handoff_specialist_to_specialist.py. +A handoff workflow defines a pattern that assembles agents in a mesh topology, allowing +them to transfer control to each other based on the conversation context. Prerequisites: - `az login` (Azure CLI authentication) - Environment variables configured for AzureOpenAIChatClient (AZURE_OPENAI_ENDPOINT, etc.) Key Concepts: - - Single-tier routing: Only triage agent has handoff capabilities - Auto-registered handoff tools: HandoffBuilder automatically creates handoff tools for each participant, allowing the coordinator to transfer control to specialists - Termination condition: Controls when the workflow stops requesting user input @@ -69,14 +59,8 @@ def process_return(order_number: Annotated[str, "Order number to process return def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAgent, ChatAgent, ChatAgent]: """Create and configure the triage and specialist agents. - The triage agent is responsible for: - - Receiving all user input first - - Deciding whether to handle the request directly or hand off to a specialist - - Signaling handoff by calling one of the explicit handoff tools exposed to it - - Specialist agents are invoked only when the triage agent explicitly hands off to them. - After a specialist responds, control returns to the triage agent, which then prompts - the user for their next message. + Args: + chat_client: The AzureOpenAIChatClient to use for creating agents. Returns: Tuple of (triage_agent, refund_agent, order_agent, return_agent) @@ -117,21 +101,6 @@ def create_agents(chat_client: AzureOpenAIChatClient) -> tuple[ChatAgent, ChatAg return triage_agent, refund_agent, order_agent, return_agent -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Collect all events from an async stream into a list. - - This helper drains the workflow's event stream so we can process events - synchronously after each workflow step completes. - - Args: - stream: Async iterable of WorkflowEvent - - Returns: - List of all events from the stream - """ - return [event async for event in stream] - - def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: """Process workflow events and extract any pending user input requests. @@ -150,6 +119,19 @@ def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: requests: list[RequestInfoEvent] = [] for event in events: + # AgentRunEvent: Contains messages generated by agents during their turn + if isinstance(event, AgentRunEvent): + for message in event.data.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue + speaker = message.author_name or message.role.value + print(f"- {speaker}: {message.text}") + + # HandoffSentEvent: Indicates a handoff has been initiated + if isinstance(event, HandoffSentEvent): + print(f"\n[Handoff from {event.source} to {event.target} initiated.]") + # WorkflowStatusEvent: Indicates workflow state changes if isinstance(event, WorkflowStatusEvent) and event.state in { WorkflowRunState.IDLE, @@ -164,40 +146,37 @@ def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: print("\n=== Final Conversation Snapshot ===") for message in conversation: speaker = message.author_name or message.role.value - print(f"- {speaker}: {message.text}") + print(f"- {speaker}: {message.text or [content.type for content in message.contents]}") print("===================================") # RequestInfoEvent: Workflow is requesting user input elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - _print_agent_responses_since_last_user_message(event.data) + if isinstance(event.data, HandoffAgentUserRequest): + _print_handoff_agent_user_request(event.data.agent_response) requests.append(event) return requests -def _print_agent_responses_since_last_user_message(request: HandoffUserInputRequest) -> None: - """Display agent responses since the last user message in a handoff request. +def _print_handoff_agent_user_request(response: AgentRunResponse) -> None: + """Display the agent's response messages when requesting user input. - The HandoffUserInputRequest contains the full conversation history so far, - allowing the user to see what's been discussed before providing their next input. + This will happen when an agent generates a response that doesn't trigger + a handoff, i.e., the agent is asking the user for more information. Args: - request: The user input request containing conversation and prompt + response: The AgentRunResponse from the agent requesting user input """ - if not request.conversation: - raise RuntimeError("HandoffUserInputRequest missing conversation history.") - - # Reverse iterate to collect agent responses since last user message - agent_responses: list[ChatMessage] = [] - for message in request.conversation[::-1]: - if message.role == Role.USER: - break - agent_responses.append(message) - - # Print agent responses in original order - agent_responses.reverse() - for message in agent_responses: + if not response.messages: + raise RuntimeError("Cannot print agent responses: response has no messages.") + + print("\n[Agent is requesting your input...]") + + # Print agent responses + for message in response.messages: + if not message.text: + # Skip messages without text (e.g., tool calls) + continue speaker = message.author_name or message.role.value print(f"- {speaker}: {message.text}") @@ -223,25 +202,23 @@ async def main() -> None: # Build the handoff workflow # - participants: All agents that can participate in the workflow - # - set_coordinator: The triage agent is designated as the coordinator, which means + # - with_start_agent: The triage agent is designated as the start agent, which means # it receives all user input first and orchestrates handoffs to specialists # - with_termination_condition: Custom logic to stop the request/response loop. # Without this, the default behavior continues requesting user input until max_turns # is reached. Here we use a custom condition that checks if the conversation has ended - # naturally (when triage agent says something like "you're welcome"). + # naturally (when one of the agents says something like "you're welcome"). workflow = ( HandoffBuilder( name="customer_support_handoff", participants=[triage, refund, order, support], ) - .set_coordinator(triage) + .with_start_agent(triage) .with_termination_condition( - # Custom termination: Check if the triage agent has provided a closing message. - # This looks for the last message being from triage_agent and containing "welcome", - # which indicates the conversation has concluded naturally. - lambda conversation: len(conversation) > 0 - and conversation[-1].author_name == "triage_agent" - and "welcome" in conversation[-1].text.lower() + # Custom termination: Check if one of the agents has provided a closing message. + # This looks for the last message containing "welcome", which indicates the + # conversation has concluded naturally. + lambda conversation: len(conversation) > 0 and "welcome" in conversation[-1].text.lower() ) .build() ) @@ -252,6 +229,7 @@ async def main() -> None: # or integrate with a UI/chat interface scripted_responses = [ "My order 1234 arrived damaged and the packaging was destroyed. I'd like to return it.", + "Please also process a refund for order 1234.", "Thanks for resolving this.", ] @@ -260,26 +238,32 @@ async def main() -> None: print("[Starting workflow with initial user message...]\n") initial_message = "Hello, I need assistance with my recent purchase." print(f"- User: {initial_message}") - events = await _drain(workflow.run_stream(initial_message)) - pending_requests = _handle_events(events) + workflow_result = await workflow.run(initial_message) + pending_requests = _handle_events(workflow_result) # Process the request/response cycle # The workflow will continue requesting input until: - # 1. The termination condition is met (4 user messages in this case), OR + # 1. The termination condition is met, OR # 2. We run out of scripted responses - while pending_requests and scripted_responses: - # Get the next scripted response - user_response = scripted_responses.pop(0) - print(f"\n- User: {user_response}") - - # Send response(s) to all pending requests - # In this demo, there's typically one request per cycle, but the API supports multiple - responses = {req.request_id: user_response for req in pending_requests} + while pending_requests: + if not scripted_responses: + # No more scripted responses; terminate the workflow + responses = {req.request_id: HandoffAgentUserRequest.terminate() for req in pending_requests} + else: + # Get the next scripted response + user_response = scripted_responses.pop(0) + print(f"\n- User: {user_response}") + + # Send response(s) to all pending requests + # In this demo, there's typically one request per cycle, but the API supports multiple + responses = { + req.request_id: HandoffAgentUserRequest.create_response(user_response) for req in pending_requests + } # Send responses and get new events - # We use send_responses_streaming() to get events as they occur, allowing us to - # display agent responses in real-time and handle new requests as they arrive - events = await _drain(workflow.send_responses_streaming(responses)) + # We use send_responses() to get events from the workflow, allowing us to + # display agent responses and handle new requests as they arrive + events = await workflow.send_responses(responses) pending_requests = _handle_events(events) """ diff --git a/python/samples/getting_started/workflows/orchestration/handoff_specialist_to_specialist.py b/python/samples/getting_started/workflows/orchestration/handoff_specialist_to_specialist.py deleted file mode 100644 index dfc9f0f73b..0000000000 --- a/python/samples/getting_started/workflows/orchestration/handoff_specialist_to_specialist.py +++ /dev/null @@ -1,284 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Sample: Multi-tier handoff workflow with specialist-to-specialist routing. - -This sample demonstrates advanced handoff routing where specialist agents can hand off -to other specialists, enabling complex multi-tier workflows. Unlike the simple handoff -pattern (see handoff_simple.py), specialists here can delegate to other specialists -without returning control to the user until the specialist chain completes. - -Routing Pattern: - User → Triage → Specialist A → Specialist B → Back to User - -This pattern is useful for complex support scenarios where different specialists need -to collaborate or escalate to each other before returning to the user. For example: - - Replacement agent needs shipping info → hands off to delivery agent - - Technical support needs billing info → hands off to billing agent - - Level 1 support escalates to Level 2 → hands off to escalation agent - -Configuration uses `.add_handoff()` to explicitly define the routing graph. - -Prerequisites: - - `az login` (Azure CLI authentication) - - Environment variables configured for AzureOpenAIChatClient -""" - -import asyncio -from collections.abc import AsyncIterable -from typing import cast - -from agent_framework import ( - ChatMessage, - HandoffBuilder, - HandoffUserInputRequest, - RequestInfoEvent, - WorkflowEvent, - WorkflowOutputEvent, - WorkflowRunState, - WorkflowStatusEvent, -) -from agent_framework.azure import AzureOpenAIChatClient -from azure.identity import AzureCliCredential - - -def create_agents(chat_client: AzureOpenAIChatClient): - """Create triage and specialist agents with multi-tier handoff capabilities. - - Returns: - Tuple of (triage_agent, replacement_agent, delivery_agent, billing_agent) - """ - triage = chat_client.create_agent( - instructions=( - "You are a customer support triage agent. Assess the user's issue and route appropriately:\n" - "- For product replacement issues: call handoff_to_replacement_agent\n" - "- For delivery/shipping inquiries: call handoff_to_delivery_agent\n" - "- For billing/payment issues: call handoff_to_billing_agent\n" - "Be concise and friendly." - ), - name="triage_agent", - ) - - replacement = chat_client.create_agent( - instructions=( - "You handle product replacement requests. Ask for order number and reason for replacement.\n" - "If the user also needs shipping/delivery information, call handoff_to_delivery_agent to " - "get tracking details. Otherwise, process the replacement and confirm with the user.\n" - "Be concise and helpful." - ), - name="replacement_agent", - ) - - delivery = chat_client.create_agent( - instructions=( - "You handle shipping and delivery inquiries. Provide tracking information, estimated " - "delivery dates, and address any delivery concerns.\n" - "If billing issues come up, call handoff_to_billing_agent.\n" - "Be concise and clear." - ), - name="delivery_agent", - ) - - billing = chat_client.create_agent( - instructions=( - "You handle billing and payment questions. Help with refunds, payment methods, " - "and invoice inquiries. Be concise." - ), - name="billing_agent", - ) - - return triage, replacement, delivery, billing - - -async def _drain(stream: AsyncIterable[WorkflowEvent]) -> list[WorkflowEvent]: - """Collect all events from an async stream into a list.""" - return [event async for event in stream] - - -def _handle_events(events: list[WorkflowEvent]) -> list[RequestInfoEvent]: - """Process workflow events and extract pending user input requests.""" - requests: list[RequestInfoEvent] = [] - - for event in events: - if isinstance(event, WorkflowStatusEvent) and event.state in { - WorkflowRunState.IDLE, - WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, - }: - print(f"[status] {event.state.name}") - - elif isinstance(event, WorkflowOutputEvent): - conversation = cast(list[ChatMessage], event.data) - if isinstance(conversation, list): - print("\n=== Final Conversation ===") - for message in conversation: - # Filter out messages with no text (tool calls) - if not message.text.strip(): - continue - speaker = message.author_name or message.role.value - print(f"- {speaker}: {message.text}") - print("==========================") - - elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - _print_handoff_request(event.data) - requests.append(event) - - return requests - - -def _print_handoff_request(request: HandoffUserInputRequest) -> None: - """Display a user input request with conversation context.""" - print("\n=== User Input Requested ===") - # Filter out messages with no text for cleaner display - messages_with_text = [msg for msg in request.conversation if msg.text.strip()] - print(f"Last {len(messages_with_text)} messages in conversation:") - for message in messages_with_text[-5:]: # Show last 5 for brevity - speaker = message.author_name or message.role.value - text = message.text[:100] + "..." if len(message.text) > 100 else message.text - print(f" {speaker}: {text}") - print("============================") - - -async def main() -> None: - """Demonstrate specialist-to-specialist handoffs in a multi-tier support scenario. - - This sample shows: - 1. Triage agent routes to replacement specialist - 2. Replacement specialist hands off to delivery specialist - 3. Delivery specialist can hand off to billing if needed - 4. All transitions are seamless without returning to user until complete - - The workflow configuration explicitly defines which agents can hand off to which others: - - triage_agent → replacement_agent, delivery_agent, billing_agent - - replacement_agent → delivery_agent, billing_agent - - delivery_agent → billing_agent - """ - chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - triage, replacement, delivery, billing = create_agents(chat_client) - - # Configure multi-tier handoffs using fluent add_handoff() API - # This allows specialists to hand off to other specialists - workflow = ( - HandoffBuilder( - name="multi_tier_support", - participants=[triage, replacement, delivery, billing], - ) - .set_coordinator(triage) - .add_handoff(triage, [replacement, delivery, billing]) # Triage can route to any specialist - .add_handoff(replacement, [delivery, billing]) # Replacement can delegate to delivery or billing - .add_handoff(delivery, billing) # Delivery can escalate to billing - # Termination condition: Stop when more than 3 user messages exist. - # This allows agents to respond to the 3rd user message before the 4th triggers termination. - # In this sample: initial message + 3 scripted responses = 4 messages, then workflow ends. - .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") > 3) - .build() - ) - - # Scripted user responses simulating a multi-tier handoff scenario - # Note: The initial run_stream() call sends the first user message, - # then these scripted responses are sent in sequence (total: 4 user messages). - # A 5th response triggers termination after agents respond to the 4th message. - scripted_responses = [ - "I need help with order 12345. I want a replacement and need to know when it will arrive.", - "The item arrived damaged. I'd like a replacement shipped to the same address.", - "Great! Can you confirm the shipping cost won't be charged again?", - "Thank you!", # Final response to trigger termination after billing agent answers - ] - - print("\n" + "=" * 80) - print("SPECIALIST-TO-SPECIALIST HANDOFF DEMONSTRATION") - print("=" * 80) - print("\nScenario: Customer needs replacement + shipping info + billing confirmation") - print("Expected flow: User → Triage → Replacement → Delivery → Billing → User") - print("=" * 80 + "\n") - - # Start workflow with initial message - print(f"[User]: {scripted_responses[0]}\n") - events = await _drain(workflow.run_stream(scripted_responses[0])) - pending_requests = _handle_events(events) - - # Process scripted responses - response_index = 1 - while pending_requests and response_index < len(scripted_responses): - user_response = scripted_responses[response_index] - print(f"\n[User]: {user_response}\n") - - responses = {req.request_id: user_response for req in pending_requests} - events = await _drain(workflow.send_responses_streaming(responses)) - pending_requests = _handle_events(events) - - response_index += 1 - - """ - Sample Output: - - ================================================================================ - SPECIALIST-TO-SPECIALIST HANDOFF DEMONSTRATION - ================================================================================ - - Scenario: Customer needs replacement + shipping info + billing confirmation - Expected flow: User → Triage → Replacement → Delivery → Billing → User - ================================================================================ - - [User]: I need help with order 12345. I want a replacement and need to know when it will arrive. - - - === User Input Requested === - Last 5 messages in conversation: - user: I need help with order 12345. I want a replacement and need to know when it will arrive. - triage_agent: I am connecting you to our replacement agent to assist with your replacement request and to our deli... - replacement_agent: I have connected you to our agents who will assist with your replacement request for order 12345 and... - delivery_agent: For your replacement request and delivery details regarding order 12345, I'll connect you to the app... - billing_agent: I don’t have access to order details. Please contact the seller or customer service directly for rep... - ============================ - [status] IDLE_WITH_PENDING_REQUESTS - - [User]: The item arrived damaged. I'd like a replacement shipped to the same address. - - - === User Input Requested === - Last 8 messages in conversation: - delivery_agent: For your replacement request and delivery details regarding order 12345, I'll connect you to the app... - billing_agent: I don’t have access to order details. Please contact the seller or customer service directly for rep... - user: The item arrived damaged. I'd like a replacement shipped to the same address. - triage_agent: I'm connecting you to our replacement agent who will assist you with getting a replacement shipped t... - replacement_agent: Thank you for the info. I'll start the replacement process for your damaged item on order 12345 and ... - ============================ - [status] IDLE_WITH_PENDING_REQUESTS - - [User]: Great! Can you confirm the shipping cost won't be charged again? - - - === User Input Requested === - Last 11 messages in conversation: - triage_agent: I'm connecting you to our replacement agent who will assist you with getting a replacement shipped t... - replacement_agent: Thank you for the info. I'll start the replacement process for your damaged item on order 12345 and ... - user: Great! Can you confirm the shipping cost won't be charged again? - triage_agent: I'm connecting you to our billing agent who can confirm whether the shipping cost will be charged ag... - billing_agent: Replacements for damaged items are typically shipped at no extra shipping cost. I recommend confirmi... - ============================ - [status] IDLE_WITH_PENDING_REQUESTS - - [User]: Thank you! - - - === Final Conversation === - - user: I need help with order 12345. I want a replacement and need to know when it will arrive. - - triage_agent: I am connecting you to our replacement agent to assist with your replacement request and to our delivery agent who can provide information on the arrival of your replacement. They will help you shortly! - - replacement_agent: I have connected you to our agents who will assist with your replacement request for order 12345 and provide shipping and delivery information. They will help you with the details shortly. If you have any other questions, feel free to ask! - - delivery_agent: For your replacement request and delivery details regarding order 12345, I'll connect you to the appropriate agent who can assist with both billing and delivery concerns. They will provide you with the estimated arrival date and replacement process. Is there anything else you'd like to know meanwhile? - - billing_agent: I don’t have access to order details. Please contact the seller or customer service directly for replacement and delivery info on order 12345. If you need help with billing, refunds, or payments, I can assist! - - user: The item arrived damaged. I'd like a replacement shipped to the same address. - - triage_agent: I'm connecting you to our replacement agent who will assist you with getting a replacement shipped to the same address for the damaged item. They will help you shortly! - - replacement_agent: Thank you for the info. I'll start the replacement process for your damaged item on order 12345 and arrange to ship it to the same address. Please hold on a moment while I proceed. - Thank you for providing the reason. I will process a replacement for the damaged item from order 12345 and arrange the shipment to the same address. Please hold on a moment while I take care of this. - - user: Great! Can you confirm the shipping cost won't be charged again? - - triage_agent: I'm connecting you to our billing agent who can confirm whether the shipping cost will be charged again for the replacement of your damaged item. They will assist you shortly! - - billing_agent: Replacements for damaged items are typically shipped at no extra shipping cost. I recommend confirming with the replacements or billing department to be sure. Let me know if you’d like me to connect you! - - user: Thank you! - ========================== - [status] IDLE - """ # noqa: E501 - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py b/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py index b1fd37302a..eadfe634e3 100644 --- a/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py +++ b/python/samples/getting_started/workflows/orchestration/handoff_with_code_interpreter_file.py @@ -32,8 +32,8 @@ from agent_framework import ( AgentRunUpdateEvent, ChatAgent, + HandoffAgentUserRequest, HandoffBuilder, - HandoffUserInputRequest, HostedCodeInterpreterTool, HostedFileContent, RequestInfoEvent, @@ -68,21 +68,10 @@ def _handle_events(events: list[WorkflowEvent]) -> tuple[list[RequestInfoEvent], print(f"[status] {event.state.name}") elif isinstance(event, RequestInfoEvent): - if isinstance(event.data, HandoffUserInputRequest): - print("\n=== Conversation So Far ===") - for msg in event.data.conversation: - speaker = msg.author_name or msg.role.value - text = msg.text or "" - txt = text[:200] + "..." if len(text) > 200 else text - print(f"- {speaker}: {txt}") - print("===========================\n") requests.append(event) elif isinstance(event, AgentRunUpdateEvent): - update = event.data - if update is None: - continue - for content in update.contents: + for content in event.data.contents: if isinstance(content, HostedFileContent): file_ids.append(content.file_id) print(f"[Found HostedFileContent: file_id={content.file_id}]") @@ -137,11 +126,7 @@ async def create_agents_v2(credential: AzureCliCredential) -> AsyncIterator[tupl ): triage = triage_client.create_agent( name="TriageAgent", - instructions=( - "You are a triage agent. Your ONLY job is to route requests to the appropriate specialist. " - "For code or file creation requests, call handoff_to_CodeSpecialist immediately. " - "Do NOT try to complete tasks yourself. Just hand off." - ), + instructions="You are a triage agent. Your ONLY job is to route requests to the appropriate specialist.", ) code_specialist = code_client.create_agent( @@ -170,7 +155,7 @@ async def main() -> None: workflow = ( HandoffBuilder() .participants([triage, code_specialist]) - .set_coordinator(triage) + .with_start_agent(triage) .with_termination_condition(lambda conv: sum(1 for msg in conv if msg.role.value == "user") >= 2) .build() ) @@ -195,7 +180,7 @@ async def main() -> None: user_input = user_inputs[input_index] print(f"\nUser: {user_input}") - responses = {request.request_id: user_input} + responses = {request.request_id: HandoffAgentUserRequest.create_response(user_input)} events = await _drain(workflow.send_responses_streaming(responses)) requests, file_ids = _handle_events(events) all_file_ids.extend(file_ids) diff --git a/python/samples/getting_started/workflows/orchestration/magentic.py b/python/samples/getting_started/workflows/orchestration/magentic.py index 213486706a..8e71d09a42 100644 --- a/python/samples/getting_started/workflows/orchestration/magentic.py +++ b/python/samples/getting_started/workflows/orchestration/magentic.py @@ -1,17 +1,19 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import json import logging from typing import cast from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, AgentRunUpdateEvent, ChatAgent, ChatMessage, + GroupChatRequestSentEvent, HostedCodeInterpreterTool, MagenticBuilder, + MagenticOrchestratorEvent, + MagenticProgressLedger, WorkflowOutputEvent, ) from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient @@ -75,13 +77,9 @@ async def main() -> None: print("\nBuilding Magentic Workflow...") - # State used by on_agent_stream callback - last_stream_agent_id: str | None = None - stream_line_open: bool = False - workflow = ( MagenticBuilder() - .participants(researcher=researcher_agent, coder=coder_agent) + .participants([researcher_agent, coder_agent]) .with_standard_manager( agent=manager_agent, max_round_count=10, @@ -103,43 +101,49 @@ async def main() -> None: print(f"\nTask: {task}") print("\nStarting workflow execution...") - try: - output: str | None = None - async for event in workflow.run_stream(task): - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - print(f"\n[ORCH:{kind}]\n\n{text}\n{'-' * 26}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", event.executor_id) if props else event.executor_id - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[STREAM:{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - elif event.data and event.data.text: - print(event.data.text, end="", flush=True) - elif isinstance(event, WorkflowOutputEvent): - output_messages = cast(list[ChatMessage], event.data) - if output_messages: - output = output_messages[-1].text - - if stream_line_open: - print() - stream_line_open = False - - if output is not None: - print(f"Workflow completed with result:\n\n{output}") - - except Exception as e: - print(f"Workflow execution failed: {e}") + # Keep track of the last executor to format output nicely in streaming mode + last_message_id: str | None = None + output_event: WorkflowOutputEvent | None = None + async for event in workflow.run_stream(task): + if isinstance(event, AgentRunUpdateEvent): + message_id = event.data.message_id + if message_id != last_message_id: + if last_message_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_message_id = message_id + print(event.data, end="", flush=True) + + elif isinstance(event, MagenticOrchestratorEvent): + print(f"\n[Magentic Orchestrator Event] Type: {event.event_type.name}") + if isinstance(event.data, ChatMessage): + print(f"Please review the plan:\n{event.data.text}") + elif isinstance(event.data, MagenticProgressLedger): + print(f"Please review progress ledger:\n{json.dumps(event.data.to_dict(), indent=2)}") + else: + print(f"Unknown data type in MagenticOrchestratorEvent: {type(event.data)}") + + # Block to allow user to read the plan/progress before continuing + # Note: this is for demonstration only and is not the recommended way to handle human interaction. + # Please refer to `with_plan_review` for proper human interaction during planning phases. + await asyncio.get_event_loop().run_in_executor(None, input, "Press Enter to continue...") + + elif isinstance(event, GroupChatRequestSentEvent): + print(f"\n[REQUEST SENT ({event.round_index})] to agent: {event.participant_name}") + + elif isinstance(event, WorkflowOutputEvent): + output_event = event + + if not output_event: + raise RuntimeError("Workflow did not produce a final output event.") + print("\n\nWorkflow completed!") + print("Final Output:") + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + output_messages = cast(list[ChatMessage], output_event.data) + if output_messages: + output = output_messages[-1].text + print(output) if __name__ == "__main__": diff --git a/python/samples/getting_started/workflows/orchestration/magentic_agent_clarification.py b/python/samples/getting_started/workflows/orchestration/magentic_agent_clarification.py deleted file mode 100644 index 44dea25acc..0000000000 --- a/python/samples/getting_started/workflows/orchestration/magentic_agent_clarification.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import logging -from typing import Annotated, cast - -from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - AgentRunUpdateEvent, - ChatAgent, - ChatMessage, - MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, - RequestInfoEvent, - WorkflowOutputEvent, - ai_function, -) -from agent_framework.openai import OpenAIChatClient - -logging.basicConfig(level=logging.WARNING) -logger = logging.getLogger(__name__) - -""" -Sample: Agent Clarification via Tool Calls in Magentic Workflows - -This sample demonstrates how agents can ask clarifying questions to users during -execution via the HITL (Human-in-the-Loop) mechanism. - -Scenario: "Onboard Jessica Smith" -- User provides an ambiguous task: "Onboard Jessica Smith" -- The onboarding agent recognizes missing information and uses the ask_user tool -- The ask_user call surfaces as a TOOL_APPROVAL request via RequestInfoEvent -- User provides the answer (e.g., "Engineering, Software Engineer") -- The answer is fed back to the agent as a FunctionResultContent -- Agent continues execution with the clarified information - -How it works: -1. Agent has an `ask_user` tool decorated with `@ai_function(approval_mode="always_require")` -2. When agent calls `ask_user`, it surfaces as a FunctionApprovalRequestContent -3. MagenticAgentExecutor converts this to a MagenticHumanInterventionRequest(kind=TOOL_APPROVAL) -4. User provides answer via MagenticHumanInterventionReply with response_text -5. The response_text becomes the function result fed back to the agent -6. Agent receives the result and continues processing - -Prerequisites: -- OpenAI credentials configured for `OpenAIChatClient`. -""" - - -@ai_function(approval_mode="always_require") -def ask_user(question: Annotated[str, "The question to ask the user for clarification"]) -> str: - """Ask the user a clarifying question to gather missing information. - - Use this tool when you need additional information from the user to complete - your task effectively. The user's response will be returned so you can - continue with your work. - - Args: - question: The question to ask the user - - Returns: - The user's response to the question - """ - # This function body is a placeholder - the actual interaction happens via HITL. - # When the agent calls this tool: - # 1. The tool call surfaces as a FunctionApprovalRequestContent - # 2. MagenticAgentExecutor detects this and emits a HITL request - # 3. The user provides their answer - # 4. The answer is fed back as the function result - return f"User was asked: {question}" - - -async def main() -> None: - # Create an onboarding agent that asks clarifying questions - onboarding_agent = ChatAgent( - name="OnboardingAgent", - description="HR specialist who handles employee onboarding", - instructions=( - "You are an HR Onboarding Specialist. Your job is to onboard new employees.\n\n" - "IMPORTANT: When given an onboarding request, you MUST gather the following " - "information before proceeding:\n" - "1. Department (e.g., Engineering, Sales, Marketing)\n" - "2. Role/Title (e.g., Software Engineer, Account Executive)\n" - "3. Start date (if not specified)\n" - "4. Manager's name (if known)\n\n" - "Use the ask_user tool to request ANY missing information. " - "Do not proceed with onboarding until you have at least the department and role.\n\n" - "Once you have the information, create an onboarding plan." - ), - chat_client=OpenAIChatClient(model_id="gpt-4o"), - tools=[ask_user], # Tool decorated with @ai_function(approval_mode="always_require") - ) - - # Create a manager agent - manager_agent = ChatAgent( - name="MagenticManager", - description="Orchestrator that coordinates the onboarding workflow", - instructions="You coordinate a team to complete HR tasks efficiently.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - print("\nBuilding Magentic Workflow with Agent Clarification...") - - workflow = ( - MagenticBuilder() - .participants(onboarding=onboarding_agent) - .with_standard_manager( - agent=manager_agent, - max_round_count=10, - max_stall_count=3, - max_reset_count=2, - ) - .build() - ) - - # Ambiguous task - agent should ask for clarification - task = "Onboard Jessica Smith" - - print(f"\nTask: {task}") - print("(This is intentionally vague - the agent should ask for more details)") - print("\nStarting workflow execution...") - print("=" * 60) - - try: - pending_request: RequestInfoEvent | None = None - pending_responses: dict[str, object] | None = None - completed = False - workflow_output: str | None = None - - last_stream_agent_id: str | None = None - stream_line_open: bool = False - - while not completed: - if pending_responses is not None: - stream = workflow.send_responses_streaming(pending_responses) - else: - stream = workflow.run_stream(task) - - async for event in stream: - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - if stream_line_open: - print() - stream_line_open = False - print(f"\n[ORCHESTRATOR: {kind}]\n{text}\n{'-' * 40}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", "unknown") if props else "unknown" - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - - elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - if stream_line_open: - print() - stream_line_open = False - pending_request = event - req = cast(MagenticHumanInterventionRequest, event.data) - - if req.kind == MagenticHumanInterventionKind.TOOL_APPROVAL: - print("\n" + "=" * 60) - print("AGENT ASKING FOR CLARIFICATION") - print("=" * 60) - print(f"\nAgent: {req.agent_id}") - print(f"Question: {req.prompt}") - if req.context: - print(f"Details: {req.context}") - print() - - elif isinstance(event, WorkflowOutputEvent): - if stream_line_open: - print() - stream_line_open = False - workflow_output = event.data if event.data else None - completed = True - - if stream_line_open: - print() - stream_line_open = False - pending_responses = None - - if pending_request is not None: - req = cast(MagenticHumanInterventionRequest, pending_request.data) - - if req.kind == MagenticHumanInterventionKind.TOOL_APPROVAL: - # Agent is asking for clarification - print("Please provide your answer:") - answer = input("> ").strip() # noqa: ASYNC250 - - if answer.lower() == "exit": - print("Exiting workflow...") - return - - # Send the answer back - it will be fed to the agent as the function result - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.APPROVE, - response_text=answer if answer else "No additional information provided.", - ) - pending_responses = {pending_request.request_id: reply} - pending_request = None - - print("\n" + "=" * 60) - print("WORKFLOW COMPLETED") - print("=" * 60) - if workflow_output: - messages = cast(list[ChatMessage], workflow_output) - if messages: - final_msg = messages[-1] - print(f"\nFinal Result:\n{final_msg.text}") - - except Exception as e: - print(f"Workflow execution failed: {e}") - logger.exception("Workflow exception", exc_info=e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py index 36e6ca4c01..6fc284a9ab 100644 --- a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py +++ b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py @@ -3,15 +3,14 @@ import asyncio import json from pathlib import Path +from typing import cast from agent_framework import ( ChatAgent, + ChatMessage, FileCheckpointStorage, MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, + MagenticPlanReviewRequest, RequestInfoEvent, WorkflowCheckpoint, WorkflowOutputEvent, @@ -82,7 +81,7 @@ def build_workflow(checkpoint_storage: FileCheckpointStorage): # stores the checkpoint backend so the runtime knows where to persist snapshots. return ( MagenticBuilder() - .participants(researcher=researcher, writer=writer) + .participants([researcher, writer]) .with_plan_review() .with_standard_manager( agent=manager_agent, @@ -110,19 +109,16 @@ async def main() -> None: # Run the workflow until the first RequestInfoEvent is surfaced. The event carries the # request_id we must reuse on resume. In a real system this is where the UI would present # the plan for human review. - plan_review_request_id: str | None = None + plan_review_request: MagenticPlanReviewRequest | None = None async for event in workflow.run_stream(TASK): - if isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - request = event.data - if isinstance(request, MagenticHumanInterventionRequest): - if request.kind == MagenticHumanInterventionKind.PLAN_REVIEW: - plan_review_request_id = event.request_id - print(f"Captured plan review request: {plan_review_request_id}") + if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: + plan_review_request = event.data + print(f"Captured plan review request: {event.request_id}") if isinstance(event, WorkflowStatusEvent) and event.state is WorkflowRunState.IDLE_WITH_PENDING_REQUESTS: break - if plan_review_request_id is None: + if plan_review_request is None: print("No plan review request emitted; nothing to resume.") return @@ -142,19 +138,19 @@ async def main() -> None: if checkpoint_path.exists(): with checkpoint_path.open() as f: snapshot = json.load(f) - request_map = snapshot.get("executor_states", {}).get("magentic_plan_review", {}).get("request_events", {}) + request_map = snapshot.get("pending_request_info_events", {}) print(f"Pending plan-review requests persisted in checkpoint: {list(request_map.keys())}") print("\n=== Stage 2: resume from checkpoint and approve plan ===") resumed_workflow = build_workflow(checkpoint_storage) # Construct an approval reply to supply when the plan review request is re-emitted. - approval = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) + approval = plan_review_request.approve() # Resume execution and capture the re-emitted plan review request. request_info_event: RequestInfoEvent | None = None async for event in resumed_workflow.run_stream(checkpoint_id=resume_checkpoint.checkpoint_id): - if isinstance(event, RequestInfoEvent) and isinstance(event.data, MagenticHumanInterventionRequest): + if isinstance(event, RequestInfoEvent) and isinstance(event.data, MagenticPlanReviewRequest): request_info_event = event if request_info_event is None: @@ -178,9 +174,11 @@ async def main() -> None: if not result: print("No result data from workflow.") return - text = getattr(result, "text", None) or str(result) + output_messages = cast(list[ChatMessage], result) print("\n=== Final Answer ===") - print(text) + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + print(output_messages[-1].text) # ------------------------------------------------------------------ # Stage 3: demonstrate resuming from a later checkpoint (post-plan) @@ -233,7 +231,7 @@ def _pending_message_count(cp: WorkflowCheckpoint) -> int: if not post_emitted_events: print("No new events were emitted; checkpoint already captured a completed run.") print("\n=== Final Answer (post-plan resume) ===") - print(text) + print(output_messages[-1].text) return print("Workflow did not complete after post-plan resume.") return @@ -243,9 +241,11 @@ def _pending_message_count(cp: WorkflowCheckpoint) -> int: print("No result data from post-plan resume.") return - post_text = getattr(post_result, "text", None) or str(post_result) + output_messages = cast(list[ChatMessage], post_result) print("\n=== Final Answer (post-plan resume) ===") - print(post_text) + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + print(output_messages[-1].text) """ Sample Output: diff --git a/python/samples/getting_started/workflows/orchestration/magentic_human_plan_review.py b/python/samples/getting_started/workflows/orchestration/magentic_human_plan_review.py new file mode 100644 index 0000000000..37a53020e7 --- /dev/null +++ b/python/samples/getting_started/workflows/orchestration/magentic_human_plan_review.py @@ -0,0 +1,145 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +from typing import cast + +from agent_framework import ( + AgentRunUpdateEvent, + ChatAgent, + ChatMessage, + MagenticBuilder, + MagenticPlanReviewRequest, + RequestInfoEvent, + WorkflowOutputEvent, +) +from agent_framework.openai import OpenAIChatClient + +""" +Sample: Magentic Orchestration with Human Plan Review + +This sample demonstrates how humans can review and provide feedback on plans +generated by the Magentic workflow orchestrator. When plan review is enabled, +the workflow requests human approval or revision before executing each plan. + +Key concepts: +- with_plan_review(): Enables human review of generated plans +- MagenticPlanReviewRequest: The event type for plan review requests +- Human can choose to: approve the plan or provide revision feedback + +Plan review options: +- approve(): Accept the proposed plan and continue execution +- revise(feedback): Provide textual feedback to modify the plan + +Prerequisites: +- OpenAI credentials configured for `OpenAIChatClient`. +""" + + +async def main() -> None: + researcher_agent = ChatAgent( + name="ResearcherAgent", + description="Specialist in research and information gathering", + instructions="You are a Researcher. You find information and gather facts.", + chat_client=OpenAIChatClient(model_id="gpt-4o"), + ) + + analyst_agent = ChatAgent( + name="AnalystAgent", + description="Data analyst who processes and summarizes research findings", + instructions="You are an Analyst. You analyze findings and create summaries.", + chat_client=OpenAIChatClient(model_id="gpt-4o"), + ) + + manager_agent = ChatAgent( + name="MagenticManager", + description="Orchestrator that coordinates the workflow", + instructions="You coordinate a team to complete tasks efficiently.", + chat_client=OpenAIChatClient(model_id="gpt-4o"), + ) + + print("\nBuilding Magentic Workflow with Human Plan Review...") + + workflow = ( + MagenticBuilder() + .participants([researcher_agent, analyst_agent]) + .with_standard_manager( + agent=manager_agent, + max_round_count=10, + max_stall_count=1, + max_reset_count=2, + ) + .with_plan_review() # Request human input for plan review + .build() + ) + + task = "Research sustainable aviation fuel technology and summarize the findings." + + print(f"\nTask: {task}") + print("\nStarting workflow execution...") + print("=" * 60) + + pending_request: RequestInfoEvent | None = None + pending_responses: dict[str, object] | None = None + output_event: WorkflowOutputEvent | None = None + + while not output_event: + if pending_responses is not None: + stream = workflow.send_responses_streaming(pending_responses) + else: + stream = workflow.run_stream(task) + + last_message_id: str | None = None + async for event in stream: + if isinstance(event, AgentRunUpdateEvent): + message_id = event.data.message_id + if message_id != last_message_id: + if last_message_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_message_id = message_id + print(event.data, end="", flush=True) + + elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: + pending_request = event + + elif isinstance(event, WorkflowOutputEvent): + output_event = event + + pending_responses = None + + # Handle plan review request if any + if pending_request is not None: + event_data = cast(MagenticPlanReviewRequest, pending_request.data) + + print("\n\n[Magentic Plan Review Request]") + if event_data.current_progress is not None: + print("Current Progress Ledger:") + print(json.dumps(event_data.current_progress.to_dict(), indent=2)) + print() + print(f"Proposed Plan:\n{event_data.plan.text}\n") + print("Please provide your feedback (press Enter to approve):") + + reply = await asyncio.get_event_loop().run_in_executor(None, input, "> ") + if reply.strip() == "": + print("Plan approved.\n") + pending_responses = {pending_request.request_id: event_data.approve()} + else: + print("Plan revised by human.\n") + pending_responses = {pending_request.request_id: event_data.revise(reply)} + pending_request = None + + print("\n" + "=" * 60) + print("WORKFLOW COMPLETED") + print("=" * 60) + print("Final Output:") + # The output of the Magentic workflow is a list of ChatMessages with only one final message + # generated by the orchestrator. + output_messages = cast(list[ChatMessage], output_event.data) + if output_messages: + output = output_messages[-1].text + print(output) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/magentic_human_plan_update.py b/python/samples/getting_started/workflows/orchestration/magentic_human_plan_update.py deleted file mode 100644 index b96fac7e99..0000000000 --- a/python/samples/getting_started/workflows/orchestration/magentic_human_plan_update.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import logging -from typing import cast - -from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - AgentRunUpdateEvent, - ChatAgent, - HostedCodeInterpreterTool, - MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, - RequestInfoEvent, - WorkflowOutputEvent, -) -from agent_framework.openai import OpenAIChatClient, OpenAIResponsesClient - -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - -""" -Sample: Magentic Orchestration + Human Plan Review - -What it does: -- Builds a Magentic workflow with two agents and enables human plan review. - A human approves or edits the plan via `RequestInfoEvent` before execution. - -- researcher: ChatAgent backed by OpenAIChatClient (web/search-capable model) -- coder: ChatAgent backed by OpenAIAssistantsClient with the Hosted Code Interpreter tool - -Key behaviors demonstrated: -- with_plan_review(): requests a PlanReviewRequest before coordination begins -- Event loop that waits for RequestInfoEvent[PlanReviewRequest], prints the plan, then - replies with PlanReviewReply (here we auto-approve, but you can edit/collect input) -- Callbacks: on_agent_stream (incremental chunks), on_agent_response (final messages), - on_result (final answer), and on_exception -- Workflow completion when idle - -Prerequisites: -- OpenAI credentials configured for `OpenAIChatClient` and `OpenAIResponsesClient`. -""" - - -async def main() -> None: - researcher_agent = ChatAgent( - name="ResearcherAgent", - description="Specialist in research and information gathering", - instructions=( - "You are a Researcher. You find information without additional computation or quantitative analysis." - ), - # This agent requires the gpt-4o-search-preview model to perform web searches. - # Feel free to explore with other agents that support web search, for example, - # the `OpenAIResponseAgent` or `AzureAgentProtocol` with bing grounding. - chat_client=OpenAIChatClient(model_id="gpt-4o-search-preview"), - ) - - coder_agent = ChatAgent( - name="CoderAgent", - description="A helpful assistant that writes and executes code to process and analyze data.", - instructions="You solve questions using code. Please provide detailed analysis and computation process.", - chat_client=OpenAIResponsesClient(), - tools=HostedCodeInterpreterTool(), - ) - - # Create a manager agent for the orchestration - manager_agent = ChatAgent( - name="MagenticManager", - description="Orchestrator that coordinates the research and coding workflow", - instructions="You coordinate a team to complete complex tasks efficiently.", - chat_client=OpenAIChatClient(), - ) - - # Callbacks - def on_exception(exception: Exception) -> None: - print(f"Exception occurred: {exception}") - logger.exception("Workflow exception", exc_info=exception) - - last_stream_agent_id: str | None = None - stream_line_open: bool = False - - print("\nBuilding Magentic Workflow...") - - workflow = ( - MagenticBuilder() - .participants(researcher=researcher_agent, coder=coder_agent) - .with_standard_manager( - agent=manager_agent, - max_round_count=10, - max_stall_count=3, - max_reset_count=2, - ) - .with_plan_review() - .build() - ) - - task = ( - "I am preparing a report on the energy efficiency of different machine learning model architectures. " - "Compare the estimated training and inference energy consumption of ResNet-50, BERT-base, and GPT-2 " - "on standard datasets (e.g., ImageNet for ResNet, GLUE for BERT, WebText for GPT-2). " - "Then, estimate the CO2 emissions associated with each, assuming training on an Azure Standard_NC6s_v3 " - "VM for 24 hours. Provide tables for clarity, and recommend the most energy-efficient model " - "per task type (image classification, text classification, and text generation)." - ) - - print(f"\nTask: {task}") - print("\nStarting workflow execution...") - - try: - pending_request: RequestInfoEvent | None = None - pending_responses: dict[str, MagenticHumanInterventionReply] | None = None - completed = False - workflow_output: str | None = None - - while not completed: - # Use streaming for both initial run and response sending - if pending_responses is not None: - stream = workflow.send_responses_streaming(pending_responses) - else: - stream = workflow.run_stream(task) - - # Collect events from the stream - async for event in stream: - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - print(f"\n[ORCH:{kind}]\n\n{text}\n{'-' * 26}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", "unknown") if props else "unknown" - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[STREAM:{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - request = cast(MagenticHumanInterventionRequest, event.data) - if request.kind == MagenticHumanInterventionKind.PLAN_REVIEW: - pending_request = event - if request.plan_text: - print(f"\n=== PLAN REVIEW REQUEST ===\n{request.plan_text}\n") - elif isinstance(event, WorkflowOutputEvent): - # Capture workflow output during streaming - workflow_output = str(event.data) if event.data else None - completed = True - - if stream_line_open: - print() - stream_line_open = False - pending_responses = None - - # Handle pending plan review request - if pending_request is not None: - # Get human input for plan review decision - print("Plan review options:") - print("1. approve - Approve the plan as-is") - print("2. approve with comments - Approve with feedback for the manager") - print("3. revise - Request revision with your feedback") - print("4. edit - Directly edit the plan text") - print("5. exit - Exit the workflow") - - while True: - choice = input("Enter your choice (1-5): ").strip().lower() # noqa: ASYNC250 - if choice in ["approve", "1"]: - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.APPROVE) - break - if choice in ["approve with comments", "2"]: - comments = input("Enter your comments for the manager: ").strip() # noqa: ASYNC250 - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.APPROVE, - comments=comments if comments else None, - ) - break - if choice in ["revise", "3"]: - comments = input("Enter feedback for revising the plan: ").strip() # noqa: ASYNC250 - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, - comments=comments if comments else None, - ) - break - if choice in ["edit", "4"]: - print("Enter your edited plan (end with an empty line):") - lines = [] - while True: - line = input() # noqa: ASYNC250 - if line == "": - break - lines.append(line) - edited_plan = "\n".join(lines) - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.REVISE, - edited_plan_text=edited_plan if edited_plan else None, - ) - break - if choice in ["exit", "5"]: - print("Exiting workflow...") - return - print("Invalid choice. Please enter a number 1-5.") - - pending_responses = {pending_request.request_id: reply} - pending_request = None - - # Show final result from captured workflow output - if workflow_output: - print(f"Workflow completed with result:\n\n{workflow_output}") - - except Exception as e: - print(f"Workflow execution failed: {e}") - on_exception(e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/magentic_human_replan.py b/python/samples/getting_started/workflows/orchestration/magentic_human_replan.py deleted file mode 100644 index aaa9be66f8..0000000000 --- a/python/samples/getting_started/workflows/orchestration/magentic_human_replan.py +++ /dev/null @@ -1,213 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import asyncio -import logging -from typing import cast - -from agent_framework import ( - MAGENTIC_EVENT_TYPE_AGENT_DELTA, - MAGENTIC_EVENT_TYPE_ORCHESTRATOR, - AgentRunUpdateEvent, - ChatAgent, - ChatMessage, - MagenticBuilder, - MagenticHumanInterventionDecision, - MagenticHumanInterventionKind, - MagenticHumanInterventionReply, - MagenticHumanInterventionRequest, - RequestInfoEvent, - WorkflowOutputEvent, -) -from agent_framework.openai import OpenAIChatClient - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - -""" -Sample: Magentic Orchestration with Human Stall Intervention - -This sample demonstrates how humans can intervene when a Magentic workflow stalls. -When agents stop making progress, the workflow requests human input instead of -automatically replanning. - -Key concepts: -- with_human_input_on_stall(): Enables human intervention when workflow detects stalls -- MagenticHumanInterventionKind.STALL: The request kind for stall interventions -- Human can choose to: continue, trigger replan, or provide guidance - -Stall intervention options: -- CONTINUE: Reset stall counter and continue with current plan -- REPLAN: Trigger automatic replanning by the manager -- GUIDANCE: Provide text guidance to help agents get back on track - -Prerequisites: -- OpenAI credentials configured for `OpenAIChatClient`. - -NOTE: it is sometimes difficult to get the agents to actually stall depending on the task. -""" - - -async def main() -> None: - researcher_agent = ChatAgent( - name="ResearcherAgent", - description="Specialist in research and information gathering", - instructions="You are a Researcher. You find information and gather facts.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - analyst_agent = ChatAgent( - name="AnalystAgent", - description="Data analyst who processes and summarizes research findings", - instructions="You are an Analyst. You analyze findings and create summaries.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - manager_agent = ChatAgent( - name="MagenticManager", - description="Orchestrator that coordinates the workflow", - instructions="You coordinate a team to complete tasks efficiently.", - chat_client=OpenAIChatClient(model_id="gpt-4o"), - ) - - print("\nBuilding Magentic Workflow with Human Stall Intervention...") - - workflow = ( - MagenticBuilder() - .participants(researcher=researcher_agent, analyst=analyst_agent) - .with_standard_manager( - agent=manager_agent, - max_round_count=10, - max_stall_count=1, # Stall detection after 1 round without progress - max_reset_count=2, - ) - .with_human_input_on_stall() # Request human input when stalled (instead of auto-replan) - .build() - ) - - task = "Research sustainable aviation fuel technology and summarize the findings." - - print(f"\nTask: {task}") - print("\nStarting workflow execution...") - print("=" * 60) - - try: - pending_request: RequestInfoEvent | None = None - pending_responses: dict[str, object] | None = None - completed = False - workflow_output: str | None = None - - last_stream_agent_id: str | None = None - stream_line_open: bool = False - - while not completed: - if pending_responses is not None: - stream = workflow.send_responses_streaming(pending_responses) - else: - stream = workflow.run_stream(task) - - async for event in stream: - if isinstance(event, AgentRunUpdateEvent): - props = event.data.additional_properties if event.data else None - event_type = props.get("magentic_event_type") if props else None - - if event_type == MAGENTIC_EVENT_TYPE_ORCHESTRATOR: - kind = props.get("orchestrator_message_kind", "") if props else "" - text = event.data.text if event.data else "" - if stream_line_open: - print() - stream_line_open = False - print(f"\n[ORCHESTRATOR: {kind}]\n{text}\n{'-' * 40}") - elif event_type == MAGENTIC_EVENT_TYPE_AGENT_DELTA: - agent_id = props.get("agent_id", "unknown") if props else "unknown" - if last_stream_agent_id != agent_id or not stream_line_open: - if stream_line_open: - print() - print(f"\n[{agent_id}]: ", end="", flush=True) - last_stream_agent_id = agent_id - stream_line_open = True - if event.data and event.data.text: - print(event.data.text, end="", flush=True) - - elif isinstance(event, RequestInfoEvent) and event.request_type is MagenticHumanInterventionRequest: - if stream_line_open: - print() - stream_line_open = False - pending_request = event - req = cast(MagenticHumanInterventionRequest, event.data) - - if req.kind == MagenticHumanInterventionKind.STALL: - print("\n" + "=" * 60) - print("STALL INTERVENTION REQUESTED") - print("=" * 60) - print(f"\nWorkflow appears stalled after {req.stall_count} rounds") - print(f"Reason: {req.stall_reason}") - if req.last_agent: - print(f"Last active agent: {req.last_agent}") - if req.plan_text: - print(f"\nCurrent plan:\n{req.plan_text}") - print() - - elif isinstance(event, WorkflowOutputEvent): - if stream_line_open: - print() - stream_line_open = False - workflow_output = event.data if event.data else None - completed = True - - if stream_line_open: - print() - stream_line_open = False - pending_responses = None - - # Handle stall intervention request - if pending_request is not None: - req = cast(MagenticHumanInterventionRequest, pending_request.data) - reply: MagenticHumanInterventionReply | None = None - - if req.kind == MagenticHumanInterventionKind.STALL: - print("Stall intervention options:") - print("1. continue - Continue with current plan (reset stall counter)") - print("2. replan - Trigger automatic replanning") - print("3. guidance - Provide guidance to help agents") - print("4. exit - Exit the workflow") - - while True: - choice = input("Enter your choice (1-4): ").strip().lower() # noqa: ASYNC250 - if choice in ["continue", "1"]: - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.CONTINUE) - break - if choice in ["replan", "2"]: - reply = MagenticHumanInterventionReply(decision=MagenticHumanInterventionDecision.REPLAN) - break - if choice in ["guidance", "3"]: - guidance = input("Enter your guidance: ").strip() # noqa: ASYNC250 - reply = MagenticHumanInterventionReply( - decision=MagenticHumanInterventionDecision.GUIDANCE, - comments=guidance if guidance else None, - ) - break - if choice in ["exit", "4"]: - print("Exiting workflow...") - return - print("Invalid choice. Please enter a number 1-4.") - - if reply is not None: - pending_responses = {pending_request.request_id: reply} - pending_request = None - - print("\n" + "=" * 60) - print("WORKFLOW COMPLETED") - print("=" * 60) - if workflow_output: - messages = cast(list[ChatMessage], workflow_output) - if messages: - final_msg = messages[-1] - print(f"\nFinal Result:\n{final_msg.text}") - - except Exception as e: - print(f"Workflow execution failed: {e}") - logger.exception("Workflow exception", exc_info=e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py b/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py index 104a833603..ecf605e73b 100644 --- a/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py +++ b/python/samples/getting_started/workflows/orchestration/sequential_custom_executors.py @@ -4,6 +4,7 @@ from typing import Any from agent_framework import ( + AgentExecutorResponse, ChatMessage, Executor, Role, @@ -20,18 +21,13 @@ This demonstrates how SequentialBuilder chains participants with a shared conversation context (list[ChatMessage]). An agent produces content; a custom executor appends a compact summary to the conversation. The workflow completes -when idle, and the final output contains the complete conversation. +after all participants have executed in sequence, and the final output contains +the complete conversation. Custom executor contract: -- Provide at least one @handler accepting list[ChatMessage] and a WorkflowContext[list[ChatMessage]] +- Provide at least one @handler accepting AgentExecutorResponse and a WorkflowContext[list[ChatMessage]] - Emit the updated conversation via ctx.send_message([...]) -Note on internal adapters: -- You may see adapter nodes in the event stream such as "input-conversation", - "to-conversation:", and "complete". These provide consistent typing, - conversion of agent responses into the shared conversation, and a single point - for completion—similar to concurrent's dispatcher/aggregator. - Prerequisites: - Azure OpenAI access configured for AzureOpenAIChatClient (use az login + env vars) """ @@ -41,11 +37,23 @@ class Summarizer(Executor): """Simple summarizer: consumes full conversation and appends an assistant summary.""" @handler - async def summarize(self, conversation: list[ChatMessage], ctx: WorkflowContext[list[ChatMessage]]) -> None: - users = sum(1 for m in conversation if m.role == Role.USER) - assistants = sum(1 for m in conversation if m.role == Role.ASSISTANT) + async def summarize(self, agent_response: AgentExecutorResponse, ctx: WorkflowContext[list[ChatMessage]]) -> None: + """Append a summary message to a copy of the full conversation. + + Note: A custom executor must be able to handle the message type from the prior participant, and produce + the message type expected by the next participant. In this case, the prior participant is an agent thus + the input is AgentExecutorResponse (an agent will be wrapped in an AgentExecutor, which produces + `AgentExecutorResponse`). If the next participant is also an agent or this is the final participant, + the output must be `list[ChatMessage]`. + """ + if not agent_response.full_conversation: + await ctx.send_message([ChatMessage(role=Role.ASSISTANT, text="No conversation to summarize.")]) + return + + users = sum(1 for m in agent_response.full_conversation if m.role == Role.USER) + assistants = sum(1 for m in agent_response.full_conversation if m.role == Role.ASSISTANT) summary = ChatMessage(role=Role.ASSISTANT, text=f"Summary -> users:{users} assistants:{assistants}") - final_conversation = list(conversation) + [summary] + final_conversation = list(agent_response.full_conversation) + [summary] await ctx.send_message(final_conversation) @@ -61,7 +69,7 @@ async def main() -> None: summarizer = Summarizer(id="summarizer") workflow = SequentialBuilder().participants([content, summarizer]).build() - # 3) Run and print final conversation + # 3) Run workflow and extract final conversation events = await workflow.run("Explain the benefits of budget eBikes for commuters.") outputs = events.get_outputs() diff --git a/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py b/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py index e4092414fc..a858fe28ce 100644 --- a/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py +++ b/python/samples/getting_started/workflows/tool-approval/concurrent_builder_tool_approval.py @@ -10,8 +10,6 @@ FunctionApprovalResponseContent, RequestInfoEvent, WorkflowOutputEvent, - WorkflowRunState, - WorkflowStatusEvent, ai_function, ) from agent_framework.openai import OpenAIChatClient @@ -25,19 +23,18 @@ This sample works as follows: 1. A ConcurrentBuilder workflow is created with two agents running in parallel. -2. One agent has a tool requiring approval (financial transaction). -3. The other agent has only non-approval tools (market data lookup). -4. Both agents receive the same task and work concurrently. -5. When the financial agent tries to execute a trade, it triggers an approval request. -6. The sample simulates human approval and the workflow completes. -7. Results from both agents are aggregated and output. +2. Both agents have the same tools, including one requiring approval (execute_trade). +3. Both agents receive the same task and work concurrently on their respective stocks. +4. When either agent tries to execute a trade, it triggers an approval request. +5. The sample simulates human approval and the workflow completes. +6. Results from both agents are aggregated and output. Purpose: -Show how tool call approvals work in parallel execution scenarios where only some -agents have sensitive tools. +Show how tool call approvals work in parallel execution scenarios where multiple +agents may independently trigger approval requests. Demonstrate: -- Combining agents with and without approval-required tools in concurrent workflows. +- Handling multiple approval requests from different agents in concurrent workflows. - Handling RequestInfoEvent during concurrent agent execution. - Understanding that approval pauses only the agent that triggered it, not all agents. @@ -47,7 +44,7 @@ """ -# 1. Define tools for the research agent (no approval required) +# 1. Define market data tools (no approval required) @ai_function def get_stock_price(symbol: Annotated[str, "The stock ticker symbol"]) -> str: """Get the current stock price for a given symbol.""" @@ -61,10 +58,16 @@ def get_stock_price(symbol: Annotated[str, "The stock ticker symbol"]) -> str: def get_market_sentiment(symbol: Annotated[str, "The stock ticker symbol"]) -> str: """Get market sentiment analysis for a stock.""" # Mock sentiment data - return f"Market sentiment for {symbol.upper()}: Bullish (72% positive mentions in last 24h)" + mock_data = { + "AAPL": "Market sentiment for AAPL: Bullish (68% positive mentions in last 24h)", + "GOOGL": "Market sentiment for GOOGL: Neutral (50% positive mentions in last 24h)", + "MSFT": "Market sentiment for MSFT: Bullish (72% positive mentions in last 24h)", + "AMZN": "Market sentiment for AMZN: Bearish (40% positive mentions in last 24h)", + } + return mock_data.get(symbol.upper(), f"Market sentiment for {symbol.upper()}: Unknown") -# 2. Define tools for the trading agent (approval required for trades) +# 2. Define trading tools (approval required) @ai_function(approval_mode="always_require") def execute_trade( symbol: Annotated[str, "The stock ticker symbol"], @@ -78,52 +81,68 @@ def execute_trade( @ai_function def get_portfolio_balance() -> str: """Get current portfolio balance and available funds.""" - return "Portfolio: $50,000 invested, $10,000 cash available" + return "Portfolio: $50,000 invested, $10,000 cash available. Holdings: AAPL, GOOGL, MSFT." + + +def _print_output(event: WorkflowOutputEvent) -> None: + if not event.data: + raise ValueError("WorkflowOutputEvent has no data") + + if not isinstance(event.data, list) and not all(isinstance(msg, ChatMessage) for msg in event.data): + raise ValueError("WorkflowOutputEvent data is not a list of ChatMessage") + + messages: list[ChatMessage] = event.data # type: ignore + + print("\n" + "-" * 60) + print("Workflow completed. Aggregated results from both agents:") + for msg in messages: + if msg.text: + print(f"- {msg.author_name or msg.role.value}: {msg.text}") async def main() -> None: - # 3. Create two agents with different tool sets + # 3. Create two agents focused on different stocks but with the same tool sets chat_client = OpenAIChatClient() - research_agent = chat_client.create_agent( - name="ResearchAgent", + microsoft_agent = chat_client.create_agent( + name="MicrosoftAgent", instructions=( - "You are a market research analyst. Analyze stock data and provide " - "recommendations based on price and sentiment. Do not execute trades." + "You are a personal trading assistant focused on Microsoft (MSFT). " + "You manage my portfolio and take actions based on market data." ), - tools=[get_stock_price, get_market_sentiment], + tools=[get_stock_price, get_market_sentiment, get_portfolio_balance, execute_trade], ) - trading_agent = chat_client.create_agent( - name="TradingAgent", + google_agent = chat_client.create_agent( + name="GoogleAgent", instructions=( - "You are a trading assistant. When asked to buy or sell shares, you MUST " - "call the execute_trade function to complete the transaction. Check portfolio " - "balance first, then execute the requested trade." + "You are a personal trading assistant focused on Google (GOOGL). " + "You manage my trades and portfolio based on market conditions." ), - tools=[get_portfolio_balance, execute_trade], + tools=[get_stock_price, get_market_sentiment, get_portfolio_balance, execute_trade], ) # 4. Build a concurrent workflow with both agents # ConcurrentBuilder requires at least 2 participants for fan-out - workflow = ConcurrentBuilder().participants([research_agent, trading_agent]).build() + workflow = ConcurrentBuilder().participants([microsoft_agent, google_agent]).build() # 5. Start the workflow - both agents will process the same task in parallel print("Starting concurrent workflow with tool approval...") - print("Two agents will analyze MSFT - one for research, one for trading.") print("-" * 60) - # Phase 1: Run workflow and collect all events (stream ends at IDLE or IDLE_WITH_PENDING_REQUESTS) + # Phase 1: Run workflow and collect request info events request_info_events: list[RequestInfoEvent] = [] - workflow_completed_without_approvals = False - async for event in workflow.run_stream("Analyze MSFT stock and if sentiment is positive, buy 10 shares."): + async for event in workflow.run_stream( + "Manage my portfolio. Use a max of 5000 dollars to adjust my position using " + "your best judgment based on market sentiment. No need to confirm trades with me." + ): if isinstance(event, RequestInfoEvent): request_info_events.append(event) if isinstance(event.data, FunctionApprovalRequestContent): print(f"\nApproval requested for tool: {event.data.function_call.name}") print(f" Arguments: {event.data.function_call.arguments}") - elif isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: - workflow_completed_without_approvals = True + elif isinstance(event, WorkflowOutputEvent): + _print_output(event) # 6. Handle approval requests (if any) if request_info_events: @@ -136,46 +155,37 @@ async def main() -> None: if responses: # Phase 2: Send all approvals and continue workflow - output: list[ChatMessage] | None = None async for event in workflow.send_responses_streaming(responses): if isinstance(event, WorkflowOutputEvent): - output = event.data - - if output: - print("\n" + "-" * 60) - print("Workflow completed. Aggregated results from both agents:") - for msg in output: - if hasattr(msg, "author_name") and msg.author_name: - print(f"\n[{msg.author_name}]:") - text = msg.text[:300] + "..." if len(msg.text) > 300 else msg.text - if text: - print(f" {text}") - elif workflow_completed_without_approvals: + _print_output(event) + else: print("\nWorkflow completed without requiring approvals.") - print("(The trading agent may have only checked balance without executing a trade)") + print("(The agents may have only checked data without executing trades)") """ Sample Output: Starting concurrent workflow with tool approval... - Two agents will analyze MSFT - one for research, one for trading. ------------------------------------------------------------ Approval requested for tool: execute_trade - Arguments: {"symbol": "MSFT", "action": "buy", "quantity": 10} + Arguments: {"symbol":"MSFT","action":"buy","quantity":13} + + Approval requested for tool: execute_trade + Arguments: {"symbol":"GOOGL","action":"buy","quantity":35} + + Simulating human approval for: execute_trade + Simulating human approval for: execute_trade ------------------------------------------------------------ Workflow completed. Aggregated results from both agents: - - [ResearchAgent]: - MSFT is currently trading at $175.50 with bullish market sentiment - (72% positive mentions). Based on the positive sentiment, this could - be a good opportunity to consider buying. - - [TradingAgent]: - I've checked your portfolio balance ($10,000 cash available) and - executed the trade: BUY 10 shares of MSFT at approximately $175.50 - per share, totaling ~$1,755. + - user: Manage my portfolio. Use a max of 5000 dollars to adjust my position using your best judgment based on + market sentiment. No need to confirm trades with me. + - MicrosoftAgent: I have successfully executed the trade, purchasing 13 shares of Microsoft (MSFT). This action + was based on the positive market sentiment and available funds within the specified limit. + Your portfolio has been adjusted accordingly. + - GoogleAgent: I have successfully executed the trade, purchasing 35 shares of GOOGL. If you need further + assistance or any adjustments, feel free to ask! """ diff --git a/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py b/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py index 565002c794..a8536afc7f 100644 --- a/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py +++ b/python/samples/getting_started/workflows/tool-approval/group_chat_builder_tool_approval.py @@ -4,9 +4,11 @@ from typing import Annotated from agent_framework import ( + AgentRunUpdateEvent, FunctionApprovalRequestContent, GroupChatBuilder, - GroupChatStateSnapshot, + GroupChatRequestSentEvent, + GroupChatState, RequestInfoEvent, ai_function, ) @@ -73,7 +75,7 @@ def create_rollback_plan(version: Annotated[str, "The version being deployed"]) # 2. Define the speaker selector function -def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: +def select_next_speaker(state: GroupChatState) -> str: """Select the next speaker based on the conversation flow. This simple selector follows a predefined flow: @@ -81,19 +83,13 @@ def select_next_speaker(state: GroupChatStateSnapshot) -> str | None: 2. DevOps Engineer checks staging and creates rollback plan 3. DevOps Engineer deploys to production (triggers approval) """ - round_index: int = state["round_index"] + if not state.conversation: + raise RuntimeError("Conversation is empty; cannot select next speaker.") - # Define the conversation flow - speaker_order: list[str] = [ - "QAEngineer", # Round 0: Run tests - "DevOpsEngineer", # Round 1: Check staging, create rollback - "DevOpsEngineer", # Round 2: Deploy to production (approval required) - ] + if len(state.conversation) == 1: + return "QAEngineer" # First speaker - if round_index >= len(speaker_order): - return None # End the conversation - - return speaker_order[round_index] + return "DevOpsEngineer" # Subsequent speakers async def main() -> None: @@ -123,28 +119,47 @@ async def main() -> None: workflow = ( GroupChatBuilder() # Optionally, use `.set_manager(...)` to customize the group chat manager - .set_select_speakers_func(select_next_speaker) + .with_select_speaker_func(select_next_speaker) .participants([qa_engineer, devops_engineer]) - .with_max_rounds(5) + # Set a hard limit to 4 rounds + # First round: QAEngineer speaks + # Second round: DevOpsEngineer speaks (check staging + create rollback) + # Third round: DevOpsEngineer speaks with an approval request (deploy to production) + # Fourth round: DevOpsEngineer speaks again after approval + .with_max_rounds(4) .build() ) # 5. Start the workflow print("Starting group chat workflow for software deployment...") - print("Agents: QA Engineer, DevOps Engineer") + print(f"Agents: {[qa_engineer.name, devops_engineer.name]}") print("-" * 60) # Phase 1: Run workflow and collect all events (stream ends at IDLE or IDLE_WITH_PENDING_REQUESTS) request_info_events: list[RequestInfoEvent] = [] + # Keep track of the last response to format output nicely in streaming mode + last_response_id: str | None = None async for event in workflow.run_stream( "We need to deploy version 2.4.0 to production. Please coordinate the deployment." ): if isinstance(event, RequestInfoEvent): request_info_events.append(event) if isinstance(event.data, FunctionApprovalRequestContent): - print("\n[APPROVAL REQUIRED]") + print("\n[APPROVAL REQUIRED] From agent:", event.source_executor_id) print(f" Tool: {event.data.function_call.name}") print(f" Arguments: {event.data.function_call.arguments}") + elif isinstance(event, AgentRunUpdateEvent): + if not event.data.text: + continue # Skip empty updates + response_id = event.data.response_id + if response_id != last_response_id: + if last_response_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_response_id = response_id + print(event.data, end="", flush=True) + elif isinstance(event, GroupChatRequestSentEvent): + print(f"\n[REQUEST SENT ({event.round_index})] to agent: {event.participant_name}") # 6. Handle approval requests if request_info_events: @@ -160,8 +175,21 @@ async def main() -> None: approval_response = request_event.data.create_response(approved=True) # Phase 2: Send approval and continue workflow - async for _ in workflow.send_responses_streaming({request_event.request_id: approval_response}): - pass # Consume all events + # Keep track of the response to format output nicely in streaming mode + last_response_id: str | None = None + async for event in workflow.send_responses_streaming({request_event.request_id: approval_response}): + if isinstance(event, AgentRunUpdateEvent): + if not event.data.text: + continue # Skip empty updates + response_id = event.data.response_id + if response_id != last_response_id: + if last_response_id is not None: + print("\n") + print(f"- {event.executor_id}:", end=" ", flush=True) + last_response_id = response_id + print(event.data, end="", flush=True) + elif isinstance(event, GroupChatRequestSentEvent): + print(f"\n[REQUEST SENT ({event.round_index})] To agent: {event.participant_name}") print("\n" + "-" * 60) print("Deployment workflow completed successfully!")