From 3234c6e884037eb5e051a4f31d9b43f8924eb0c1 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 14 Oct 2025 14:50:29 -0700 Subject: [PATCH 01/26] Prototype: Add request_info API and @response_handler --- .../agent_framework/_workflows/__init__.py | 4 +- .../agent_framework/_workflows/__init__.pyi | 4 +- .../core/agent_framework/_workflows/_const.py | 7 ++ .../core/agent_framework/_workflows/_edge.py | 25 ++++ .../_workflows/_edge_runner.py | 28 +++-- .../agent_framework/_workflows/_events.py | 6 +- .../agent_framework/_workflows/_executor.py | 36 +++--- .../_workflows/_request_info_mixin.py | 115 ++++++++++++++++++ .../agent_framework/_workflows/_runner.py | 4 + .../_workflows/_runner_context.py | 91 +++++++++++++- .../agent_framework/_workflows/_validation.py | 54 +------- .../agent_framework/_workflows/_workflow.py | 96 ++++++--------- .../_workflows/_workflow_context.py | 28 ++++- .../_workflows/_workflow_executor.py | 5 +- .../guessing_game_with_human_input.py | 108 ++++++---------- 15 files changed, 400 insertions(+), 211 deletions(-) create mode 100644 python/packages/core/agent_framework/_workflows/_request_info_mixin.py diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index 94950e1948..bd8cdfe3ac 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -80,6 +80,7 @@ RequestInfoMessage, RequestResponse, ) +from ._request_info_mixin import response_handler from ._runner import Runner from ._runner_context import ( InProcRunnerContext, @@ -90,7 +91,6 @@ from ._shared_state import SharedState from ._validation import ( EdgeDuplicationError, - ExecutorDuplicationError, GraphConnectivityError, TypeCompatibilityError, ValidationTypeEnum, @@ -117,7 +117,6 @@ "EdgeDuplicationError", "Executor", "ExecutorCompletedEvent", - "ExecutorDuplicationError", "ExecutorEvent", "ExecutorFailedEvent", "ExecutorInvokedEvent", @@ -187,5 +186,6 @@ "executor", "get_checkpoint_summary", "handler", + "response_handler", "validate_workflow_graph", ] diff --git a/python/packages/core/agent_framework/_workflows/__init__.pyi b/python/packages/core/agent_framework/_workflows/__init__.pyi index d98829c56d..b6de88b5ba 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.pyi +++ b/python/packages/core/agent_framework/_workflows/__init__.pyi @@ -78,6 +78,7 @@ from ._request_info_executor import ( RequestInfoMessage, RequestResponse, ) +from ._request_info_mixin import response_handler from ._runner import Runner from ._runner_context import ( InProcRunnerContext, @@ -88,7 +89,6 @@ from ._sequential import SequentialBuilder from ._shared_state import SharedState from ._validation import ( EdgeDuplicationError, - ExecutorDuplicationError, GraphConnectivityError, TypeCompatibilityError, ValidationTypeEnum, @@ -115,7 +115,6 @@ __all__ = [ "EdgeDuplicationError", "Executor", "ExecutorCompletedEvent", - "ExecutorDuplicationError", "ExecutorEvent", "ExecutorFailedEvent", "ExecutorInvokedEvent", @@ -185,5 +184,6 @@ __all__ = [ "executor", "get_checkpoint_summary", "handler", + "response_handler", "validate_workflow_graph", ] diff --git a/python/packages/core/agent_framework/_workflows/_const.py b/python/packages/core/agent_framework/_workflows/_const.py index b426692725..80d70b0b50 100644 --- a/python/packages/core/agent_framework/_workflows/_const.py +++ b/python/packages/core/agent_framework/_workflows/_const.py @@ -1,3 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. DEFAULT_MAX_ITERATIONS = 100 # Default maximum iterations for workflow execution. + +INTERNAL_SOURCE_PREFIX = "internal" # Source identifier for internal workflow messages. + + +def INTERNAL_SOURCE_ID(executor_id: str) -> str: + """Generate an internal source ID for a given executor.""" + return f"{INTERNAL_SOURCE_PREFIX}:{executor_id}" diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index 6cc1aa31b2..22bd0255ff 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from typing import Any, ClassVar +from ._const import INTERNAL_SOURCE_ID from ._executor import Executor from ._model_utils import DictConvertible, encode_value @@ -865,3 +866,27 @@ def to_dict(self) -> dict[str, Any]: payload = super().to_dict() payload["cases"] = [encode_value(case) for case in self.cases] return payload + + +class InternalEdgeGroup(EdgeGroup): + """Special edge group used to route internal messages to executors. + + This group is not serialized and is only used at runtime to link internal + executors that should not be exposed as part of the public workflow graph. + """ + + def __init__(self, executor_id: str) -> None: + """Create an internal edge group from the given edges. + + Parameters + ---------- + executor_id: + Identifier of the internal executor that should receive messages. + + Examples: + .. code-block:: python + + edge_group = InternalEdgeGroup("executor_a") + """ + edge = Edge(source_id=INTERNAL_SOURCE_ID(executor_id), target_id=executor_id) + super().__init__([edge]) diff --git a/python/packages/core/agent_framework/_workflows/_edge_runner.py b/python/packages/core/agent_framework/_workflows/_edge_runner.py index bc6d5d85c4..d97236530a 100644 --- a/python/packages/core/agent_framework/_workflows/_edge_runner.py +++ b/python/packages/core/agent_framework/_workflows/_edge_runner.py @@ -8,7 +8,15 @@ from typing import Any, cast from ..observability import EdgeGroupDeliveryStatus, OtelAttr, create_edge_group_processing_span -from ._edge import Edge, EdgeGroup, FanInEdgeGroup, FanOutEdgeGroup, SingleEdgeGroup, SwitchCaseEdgeGroup +from ._edge import ( + Edge, + EdgeGroup, + FanInEdgeGroup, + FanOutEdgeGroup, + InternalEdgeGroup, + SingleEdgeGroup, + SwitchCaseEdgeGroup, +) from ._executor import Executor from ._runner_context import Message, RunnerContext from ._shared_state import SharedState @@ -44,11 +52,11 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R """ raise NotImplementedError - def _can_handle(self, executor_id: str, message_data: Any) -> bool: + def _can_handle(self, executor_id: str, message: Message) -> bool: """Check if an executor can handle the given message data.""" if executor_id not in self._executors: return False - return self._executors[executor_id].can_handle(message_data) + return self._executors[executor_id].can_handle(message) async def _execute_on_target( self, @@ -66,7 +74,7 @@ async def _execute_on_target( # Execute with trace context parameters await target_executor.execute( - message.data, + message, source_ids, # source_executor_ids shared_state, # shared_state ctx, # runner_context @@ -103,7 +111,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R }) return False - if self._can_handle(self._edge.target_id, message.data): + if self._can_handle(self._edge.target_id, message): if self._edge.should_route(message.data): span.set_attributes({ OtelAttr.EDGE_GROUP_DELIVERED: True, @@ -183,7 +191,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R # If the target ID is specified and the selection result contains it, send the message to that edge if message.target_id in selection_results: edge = self._target_map.get(message.target_id) - if edge and self._can_handle(edge.target_id, message.data): + if edge and self._can_handle(edge.target_id, message): if edge.should_route(message.data): span.set_attributes({ OtelAttr.EDGE_GROUP_DELIVERED: True, @@ -215,7 +223,7 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R # If no target ID, send the message to the selected targets for target_id in selection_results: edge = self._target_map[target_id] - if self._can_handle(edge.target_id, message.data) and edge.should_route(message.data): + if self._can_handle(edge.target_id, message) and edge.should_route(message.data): deliverable_edges.append(edge) if len(deliverable_edges) > 0: @@ -291,7 +299,9 @@ async def send_message(self, message: Message, shared_state: SharedState, ctx: R return False # Check if target can handle list of message data (fan-in aggregates multiple messages) - if self._can_handle(self._edges[0].target_id, [message.data]): + if self._can_handle( + self._edges[0].target_id, Message(data=[message.data], source_id=message.source_id) + ): # If the edge can handle the data, buffer the message self._buffer[message.source_id].append(message) span.set_attributes({ @@ -374,7 +384,7 @@ def create_edge_runner(edge_group: EdgeGroup, executors: dict[str, Executor]) -> Returns: The appropriate EdgeRunner instance. """ - if isinstance(edge_group, SingleEdgeGroup): + if isinstance(edge_group, (SingleEdgeGroup, InternalEdgeGroup)): return SingleEdgeRunner(edge_group, executors) if isinstance(edge_group, SwitchCaseEdgeGroup): return SwitchCaseEdgeRunner(edge_group, executors) diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index 58e699e2b4..a339ea9cb3 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -212,6 +212,7 @@ def __init__( source_executor_id: str, request_type: type, request_data: "RequestInfoMessage", + response_type: type, ): """Initialize the request info event. @@ -220,11 +221,13 @@ def __init__( source_executor_id: ID of the executor that made the request. request_type: Type of the request (e.g., a specific data type). request_data: The data associated with the request. + response_type: Expected type of the response. """ super().__init__(request_data) self.request_id = request_id self.source_executor_id = source_executor_id self.request_type = request_type + self.response_type = response_type def __repr__(self) -> str: """Return a string representation of the request info event.""" @@ -233,7 +236,8 @@ def __repr__(self) -> str: f"request_id={self.request_id}, " f"source_executor_id={self.source_executor_id}, " f"request_type={self.request_type.__name__}, " - f"data={self.data})" + f"data={self.data}, " + f"response_type={self.response_type.__name__})" ) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 1f822e870a..dc7261ab71 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -16,7 +16,8 @@ _framework_event_origin, # type: ignore[reportPrivateUsage] ) from ._model_utils import DictConvertible -from ._runner_context import Message, RunnerContext # type: ignore +from ._request_info_mixin import RequestInfoMixin +from ._runner_context import Message, MessageType, RunnerContext from ._shared_state import SharedState from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext, validate_function_signature @@ -25,7 +26,7 @@ # region Executor -class Executor(DictConvertible): +class Executor(RequestInfoMixin, DictConvertible): """Base class for all workflow executors that process messages and perform computations. ## Overview @@ -204,6 +205,9 @@ def __init__( "Please define at least one handler using the @handler decorator." ) + # Initialize RequestInfoMixin to discover response handlers + self._discover_response_handlers() + async def execute( self, message: Any, @@ -229,12 +233,17 @@ async def execute( Returns: An awaitable that resolves to the result of the execution. """ - # Create processing span for tracing (gracefully handles disabled tracing) + # Default to find handler in regular handlers + target_handlers = self._handlers # Handle case where Message wrapper is passed instead of raw data if isinstance(message, Message): + if message.message_type == MessageType.RESPONSE: + # Switch to response handlers if message is a response + target_handlers = self._response_handlers message = message.data + # Create processing span for tracing (gracefully handles disabled tracing) with create_processing_span( self.id, self.__class__.__name__, @@ -244,15 +253,9 @@ async def execute( ): # Find the handler and handler spec that matches the message type. handler: Callable[[Any, WorkflowContext[Any, Any]], Awaitable[None]] | None = None - ctx_annotation = None - for message_type in self._handlers: + for message_type in target_handlers: if is_instance_of(message, message_type): - handler = self._handlers[message_type] - # Find the corresponding handler spec for context annotation - for spec in self._handler_specs: - if spec.get("message_type") == message_type: - ctx_annotation = spec.get("ctx_annotation") - break + handler = target_handlers[message_type] break if handler is None: @@ -263,7 +266,6 @@ async def execute( source_executor_ids=source_executor_ids, shared_state=shared_state, runner_context=runner_context, - ctx_annotation=ctx_annotation, trace_contexts=trace_contexts, source_span_ids=source_span_ids, ) @@ -289,7 +291,6 @@ def _create_context_for_handler( source_executor_ids: list[str], shared_state: SharedState, runner_context: RunnerContext, - ctx_annotation: Any, trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, ) -> WorkflowContext[Any]: @@ -299,7 +300,6 @@ def _create_context_for_handler( source_executor_ids: The IDs of the source executors that sent messages to this executor. shared_state: The shared state for the workflow. runner_context: The runner context that provides methods to send messages and events. - ctx_annotation: The context annotation from the handler spec to determine which context type to create. trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking. @@ -350,7 +350,7 @@ def _discover_handlers(self) -> None: # Skip attributes that may not be accessible continue - def can_handle(self, message: Any) -> bool: + def can_handle(self, message: Message) -> bool: """Check if the executor can handle a given message type. Args: @@ -359,7 +359,11 @@ def can_handle(self, message: Any) -> bool: Returns: True if the executor can handle the message type, False otherwise. """ - return any(is_instance_of(message, message_type) for message_type in self._handlers) + if message.message_type == MessageType.REGULAR: + return any(is_instance_of(message.data, message_type) for message_type in self._handlers) + if message.message_type == MessageType.RESPONSE: + return any(is_instance_of(message.data, message_type) for message_type in self._response_handlers) + return False def _register_instance_handler( self, diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py new file mode 100644 index 0000000000..1396382481 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft. All rights reserved. + +import contextlib +import functools +import inspect +import logging +from builtins import type as builtin_type +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar + +from ._workflow_context import WorkflowContext, validate_function_signature + +if TYPE_CHECKING: + from ._executor import Executor + + +logger = logging.getLogger(__name__) + + +class RequestInfoMixin: + """Mixin providing common functionality for request info handling.""" + + _PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info" + + def _discover_response_handlers(self) -> None: + """Discover and register response handlers defined in the class.""" + # Initialize handler storage if not already present + if not hasattr(self, "_response_handlers"): + self._response_handlers: dict[ + builtin_type[Any], Callable[[Any, WorkflowContext[Any, Any]], Awaitable[None]] + ] = {} + if not hasattr(self, "_response_handler_specs"): + self._response_handler_specs: list[dict[str, Any]] = [] + + for attr_name in dir(self.__class__): + try: + attr = getattr(self.__class__, attr_name) + if callable(attr) and hasattr(attr, "_response_handler_spec"): + handler_spec = attr._response_handler_spec # type: ignore + message_type = handler_spec["message_type"] + + if self._response_handlers.get(message_type): + raise ValueError( + f"Duplicate response handler for message type {message_type} in {self.__class__.__name__}" + ) + + self._response_handlers[message_type] = getattr(self, attr_name) + self._response_handler_specs.append({ + "name": handler_spec["name"], + "message_type": message_type, + "output_types": handler_spec.get("output_types", []), + "workflow_output_types": handler_spec.get("workflow_output_types", []), + "ctx_annotation": handler_spec.get("ctx_annotation"), + "source": "class_method", # Distinguish from instance handlers if needed + }) + except AttributeError: + continue # Skip non-callable attributes or those without handler spec + + +ExecutorT = TypeVar("ExecutorT", bound="Executor") +ContextT = TypeVar("ContextT", bound="WorkflowContext[Any, Any]") + + +def response_handler( + func: Callable[[ExecutorT, Any, ContextT], Awaitable[None]], +) -> Callable[[ExecutorT, Any, ContextT], Awaitable[None]]: + """Decorator to register a handler to handle responses for a request. + + Args: + func: The function to decorate. + + Returns: + The decorated function with handler metadata. + + Example: + @response_handler + async def handle_response(self, response: str, context: WorkflowContext[str]) -> None: + ... + + @response_handler + async def handle_response(self, response: dict, context: WorkflowContext[int]) -> None: + ... + """ + + def decorator( + func: Callable[[ExecutorT, Any, ContextT], Awaitable[None]], + ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[None]]: + message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( + validate_function_signature(func, "Handler method") + ) + + # Get signature for preservation + sig = inspect.signature(func) + + @functools.wraps(func) + async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: + """Wrapper function to call the handler.""" + return await func(self, message, ctx) + + # Preserve the original function signature for introspection during validation + with contextlib.suppress(AttributeError, TypeError): + wrapper.__signature__ = sig # type: ignore[attr-defined] + + wrapper._response_handler_spec = { # type: ignore + "name": func.__name__, + "message_type": message_type, + # Keep output_types and workflow_output_types in spec for validators + "output_types": inferred_output_types, + "workflow_output_types": inferred_workflow_output_types, + "ctx_annotation": ctx_annotation, + } + + return wrapper + + return decorator(func) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 9a0d0a7790..c3c11537f5 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -179,6 +179,10 @@ def _normalize_message_payload(message: Message) -> None: # Route all messages through normal workflow edges associated_edge_runners = self._edge_runner_map.get(source_executor_id, []) + if not associated_edge_runners: + logger.warning(f"No outgoing edges found for executor {source_executor_id}; dropping messages.") + return + for message in messages: _normalize_message_payload(message) # Deliver a message through all edge runners associated with the source executor concurrently. diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 4de70cb591..9aaa479d2f 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -8,11 +8,12 @@ import uuid from copy import copy from dataclasses import dataclass, fields, is_dataclass +from enum import Enum from typing import Any, Protocol, TypedDict, TypeVar, cast, runtime_checkable from ._checkpoint import CheckpointStorage, WorkflowCheckpoint -from ._const import DEFAULT_MAX_ITERATIONS -from ._events import WorkflowEvent +from ._const import DEFAULT_MAX_ITERATIONS, INTERNAL_SOURCE_ID +from ._events import RequestInfoEvent, WorkflowEvent from ._shared_state import SharedState logger = logging.getLogger(__name__) @@ -20,6 +21,16 @@ T = TypeVar("T") +class MessageType(Enum): + """Enum representing different types of messages in the workflow.""" + + RESPONSE = "response" + """A response message for a pending request.""" + + REGULAR = "regular" + """A regular message between executors.""" + + @dataclass class Message: """A class representing a message in the workflow.""" @@ -27,6 +38,7 @@ class Message: data: Any source_id: str target_id: str | None = None + message_type: MessageType = MessageType.REGULAR # OpenTelemetry trace context fields for message propagation # These are plural to support fan-in scenarios where multiple messages are aggregated @@ -444,6 +456,31 @@ async def set_checkpoint_state(self, state: CheckpointState) -> None: """ ... + async def add_request_info_event(self, event: RequestInfoEvent) -> None: + """Add a RequestInfoEvent to the context and track it for correlation. + + Args: + event: The RequestInfoEvent to be added. + """ + ... + + async def send_request_info_response(self, request_id: str, response: Any) -> None: + """Send a response correlated to a pending request. + + Args: + request_id: The ID of the original request. + response: The response data to be sent. + """ + ... + + async def get_pending_request_info_events(self) -> dict[str, RequestInfoEvent]: + """Get the mapping of request IDs to their corresponding RequestInfoEvent. + + Returns: + A dictionary mapping request IDs to their corresponding RequestInfoEvent. + """ + ... + class InProcRunnerContext: """In-process execution context for local execution and optional checkpointing.""" @@ -458,6 +495,9 @@ def __init__(self, checkpoint_storage: CheckpointStorage | None = None): # Event queue for immediate streaming of events (e.g., AgentRunUpdateEvent) self._event_queue: asyncio.Queue[WorkflowEvent] = asyncio.Queue() + # An additional storage for pending request info events + self._pending_request_info_events: dict[str, RequestInfoEvent] = {} + # Checkpointing configuration/state self._checkpoint_storage = checkpoint_storage self._workflow_id: str | None = None @@ -655,3 +695,50 @@ async def set_checkpoint_state(self, state: CheckpointState) -> None: self._iteration_count = state.get("iteration_count", 0) self._max_iterations = state.get("max_iterations", 100) + + async def add_request_info_event(self, event: RequestInfoEvent) -> None: + """Add a RequestInfoEvent to the context and track it for correlation. + + Args: + event: The RequestInfoEvent to be added. + """ + self._pending_request_info_events[event.request_id] = event + await self.add_event(event) + + async def send_request_info_response(self, request_id: str, response: Any) -> None: + """Send a response correlated to a pending request. + + Args: + request_id: The ID of the original request. + response: The response data to be sent. + """ + event = self._pending_request_info_events.pop(request_id, None) + if not event: + raise ValueError(f"No pending request found for request_id: {request_id}") + + # Validate response type if specified + if event.response_type and not isinstance(response, event.response_type): + raise TypeError( + f"Response type mismatch for request_id {request_id}: " + f"expected {event.response_type.__name__}, got {type(response).__name__}" + ) + + await self.send_message( + Message( + data=response, + source_id=INTERNAL_SOURCE_ID(event.source_executor_id), + target_id=event.source_executor_id, + message_type=MessageType.RESPONSE, + ) + ) + + # Clear the event from pending requests + self._pending_request_info_events.pop(request_id, None) + + async def get_pending_request_info_events(self) -> dict[str, RequestInfoEvent]: + """Get the mapping of request IDs to their corresponding RequestInfoEvent. + + Returns: + A dictionary mapping request IDs to their corresponding RequestInfoEvent. + """ + return dict(self._pending_request_info_events) diff --git a/python/packages/core/agent_framework/_workflows/_validation.py b/python/packages/core/agent_framework/_workflows/_validation.py index 5cd7940ff3..3b93e7ad84 100644 --- a/python/packages/core/agent_framework/_workflows/_validation.py +++ b/python/packages/core/agent_framework/_workflows/_validation.py @@ -8,7 +8,7 @@ from types import UnionType from typing import Any, Union, get_args, get_origin -from ._edge import Edge, EdgeGroup, FanInEdgeGroup +from ._edge import Edge, EdgeGroup, FanInEdgeGroup, InternalEdgeGroup from ._executor import Executor from ._request_info_executor import RequestInfoExecutor @@ -54,20 +54,6 @@ def __init__(self, edge_id: str): self.edge_id = edge_id -class ExecutorDuplicationError(WorkflowValidationError): - """Exception raised when duplicate executor identifiers are detected.""" - - def __init__(self, executor_id: str): - super().__init__( - message=( - f"Duplicate executor id detected: '{executor_id}'. Executor ids must be globally unique within a " - "workflow." - ), - validation_type=ValidationTypeEnum.EXECUTOR_DUPLICATION, - ) - self.executor_id = executor_id - - class TypeCompatibilityError(WorkflowValidationError): """Exception raised when type incompatibility is detected between connected executors.""" @@ -121,7 +107,6 @@ class WorkflowGraphValidator: def __init__(self) -> None: self._edges: list[Edge] = [] self._executors: dict[str, Executor] = {} - self._duplicate_executor_ids: set[str] = set() self._start_executor_ref: Executor | str | None = None # region Core Validation Methods @@ -130,8 +115,6 @@ def validate_workflow( edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str, - *, - duplicate_executor_ids: Sequence[str] | None = None, ) -> None: """Validate the entire workflow graph. @@ -140,16 +123,12 @@ def validate_workflow( executors: Map of executor IDs to executor instances start_executor: The starting executor (can be instance or ID) - Keyword Args: - duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate - Raises: WorkflowValidationError: If any validation fails """ self._executors = executors self._edges = [edge for group in edge_groups for edge in group.edges] self._edge_groups = edge_groups - self._duplicate_executor_ids = set(duplicate_executor_ids or []) self._start_executor_ref = start_executor # If only the start executor exists, add it to the executor map @@ -185,7 +164,6 @@ def validate_workflow( ) # Run all checks - self._validate_executor_id_uniqueness(start_executor_id) self._validate_edge_duplication() self._validate_handler_output_annotations() self._validate_type_compatibility() @@ -212,26 +190,6 @@ def _validate_handler_output_annotations(self) -> None: # endregion - def _validate_executor_id_uniqueness(self, start_executor_id: str) -> None: - """Ensure executor identifiers are unique throughout the workflow graph.""" - duplicates: set[str] = set(self._duplicate_executor_ids) - - id_counts: defaultdict[str, int] = defaultdict(int) - for key, executor in self._executors.items(): - id_counts[executor.id] += 1 - if key != executor.id: - duplicates.add(executor.id) - - duplicates.update({executor_id for executor_id, count in id_counts.items() if count > 1}) - - if isinstance(self._start_executor_ref, Executor): - mapped = self._executors.get(start_executor_id) - if mapped is not None and mapped is not self._start_executor_ref: - duplicates.add(start_executor_id) - - if duplicates: - raise ExecutorDuplicationError(sorted(duplicates)[0]) - # region Edge and Type Validation def _validate_edge_duplication(self) -> None: """Validate that there are no duplicate edges in the workflow. @@ -273,6 +231,10 @@ def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) - Raises: TypeCompatibilityError: If type incompatibility is detected """ + if isinstance(edge_group, InternalEdgeGroup): + # Skip type compatibility validation for internal edges + return + source_executor = self._executors[edge.source_id] target_executor = self._executors[edge.target_id] @@ -582,8 +544,6 @@ def validate_workflow_graph( edge_groups: Sequence[EdgeGroup], executors: dict[str, Executor], start_executor: Executor | str, - *, - duplicate_executor_ids: Sequence[str] | None = None, ) -> None: """Convenience function to validate a workflow graph. @@ -592,9 +552,6 @@ def validate_workflow_graph( executors: Map of executor IDs to executor instances start_executor: The starting executor (can be instance or ID) - Keyword Args: - duplicate_executor_ids: Optional list of known duplicate executor IDs to pre-populate - Raises: WorkflowValidationError: If any validation fails """ @@ -603,5 +560,4 @@ def validate_workflow_graph( edge_groups, executors, start_executor, - duplicate_executor_ids=duplicate_executor_ids, ) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index d9270bfe02..ecfcf92788 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import functools import hashlib import json import logging @@ -21,6 +22,7 @@ EdgeGroup, FanInEdgeGroup, FanOutEdgeGroup, + InternalEdgeGroup, SingleEdgeGroup, SwitchCaseEdgeGroup, SwitchCaseEdgeGroupCase, @@ -503,33 +505,8 @@ async def send_responses_streaming(self, responses: dict[str, Any]) -> AsyncIter """ self._ensure_not_running() try: - - async def send_responses() -> None: - request_info_executor = self._find_request_info_executor() - if not request_info_executor: - raise ValueError("No RequestInfoExecutor found in workflow.") - - async def _handle_response(response: Any, request_id: str) -> None: - """Handle the response from the RequestInfoExecutor.""" - await request_info_executor.handle_response( - response, - request_id, - WorkflowContext( - request_info_executor.id, - [self.__class__.__name__], - self._shared_state, - self._runner.context, - trace_contexts=None, # No parent trace context for new workflow span - source_span_ids=None, # No source span for response handling - ), - ) - - await asyncio.gather(*[ - _handle_response(response, request_id) for request_id, response in responses.items() - ]) - async for event in self._run_workflow_with_tracing( - initial_executor_fn=send_responses, + initial_executor_fn=functools.partial(self._send_responses_internal, responses), reset_context=False, # Don't reset context when sending responses streaming=True, ): @@ -686,35 +663,10 @@ async def send_responses(self, responses: dict[str, Any]) -> WorkflowRunResult: """ self._ensure_not_running() try: - - async def send_responses_internal() -> None: - request_info_executor = self._find_request_info_executor() - if not request_info_executor: - raise ValueError("No RequestInfoExecutor found in workflow.") - - async def _handle_response(response: Any, request_id: str) -> None: - """Handle the response from the RequestInfoExecutor.""" - await request_info_executor.handle_response( - response, - request_id, - WorkflowContext( - request_info_executor.id, - [self.__class__.__name__], - self._shared_state, - self._runner.context, - trace_contexts=None, # No parent trace context for new workflow span - source_span_ids=None, # No source span for response handling - ), - ) - - await asyncio.gather(*[ - _handle_response(response, request_id) for request_id, response in responses.items() - ]) - events = [ event async for event in self._run_workflow_with_tracing( - initial_executor_fn=send_responses_internal, + initial_executor_fn=functools.partial(self._send_responses_internal, responses), reset_context=False, # Don't reset context when sending responses ) ] @@ -724,6 +676,28 @@ async def _handle_response(response: Any, request_id: str) -> None: finally: self._reset_running_flag() + async def _send_responses_internal(self, responses: dict[str, Any]) -> None: + """Internal method to validate and send responses to the executors.""" + pending_requests = await self._runner_context.get_pending_request_info_events() + if not pending_requests: + raise RuntimeError("No pending requests found in workflow context.") + + # Validate responses against pending requests + for request_id, response in responses.items(): + if request_id not in pending_requests: + raise ValueError(f"Response provided for unknown request ID: {request_id}") + pending_request = pending_requests[request_id] + if not isinstance(response, pending_request.response_type): + raise ValueError( + f"Response type mismatch for request ID {request_id}: " + f"expected {pending_request.response_type}, got {type(response)}" + ) + + await asyncio.gather(*[ + self._runner_context.send_request_info_response(request_id, response) + for request_id, response in responses.items() + ]) + def _get_executor_by_id(self, executor_id: str) -> Executor: """Get an executor by its ID. @@ -884,7 +858,6 @@ def __init__( """ self._edge_groups: list[EdgeGroup] = [] self._executors: dict[str, Executor] = {} - self._duplicate_executor_ids: set[str] = set() self._start_executor: Executor | str | None = None self._checkpoint_storage: CheckpointStorage | None = None self._max_iterations: int = max_iterations @@ -901,10 +874,18 @@ def __init__( def _add_executor(self, executor: Executor) -> str: """Add an executor to the map and return its ID.""" existing = self._executors.get(executor.id) - if existing is not None and existing is not executor: - self._duplicate_executor_ids.add(executor.id) - else: - self._executors[executor.id] = executor + if existing is not None: + if existing is executor: + # Already added + return executor.id + # ID conflict + raise ValueError(f"Duplicate executor ID '{executor.id}' detected in workflow.") + + # New executor + self._executors[executor.id] = executor + # Add an internal edge group for each unique executor + self._edge_groups.append(InternalEdgeGroup(executor.id)) + return executor.id def _maybe_wrap_agent( @@ -1236,7 +1217,6 @@ def build(self) -> Workflow: self._edge_groups, self._executors, self._start_executor, - duplicate_executor_ids=tuple(self._duplicate_executor_ids), ) # Add validation completed event diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index b7e077424e..3ac60b7570 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -2,6 +2,7 @@ import inspect import logging +import uuid from collections.abc import Callable from types import UnionType from typing import Any, Generic, Union, cast, get_args, get_origin @@ -12,6 +13,7 @@ from ..observability import OtelAttr, create_workflow_span from ._events import ( + RequestInfoEvent, WorkflowEvent, WorkflowEventSource, WorkflowFailedEvent, @@ -20,7 +22,7 @@ WorkflowStartedEvent, WorkflowStatusEvent, WorkflowWarningEvent, - _framework_event_origin, + _framework_event_origin, # type: ignore ) from ._runner_context import Message, RunnerContext from ._shared_state import SharedState @@ -404,6 +406,30 @@ async def add_event(self, event: WorkflowEvent) -> None: return await self._runner_context.add_event(event) + async def request_info(self, request_data: Any, request_type: type, response_type: type) -> None: + """Request information from outside of the workflow. + + Calling this method will cause the workflow to emit a RequestInfoEvent, carrying the + provided request_data and request_type. External systems listening for such events + can then process the request and respond accordingly. + + Executors must have the corresponding response handlers defined using the + @response_handler decorator to handle the incoming responses. + + Args: + request_data: The data associated with the information request. + request_type: The type of the request, used to match with response handlers. + response_type: The expected type of the response, used for validation. + """ + request_info_event = RequestInfoEvent( + request_id=str(uuid.uuid4()), + source_executor_id=self._executor_id, + request_type=request_type, + request_data=request_data, + response_type=response_type, + ) + await self._runner_context.add_request_info_event(request_info_event) + async def get_shared_state(self, key: str) -> Any: """Get a value from the shared state.""" return await self._shared_state.get(key) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 501ce0d8f1..dc69ca992b 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -26,6 +26,7 @@ RequestInfoMessage, RequestResponse, ) +from ._runner_context import Message from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext @@ -276,7 +277,7 @@ def to_dict(self) -> dict[str, Any]: data["workflow"] = self.workflow.to_dict() return data - def can_handle(self, message: Any) -> bool: + def can_handle(self, message: Message) -> bool: """Override can_handle to only accept messages that the wrapped workflow can handle. This prevents the WorkflowExecutor from accepting messages that should go to other @@ -287,7 +288,7 @@ def can_handle(self, message: Any) -> bool: return True # For other messages, only handle if the wrapped workflow can accept them as input - return any(is_instance_of(message, input_type) for input_type in self.workflow.input_types) + return any(is_instance_of(message.data, input_type) for input_type in self.workflow.input_types) @handler # No output_types - can send any completion data type async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) -> None: diff --git a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py index 02aa758bea..761d09cd80 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py @@ -10,16 +10,14 @@ ChatMessage, # Chat message structure Executor, # Base class for workflow executors RequestInfoEvent, # Event emitted when human input is requested - RequestInfoExecutor, # Special executor that collects human input out of band - RequestInfoMessage, # Base class for request payloads sent to RequestInfoExecutor - RequestResponse, # Correlates a human response with the original request Role, # Enum of chat roles (user, assistant, system) WorkflowBuilder, # Fluent builder for assembling the graph WorkflowContext, # Per run context and event bus WorkflowOutputEvent, # Event emitted when workflow yields output WorkflowRunState, # Enum of workflow run states WorkflowStatusEvent, # Event emitted on run state changes - handler, # Decorator to expose an Executor method as a step + handler, + response_handler, # Decorator to expose an Executor method as a step ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential @@ -29,12 +27,12 @@ Sample: Human in the loop guessing game An agent guesses a number, then a human guides it with higher, lower, or -correct via RequestInfoExecutor. The loop continues until the human confirms -correct, at which point the workflow completes when idle with no pending work. +correct. The loop continues until the human confirms correct, at which point +the workflow completes when idle with no pending work. Purpose: -Show how to integrate a human step in the middle of an LLM workflow using RequestInfoExecutor and correlated -RequestResponse objects. +Show how to integrate a human step in the middle of an LLM workflow by using +`request_info` and `send_responses_streaming`. Demonstrate: - Alternating turns between an AgentExecutor and a human, driven by events. @@ -47,27 +45,20 @@ - Basic familiarity with WorkflowBuilder, executors, edges, events, and streaming runs. """ -# What RequestInfoExecutor does: -# RequestInfoExecutor is a workflow-native bridge that pauses the graph at a request for information, -# emits a RequestInfoEvent with a typed payload, and then resumes the graph only after your application -# supplies a matching RequestResponse keyed by the emitted request_id. It does not gather input by itself. -# Your application is responsible for collecting the human reply from any UI or CLI and then calling -# send_responses_streaming with a dict mapping request_id to the human's answer. The executor exists to -# standardize pause-and-resume human gating, to carry typed request payloads, and to preserve correlation. - - -# Request type sent to the RequestInfoExecutor for human feedback. -# Including the agent's last guess allows the UI or CLI to display context and helps -# the turn manager avoid extra state reads. -# Why subclass RequestInfoMessage: -# Subclassing RequestInfoMessage defines the exact schema of the request that the human will see. -# This gives you strong typing, forward-compatible validation, and clear correlation semantics. -# It also lets you attach contextual fields (such as the previous guess) so the UI can render a rich prompt -# without fetching extra state from elsewhere. +# How human-in-the-loops is achieved via `request_info` and `send_responses_streaming`: +# - An executor (TurnManager) calls `ctx.request_info` with a payload (HumanFeedbackRequest). +# - The workflow run pauses and emits a RequestInfoEvent with the payload and the request_id. +# - The application captures the event, prompts the user, and collects replies. +# - The application calls `send_responses_streaming` with a map of request_ids to replies. +# - The workflow resumes, and the response is delivered to the executor method decorated with @response_handler. +# - The executor can then continue the workflow, e.g., by sending a new message to the agent. + + @dataclass -class HumanFeedbackRequest(RequestInfoMessage): - prompt: str = "" - guess: int | None = None +class HumanFeedbackRequest: + """Request sent to the human for feedback on the agent's guess.""" + + prompt: str class GuessOutput(BaseModel): @@ -103,47 +94,38 @@ async def start(self, _: str, ctx: WorkflowContext[AgentExecutorRequest]) -> Non async def on_agent_response( self, result: AgentExecutorResponse, - ctx: WorkflowContext[HumanFeedbackRequest], + ctx: WorkflowContext, ) -> None: """Handle the agent's guess and request human guidance. Steps: 1) Parse the agent's JSON into GuessOutput for robustness. - 2) Send a HumanFeedbackRequest to the RequestInfoExecutor with a clear instruction: - - higher means the human's secret number is higher than the agent's guess. - - lower means the human's secret number is lower than the agent's guess. - - correct confirms the guess is exactly right. - - exit quits the demo. + 2) Request info with a HumanFeedbackRequest as the payload. """ - # Parse structured model output (defensive default if the agent did not reply). - text = result.agent_run_response.text or "" - last_guess = GuessOutput.model_validate_json(text).guess if text else None + # Parse structured model output + text = result.agent_run_response.text + last_guess = GuessOutput.model_validate_json(text).guess # Craft a precise human prompt that defines higher and lower relative to the agent's guess. prompt = ( - f"The agent guessed: {last_guess if last_guess is not None else text}. " + f"The agent guessed: {last_guess}. " "Type one of: higher (your number is higher than this guess), " "lower (your number is lower than this guess), correct, or exit." ) - await ctx.send_message(HumanFeedbackRequest(prompt=prompt, guess=last_guess)) + # Send a request with a prompt as the payload and expect a string reply. + await ctx.request_info(HumanFeedbackRequest(prompt=prompt), HumanFeedbackRequest, str) - @handler + @response_handler async def on_human_feedback( self, - feedback: RequestResponse[HumanFeedbackRequest, str], + feedback: str, ctx: WorkflowContext[AgentExecutorRequest, str], ) -> None: - """Continue the game or finish based on human feedback. - - The RequestResponse contains both the human's string reply and the correlated HumanFeedbackRequest, - which carries the prior guess for convenience. - """ - reply = (feedback.data or "").strip().lower() - # Prefer the correlated request's guess to avoid extra shared state reads. - last_guess = getattr(feedback.original_request, "guess", None) + """Continue the game or finish based on human feedback.""" + reply = feedback.strip().lower() if reply == "correct": - await ctx.yield_output(f"Guessed correctly: {last_guess}") + await ctx.yield_output("Guessed correctly!") return # Provide feedback to the agent to try again. @@ -166,35 +148,24 @@ async def main() -> None: 'You MUST return ONLY a JSON object exactly matching this schema: {"guess": }. ' "No explanations or additional text." ), + # Structured output enforced via Pydantic model. response_format=GuessOutput, ) - # Build a simple loop: TurnManager <-> AgentExecutor <-> RequestInfoExecutor. - # TurnManager coordinates, AgentExecutor runs the model, RequestInfoExecutor gathers human replies. + # Build a simple loop: TurnManager <-> AgentExecutor. + # TurnManager coordinates and gathers human replies while AgentExecutor runs the model. turn_manager = TurnManager(id="turn_manager") agent_exec = AgentExecutor(agent=agent, id="agent") - # Naming note: - # This variable is currently named hitl for historical reasons. The name can feel ambiguous or magical. - # Consider renaming to request_info_executor in your own code for clarity, since it directly represents - # the RequestInfoExecutor node that gathers human replies out of band. - hitl = RequestInfoExecutor(id="request_info") - - top_builder = ( + workflow = ( WorkflowBuilder() .set_start_executor(turn_manager) .add_edge(turn_manager, agent_exec) # Ask agent to make/adjust a guess .add_edge(agent_exec, turn_manager) # Agent's response comes back to coordinator - .add_edge(turn_manager, hitl) # Ask human for guidance - .add_edge(hitl, turn_manager) # Feed human guidance back to coordinator - ) - - # Build the workflow (no checkpointing in this minimal sample). - workflow = top_builder.build() + ).build() # Human in the loop run: alternate between invoking the workflow and supplying collected responses. pending_responses: dict[str, str] | None = None - completed = False workflow_output: str | None = None # User guidance printing: @@ -206,7 +177,7 @@ async def main() -> None: # flush=True, # ) - while not completed: + while not workflow_output: # First iteration uses run_stream("start"). # Subsequent iterations use send_responses_streaming with pending_responses from the console. stream = ( @@ -228,7 +199,6 @@ async def main() -> None: elif isinstance(event, WorkflowOutputEvent): # Capture workflow output as they're yielded workflow_output = str(event.data) - completed = True # In this sample, we finish after one output. # Detect run state transitions for a better developer experience. pending_status = any( @@ -245,7 +215,7 @@ async def main() -> None: print("State: IDLE_WITH_PENDING_REQUESTS (awaiting human input)") # If we have any human requests, prompt the user and prepare responses. - if requests and not completed: + if requests: responses: dict[str, str] = {} for req_id, prompt in requests: # Simple console prompt for the sample. From 309ceb4f33a3f2d4686af01f47459beab2e9bf8f Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 15 Oct 2025 10:34:53 -0700 Subject: [PATCH 02/26] Add original_request as a parameter to the response handler --- .../agent_framework/_workflows/_executor.py | 64 +++++++++++-- .../_workflows/_function_executor.py | 75 +++++++++++---- .../_workflows/_request_info_mixin.py | 96 ++++++++++++++++--- .../_workflows/_runner_context.py | 23 ++--- .../agent_framework/_workflows/_validation.py | 4 +- .../_workflows/_workflow_context.py | 63 ------------ .../guessing_game_with_human_input.py | 3 + 7 files changed, 211 insertions(+), 117 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index dc7261ab71..7e07956a35 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -17,10 +17,10 @@ ) from ._model_utils import DictConvertible from ._request_info_mixin import RequestInfoMixin -from ._runner_context import Message, MessageType, RunnerContext +from ._runner_context import Message, ResponseMessage, RunnerContext from ._shared_state import SharedState from ._typing_utils import is_instance_of -from ._workflow_context import WorkflowContext, validate_function_signature +from ._workflow_context import WorkflowContext, validate_workflow_context_annotation logger = logging.getLogger(__name__) @@ -236,11 +236,14 @@ async def execute( # Default to find handler in regular handlers target_handlers = self._handlers + if isinstance(message, ResponseMessage): + # Wrap the response handlers to include original_request parameter + target_handlers = { + message_type: functools.partial(handler, message.original_request) + for message_type, handler in self._response_handlers.items() + } # Handle case where Message wrapper is passed instead of raw data if isinstance(message, Message): - if message.message_type == MessageType.RESPONSE: - # Switch to response handlers if message is a response - target_handlers = self._response_handlers message = message.data # Create processing span for tracing (gracefully handles disabled tracing) @@ -359,11 +362,10 @@ def can_handle(self, message: Message) -> bool: Returns: True if the executor can handle the message type, False otherwise. """ - if message.message_type == MessageType.REGULAR: - return any(is_instance_of(message.data, message_type) for message_type in self._handlers) - if message.message_type == MessageType.RESPONSE: + if isinstance(message, ResponseMessage): return any(is_instance_of(message.data, message_type) for message_type in self._response_handlers) - return False + + return any(is_instance_of(message.data, message_type) for message_type in self._handlers) def _register_instance_handler( self, @@ -484,7 +486,7 @@ def decorator( ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[Any]]: # Extract the message type and validate using unified validation message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( - validate_function_signature(func, "Handler method") + _validate_handler_signature(func) ) # Get signature for preservation @@ -514,3 +516,45 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: # endregion: Handler Decorator + +# region Handler Validation + + +def _validate_handler_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: + """Validate function signature for executor functions. + + Args: + func: The function to validate + + Returns: + Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + + Raises: + ValueError: If the function signature is invalid + """ + signature = inspect.signature(func) + params = list(signature.parameters.values()) + + expected_counts = 3 # self, message, ctx + param_description = "(self, message: T, ctx: WorkflowContext[U, V])" + if len(params) != expected_counts: + raise ValueError(f"Handler {func.__name__} must have {param_description}. Got {len(params)} parameters.") + + # Check message parameter has type annotation + message_param = params[1] + if message_param.annotation == inspect.Parameter.empty: + raise ValueError(f"Handler {func.__name__} must have a type annotation for the message parameter") + + # Validate ctx parameter is WorkflowContext and extract type args + ctx_param = params[2] + output_types, workflow_output_types = validate_workflow_context_annotation( + ctx_param.annotation, f"parameter '{ctx_param.name}'", "Handler" + ) + + message_type = message_param.annotation + ctx_annotation = ctx_param.annotation + + return message_type, ctx_annotation, output_types, workflow_output_types + + +# endregion: Handler Validation diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index 223dd9c05b..eaeaf15f63 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -11,11 +11,12 @@ """ import asyncio +import inspect from collections.abc import Awaitable, Callable from typing import Any, overload from ._executor import Executor -from ._workflow_context import WorkflowContext, validate_function_signature +from ._workflow_context import WorkflowContext, validate_workflow_context_annotation class FunctionExecutor(Executor): @@ -28,21 +29,6 @@ class FunctionExecutor(Executor): blocking the event loop. """ - @staticmethod - def _validate_function(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: - """Validate that the function has the correct signature for an executor. - - Args: - func: The function to validate (can be sync or async) - - Returns: - Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) - - Raises: - ValueError: If the function signature is incorrect - """ - return validate_function_signature(func, "Function") - def __init__(self, func: Callable[..., Any], id: str | None = None): """Initialize the FunctionExecutor with a user-defined function. @@ -51,7 +37,7 @@ def __init__(self, func: Callable[..., Any], id: str | None = None): id: Optional executor ID. If None, uses the function name. """ # Validate function signature and extract types - message_type, ctx_annotation, output_types, workflow_output_types = self._validate_function(func) + message_type, ctx_annotation, output_types, workflow_output_types = _validate_function_signature(func) # Determine if function has WorkflowContext parameter has_context = ctx_annotation is not None @@ -112,6 +98,9 @@ async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: ) +# region Decorator + + @overload def executor(func: Callable[..., Any]) -> FunctionExecutor: ... @@ -164,3 +153,55 @@ def wrapper(func: Callable[..., Any]) -> FunctionExecutor: # Otherwise, return the wrapper for @executor() or @executor(id="...") return wrapper + + +# endregion: Decorator + +# region Function Validation + + +def _validate_function_signature(func: Callable[..., Any]) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: + """Validate function signature for executor functions. + + Args: + func: The function to validate + + Returns: + Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + + Raises: + ValueError: If the function signature is invalid + """ + signature = inspect.signature(func) + params = list(signature.parameters.values()) + + expected_counts = (1, 2) # Function executor: (message) or (message, ctx) + param_description = "(message: T) or (message: T, ctx: WorkflowContext[U])" + if len(params) not in expected_counts: + raise ValueError( + f"Function instance {func.__name__} must have {param_description}. Got {len(params)} parameters." + ) + + # Check message parameter has type annotation + message_param = params[0] + if message_param.annotation == inspect.Parameter.empty: + raise ValueError(f"Function instance {func.__name__} must have a type annotation for the message parameter") + + message_type = message_param.annotation + + # Check if there's a context parameter + if len(params) == 2: + ctx_param = params[1] + output_types, workflow_output_types = validate_workflow_context_annotation( + ctx_param.annotation, f"parameter '{ctx_param.name}'", "Function instance" + ) + ctx_annotation = ctx_param.annotation + else: + # No context parameter (only valid for function executors) + output_types, workflow_output_types = [], [] + ctx_annotation = None + + return message_type, ctx_annotation, output_types, workflow_output_types + + +# endregion: Function Validation diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index 1396382481..56146222e7 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -8,7 +8,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any, ClassVar, TypeVar -from ._workflow_context import WorkflowContext, validate_function_signature +from ._workflow_context import WorkflowContext, validate_workflow_context_annotation if TYPE_CHECKING: from ._executor import Executor @@ -27,7 +27,7 @@ def _discover_response_handlers(self) -> None: # Initialize handler storage if not already present if not hasattr(self, "_response_handlers"): self._response_handlers: dict[ - builtin_type[Any], Callable[[Any, WorkflowContext[Any, Any]], Awaitable[None]] + builtin_type[Any], Callable[[Any, Any, WorkflowContext[Any, Any]], Awaitable[None]] ] = {} if not hasattr(self, "_response_handler_specs"): self._response_handler_specs: list[dict[str, Any]] = [] @@ -60,10 +60,12 @@ def _discover_response_handlers(self) -> None: ExecutorT = TypeVar("ExecutorT", bound="Executor") ContextT = TypeVar("ContextT", bound="WorkflowContext[Any, Any]") +# region Handler Decorator + def response_handler( - func: Callable[[ExecutorT, Any, ContextT], Awaitable[None]], -) -> Callable[[ExecutorT, Any, ContextT], Awaitable[None]]: + func: Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]], +) -> Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]: """Decorator to register a handler to handle responses for a request. Args: @@ -73,29 +75,48 @@ def response_handler( The decorated function with handler metadata. Example: + @handler + async def run(self, message: int, context: WorkflowContext[str]) -> None: + # Example of a handler that sends a request + ... + # Send a request with a `CustomRequest` payload and expect a `str` response. + await context.request_info(CustomRequest(...), CustomRequest, str) + @response_handler - async def handle_response(self, response: str, context: WorkflowContext[str]) -> None: + async def handle_response( + self, + original_request: CustomRequest, + response: str, + context: WorkflowContext[str], + ) -> None: + # Example of a response handler for the above request ... @response_handler - async def handle_response(self, response: dict, context: WorkflowContext[int]) -> None: + async def handle_response( + self, + original_request: CustomRequest, + response: dict, + context: WorkflowContext[int], + ) -> None: + # Example of a response handler for a request expecting a dict response ... """ def decorator( - func: Callable[[ExecutorT, Any, ContextT], Awaitable[None]], - ) -> Callable[[ExecutorT, Any, ContextT], Awaitable[None]]: + func: Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]], + ) -> Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]: message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( - validate_function_signature(func, "Handler method") + _validate_response_handler_signature(func) ) # Get signature for preservation sig = inspect.signature(func) @functools.wraps(func) - async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: + async def wrapper(self: ExecutorT, original_request: Any, message: Any, ctx: ContextT) -> Any: """Wrapper function to call the handler.""" - return await func(self, message, ctx) + return await func(self, original_request, message, ctx) # Preserve the original function signature for introspection during validation with contextlib.suppress(AttributeError, TypeError): @@ -113,3 +134,56 @@ async def wrapper(self: ExecutorT, message: Any, ctx: ContextT) -> Any: return wrapper return decorator(func) + + +# endregion: Handler Decorator + +# region Response Handler Validation + + +def _validate_response_handler_signature( + func: Callable[..., Any], +) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: + """Validate function signature for executor functions. + + Args: + func: The function to validate + + Returns: + Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + + Raises: + ValueError: If the function signature is invalid + """ + signature = inspect.signature(func) + params = list(signature.parameters.values()) + + # Note that the original_request parameter must be the second parameter + # such that we can wrap the handler with functools.partial to bind it + # to the original request when registering the handler, while maintaining + # the order of parameters as if the response handler is a normal handler. + expected_counts = 4 # self, original_request, message, ctx + param_description = "(self, original_request: Any, message: T, ctx: WorkflowContext[U, V])" + if len(params) != expected_counts: + raise ValueError( + f"Response handler {func.__name__} must have {param_description}. Got {len(params)} parameters." + ) + + # Check message parameter has type annotation + message_param = params[2] + if message_param.annotation == inspect.Parameter.empty: + raise ValueError(f"Response handler {func.__name__} must have a type annotation for the message parameter") + + # Validate ctx parameter is WorkflowContext and extract type args + ctx_param = params[3] + output_types, workflow_output_types = validate_workflow_context_annotation( + ctx_param.annotation, f"parameter '{ctx_param.name}'", "Response handler" + ) + + message_type = message_param.annotation + ctx_annotation = ctx_param.annotation + + return message_type, ctx_annotation, output_types, workflow_output_types + + +# endregion: Response Handler Validation diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 9aaa479d2f..c02245c123 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -8,7 +8,6 @@ import uuid from copy import copy from dataclasses import dataclass, fields, is_dataclass -from enum import Enum from typing import Any, Protocol, TypedDict, TypeVar, cast, runtime_checkable from ._checkpoint import CheckpointStorage, WorkflowCheckpoint @@ -21,16 +20,6 @@ T = TypeVar("T") -class MessageType(Enum): - """Enum representing different types of messages in the workflow.""" - - RESPONSE = "response" - """A response message for a pending request.""" - - REGULAR = "regular" - """A regular message between executors.""" - - @dataclass class Message: """A class representing a message in the workflow.""" @@ -38,7 +27,6 @@ class Message: data: Any source_id: str target_id: str | None = None - message_type: MessageType = MessageType.REGULAR # OpenTelemetry trace context fields for message propagation # These are plural to support fan-in scenarios where multiple messages are aggregated @@ -57,6 +45,13 @@ def source_span_id(self) -> str | None: return self.source_span_ids[0] if self.source_span_ids else None +@dataclass +class ResponseMessage(Message): + """A message representing a response to a pending request.""" + + original_request: Any = None + + class CheckpointState(TypedDict): messages: dict[str, list[dict[str, Any]]] shared_state: dict[str, Any] @@ -724,11 +719,11 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No ) await self.send_message( - Message( + ResponseMessage( data=response, source_id=INTERNAL_SOURCE_ID(event.source_executor_id), target_id=event.source_executor_id, - message_type=MessageType.RESPONSE, + original_request=event.data, ) ) diff --git a/python/packages/core/agent_framework/_workflows/_validation.py b/python/packages/core/agent_framework/_workflows/_validation.py index 3b93e7ad84..6c2ef35681 100644 --- a/python/packages/core/agent_framework/_workflows/_validation.py +++ b/python/packages/core/agent_framework/_workflows/_validation.py @@ -180,8 +180,8 @@ def _validate_handler_output_annotations(self) -> None: decorator is applied. This method is kept minimal for any edge cases. """ # The comprehensive validation is already done during handler registration: - # 1. @handler decorator calls validate_function_signature() - # 2. FunctionExecutor constructor calls validate_function_signature() + # 1. @handler and @response_handler decorators already have validation logic + # 2. FunctionExecutor constructor also has validation logic # 3. Both use validate_workflow_context_annotation() for WorkflowContext validation # # All executors in the workflow must have gone through one of these paths, diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index 3ac60b7570..735dd4f4ad 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -3,7 +3,6 @@ import inspect import logging import uuid -from collections.abc import Callable from types import UnionType from typing import Any, Generic, Union, cast, get_args, get_origin @@ -199,68 +198,6 @@ def _is_type_like(x: Any) -> bool: return infer_output_types_from_ctx_annotation(annotation) -def validate_function_signature( - func: Callable[..., Any], context_description: str -) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: - """Validate function signature for executor functions. - - Args: - func: The function to validate - context_description: Description for error messages (e.g., "Function", "Handler method") - - Returns: - Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) - - Raises: - ValueError: If the function signature is invalid - """ - signature = inspect.signature(func) - params = list(signature.parameters.values()) - - # Determine expected parameter count based on context - expected_counts: tuple[int, ...] - if context_description.startswith("Function"): - # Function executor: (message) or (message, ctx) - expected_counts = (1, 2) - param_description = "(message: T) or (message: T, ctx: WorkflowContext[U])" - else: - # Handler method: (self, message, ctx) - expected_counts = (3,) - param_description = "(self, message: T, ctx: WorkflowContext[U])" - - if len(params) not in expected_counts: - raise ValueError( - f"{context_description} {func.__name__} must have {param_description}. Got {len(params)} parameters." - ) - - # Extract message parameter (index 0 for functions, index 1 for methods) - message_param_idx = 0 if context_description.startswith("Function") else 1 - message_param = params[message_param_idx] - - # Check message parameter has type annotation - if message_param.annotation == inspect.Parameter.empty: - raise ValueError(f"{context_description} {func.__name__} must have a type annotation for the message parameter") - - message_type = message_param.annotation - - # Check if there's a context parameter - ctx_param_idx = message_param_idx + 1 - if len(params) > ctx_param_idx: - ctx_param = params[ctx_param_idx] - output_types, workflow_output_types = validate_workflow_context_annotation( - ctx_param.annotation, f"parameter '{ctx_param.name}'", context_description - ) - ctx_annotation = ctx_param.annotation - else: - # No context parameter (only valid for function executors) - if not context_description.startswith("Function"): - raise ValueError(f"{context_description} {func.__name__} must have a WorkflowContext parameter") - output_types, workflow_output_types = [], [] - ctx_annotation = None - - return message_type, ctx_annotation, output_types, workflow_output_types - - _FRAMEWORK_LIFECYCLE_EVENT_TYPES: tuple[type[WorkflowEvent], ...] = cast( tuple[type[WorkflowEvent], ...], tuple(get_args(WorkflowLifecycleEvent)) diff --git a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py index 761d09cd80..24de2c8ffb 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py @@ -118,10 +118,13 @@ async def on_agent_response( @response_handler async def on_human_feedback( self, + original_request: HumanFeedbackRequest, feedback: str, ctx: WorkflowContext[AgentExecutorRequest, str], ) -> None: """Continue the game or finish based on human feedback.""" + print(f"Feedback for prompt '{original_request.prompt}' received: {feedback}") + reply = feedback.strip().lower() if reply == "correct": From 334016f7b116ffd71cd85b50b4e37755cdf3013a Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 15 Oct 2025 14:49:05 -0700 Subject: [PATCH 03/26] Prototype: request interception in sub workflows --- .../agent_framework/_workflows/__init__.py | 4 +- .../agent_framework/_workflows/__init__.pyi | 4 +- .../_workflows/_workflow_executor.py | 96 +++- .../sub_workflow_request_interception.py | 482 +++++++++--------- 4 files changed, 326 insertions(+), 260 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index bd8cdfe3ac..654e9ed56e 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -100,7 +100,7 @@ from ._viz import WorkflowViz from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult from ._workflow_context import WorkflowContext -from ._workflow_executor import WorkflowExecutor +from ._workflow_executor import SubWorkflowRequestMessage, SubWorkflowResponseMessage, WorkflowExecutor __all__ = [ "DEFAULT_MAX_ITERATIONS", @@ -158,6 +158,8 @@ "SharedState", "SingleEdgeGroup", "StandardMagenticManager", + "SubWorkflowRequestMessage", + "SubWorkflowResponseMessage", "SwitchCaseEdgeGroup", "SwitchCaseEdgeGroupCase", "SwitchCaseEdgeGroupDefault", diff --git a/python/packages/core/agent_framework/_workflows/__init__.pyi b/python/packages/core/agent_framework/_workflows/__init__.pyi index b6de88b5ba..926af9c3d0 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.pyi +++ b/python/packages/core/agent_framework/_workflows/__init__.pyi @@ -98,7 +98,7 @@ from ._validation import ( from ._viz import WorkflowViz from ._workflow import Workflow, WorkflowBuilder, WorkflowRunResult from ._workflow_context import WorkflowContext -from ._workflow_executor import WorkflowExecutor +from ._workflow_executor import SubWorkflowRequestMessage, SubWorkflowResponseMessage, WorkflowExecutor __all__ = [ "DEFAULT_MAX_ITERATIONS", @@ -156,6 +156,8 @@ __all__ = [ "SharedState", "SingleEdgeGroup", "StandardMagenticManager", + "SubWorkflowRequestMessage", + "SubWorkflowResponseMessage", "SwitchCaseEdgeGroup", "SwitchCaseEdgeGroupCase", "SwitchCaseEdgeGroupDefault", diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index dc69ca992b..12b800d3f6 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -40,7 +40,52 @@ class ExecutionContext: execution_id: str collected_responses: dict[str, Any] # request_id -> response_data expected_response_count: int - pending_requests: dict[str, Any] # request_id -> original request data + pending_requests: dict[str, RequestInfoEvent] # request_id -> request_info_event + + +@dataclass +class SubWorkflowResponseMessage: + """Message sent from a parent workflow to a sub-workflow via WorkflowExecutor to provide requested information. + + This message wraps a RequestResponse emitted by the parent workflow. + + Attributes: + response: The response to the original request. + """ + + data: Any + source_event: RequestInfoEvent + + +@dataclass +class SubWorkflowRequestMessage: + """Message sent from a sub-workflow to an executor in the parent workflow to request information. + + This message wraps a RequestInfoEvent emitted by the executor in the sub-workflow. + + Attributes: + source_event: The original RequestInfoEvent emitted by the sub-workflow executor. + executor_id: The ID of the WorkflowExecutor in the parent workflow that is + responsible for this sub-workflow. This can be used to ensure that the response + is sent back to the correct sub-workflow instance. + """ + + source_event: RequestInfoEvent + executor_id: str + + def create_response(self, data: Any) -> SubWorkflowResponseMessage: + """Validate and wrap response data into a SubWorkflowResponseMessage. + + Validation ensures the response data type matches the expected type from the original request. + """ + expected_data_type = self.source_event.response_type + if not is_instance_of(data, expected_data_type): + raise TypeError( + f"Response data type {type(data)} does not match expected type {expected_data_type} " + f"for request_id {self.source_event.request_id}" + ) + + return SubWorkflowResponseMessage(data=data, source_event=self.source_event) class WorkflowExecutor(Executor): @@ -236,9 +281,9 @@ def input_types(self) -> list[type[Any]]: """ input_types = list(self.workflow.input_types) - # WorkflowExecutor can also handle RequestResponse for sub-workflow responses - if RequestResponse not in input_types: - input_types.append(RequestResponse) + # WorkflowExecutor can also handle SubWorkflowResponseMessage for sub-workflow responses + if SubWorkflowResponseMessage not in input_types: + input_types.append(SubWorkflowResponseMessage) return input_types @@ -281,10 +326,10 @@ def can_handle(self, message: Message) -> bool: """Override can_handle to only accept messages that the wrapped workflow can handle. This prevents the WorkflowExecutor from accepting messages that should go to other - executors (like RequestInfoExecutor). + executors because the handler `process_workflow` has no type restrictions. """ - # Always handle RequestResponse (for the handle_response handler) - if isinstance(message, RequestResponse): + # Always handle SubWorkflowResponseMessage + if isinstance(message.data, SubWorkflowResponseMessage): return True # For other messages, only handle if the wrapped workflow can accept them as input @@ -372,15 +417,16 @@ async def _process_workflow_result( # Process request info events for event in request_info_events: # Track the pending request in execution context - execution_context.pending_requests[event.request_id] = event.data + execution_context.pending_requests[event.request_id] = event # Map request to execution for response routing self._request_to_execution[event.request_id] = execution_context.execution_id - # Set source_executor_id for response routing and send to parent - if not isinstance(event.data, RequestInfoMessage): - raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}") - # Set the source_executor_id to this WorkflowExecutor's ID for response routing - event.data.source_executor_id = self.id - await ctx.send_message(event.data) + # TODO(@taochen): There should be two ways a sub-workflow can make a request: + # 1. In a workflow where the parent workflow has an executor that may intercept the + # request and handle it directly, a message should be sent. + # 2. In a workflow where the parent workflow does not handle the request, the request + # should be propagated via the `request_info` mechanism to an external source. And + # a @response_handler would be required in the WorkflowExecutor to handle the response. + await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) # Update expected response count for this execution execution_context.expected_response_count = len(request_info_events) @@ -429,11 +475,7 @@ async def _process_workflow_result( await self._persist_execution_state(ctx) @handler - async def handle_response( - self, - response: RequestResponse[RequestInfoMessage, Any], - ctx: WorkflowContext[Any], - ) -> None: + async def handle_response(self, response: SubWorkflowResponseMessage, ctx: WorkflowContext[Any]) -> None: """Handle response from parent for a forwarded request. This handler accumulates responses and only resumes the sub-workflow @@ -446,29 +488,31 @@ async def handle_response( await self._ensure_state_loaded(ctx) # Find the execution context for this request - execution_id = self._request_to_execution.get(response.request_id) + original_request = response.source_event + execution_id = self._request_to_execution.get(original_request.request_id) if not execution_id or execution_id not in self._execution_contexts: logger.warning( - f"WorkflowExecutor {self.id} received response for unknown request_id: {response.request_id}, ignoring" + f"WorkflowExecutor {self.id} received response for unknown request_id: {original_request.request_id}. " + "This response will be ignored." ) return execution_context = self._execution_contexts[execution_id] # Check if we have this pending request in the execution context - if response.request_id not in execution_context.pending_requests: + if original_request.request_id not in execution_context.pending_requests: logger.warning( f"WorkflowExecutor {self.id} received response for unknown request_id: " - f"{response.request_id} in execution {execution_id}, ignoring" + f"{original_request.request_id} in execution {execution_id}, ignoring" ) return # Remove the request from pending list and request mapping - execution_context.pending_requests.pop(response.request_id, None) - self._request_to_execution.pop(response.request_id, None) + execution_context.pending_requests.pop(original_request.request_id, None) + self._request_to_execution.pop(original_request.request_id, None) # Accumulate the response in this execution's context - execution_context.collected_responses[response.request_id] = response.data + execution_context.collected_responses[original_request.request_id] = response.data await self._persist_execution_state(ctx) diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py index fae0f80e5f..d3fcf75f89 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py @@ -2,292 +2,310 @@ import asyncio from dataclasses import dataclass +from typing import Never from agent_framework import ( Executor, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, + SubWorkflowRequestMessage, + SubWorkflowResponseMessage, + Workflow, WorkflowBuilder, WorkflowContext, WorkflowExecutor, + WorkflowOutputEvent, handler, + response_handler, ) """ -Sample: Sub-Workflows with Request Interception - -This sample shows how to: -1. Create workflows that execute other workflows as sub-workflows -2. Intercept requests from sub-workflows using an executor with @handler for RequestInfoMessage subclasses -3. Conditionally handle or forward requests using RequestResponse messages -4. Handle external requests that are forwarded by the parent workflow -5. Proper request/response correlation for concurrent processing - -The example simulates an email validation system where: -- Sub-workflows validate multiple email addresses concurrently -- Parent workflows can intercept domain check requests for optimization -- Known domains (example.com, company.com) are approved locally -- Unknown domains (unknown.org) are forwarded to external services -- Request correlation ensures each email gets the correct domain check response -- External domain check requests are processed and responses routed back correctly - -Key concepts demonstrated: -- WorkflowExecutor: Wraps a workflow to make it behave as an executor -- RequestInfoMessage handler: @handler method to intercept sub-workflow requests -- Request correlation: Using request_id and source_executor_id to match responses with original requests -- Concurrent processing: Multiple emails processed simultaneously without interference -- External request routing: RequestInfoExecutor handles forwarded external requests -- Sub-workflow isolation: Sub-workflows work normally without knowing they're nested -- Sub-workflows complete by yielding outputs when validation is finished - -Prerequisites: -- No external services required (external calls are simulated via `RequestInfoExecutor`). - -Simple flow visualization: - - Parent Orchestrator (handles DomainCheckRequest) - | - | EmailValidationRequest(email) x3 (concurrent) - v - [ Sub-workflow: WorkflowExecutor(EmailValidator) ] - | - | DomainCheckRequest(domain) with request_id and source_executor_id - v - Interception? yes -> handled locally with RequestResponse(data=True) - no -> forwarded to RequestInfoExecutor -> external service - | - v - Response routed back to sub-workflow using source_executor_id +This sample demostrates how to handle request from the sub-workflow in the main workflow. + +Prerequisite: +- Understanding of sub-workflows. +- Understanding of requests and responses. + +This pattern is useful when you want to reuse a workflow that makes requests to an external system, +but you want to intercept those requests in the main workflow and handle them without further propagation +to the external system. + +This sample implements a smart email delivery system that validates email addresses before sending emails. +1. We will start by creating a workflow that validates email addresses in a sequential manner. The validation + consists of three steps: sanitization, format validation, and domain validation. The domain validation + step will involve checking if the email domain is valid by making a request to an external system. +2. Then we will create a main workflow that uses the email validation workflow as a sub-workflow. The main + workflow will intercept the domain validation requests from the sub-workflow and handle them internally + without propagating them to an external system. +3. Once the email address is validated, the main workflow will proceed to send the email if the address is valid, + or block the email if the address is invalid. """ -# 1. Define domain-specific message types @dataclass -class EmailValidationRequest: - """Request to validate an email address.""" +class SanitizedEmailResult: + """Result of email sanitization and validation. - email: str + The properties get built up as the email address goes through + the validation steps in the workflow. + """ + original: str + sanitized: str + is_valid: bool -@dataclass -class DomainCheckRequest(RequestInfoMessage): - """Request to check if a domain is approved.""" - domain: str = "" +def build_email_address_validation_workflow() -> Workflow: + """Build a email address validation workflow. + + This workflow consists of three steps (exach is represented by an executor): + 1. Sanitize the email address, such as removing leading/trailing spaces. + 2. Validate the email address format, such as checking for "@" and domain. + 3. Extract the domain from the email address and request domain validation, + after which it completes with the final result. + """ + + class EmailSanitizer(Executor): + """Sanitize email address by trimming spaces.""" + + @handler + async def handle(self, email_address: str, ctx: WorkflowContext[SanitizedEmailResult]) -> None: + """Trim leading and trailing spaces from the email address. + + This executor doesn't produce any workflow output, but sends the sanitized + email address to the next executor in the workflow. + """ + sanitized = email_address.strip() + print(f"✂️ Sanitized email address: '{sanitized}'") + await ctx.send_message(SanitizedEmailResult(original=email_address, sanitized=sanitized, is_valid=False)) + + class EmailFormatValidator(Executor): + """Validate email address format.""" + + @handler + async def handle( + self, + partial_result: SanitizedEmailResult, + ctx: WorkflowContext[SanitizedEmailResult, SanitizedEmailResult], + ) -> None: + """Validate the email address format. + + This executor can potentially produce a workflow output (False if the format is invalid). + When the format is valid, it sends the validated email address to the next executor in the workflow. + """ + if "@" not in partial_result.sanitized or "." not in partial_result.sanitized.split("@")[-1]: + print(f"❌ Invalid email format: '{partial_result.sanitized}'") + await ctx.yield_output( + SanitizedEmailResult( + original=partial_result.original, sanitized=partial_result.sanitized, is_valid=False + ) + ) + return + print(f"✅ Validated email format: '{partial_result.sanitized}'") + await ctx.send_message( + SanitizedEmailResult( + original=partial_result.original, sanitized=partial_result.sanitized, is_valid=False + ) + ) + + class DomainValidator(Executor): + """Validate email domain.""" + + def __init__(self, id: str): + super().__init__(id=id) + self._pending_domains: dict[str, SanitizedEmailResult] = {} + + @handler + async def handle(self, partial_result: SanitizedEmailResult, ctx: WorkflowContext) -> None: + """Extract the domain from the email address and request domain validation. + + This executor doesn't produce any workflow output, but sends a domain validation request + to an external system to user for validation. + """ + domain = partial_result.sanitized.split("@")[-1] + print(f"🔍 Validating domain: '{domain}'") + self._pending_domains[domain] = partial_result + # Send a request to the external system via the request_info mechanism + await ctx.request_info(domain, str, bool) + + @response_handler + async def handle_domain_validation_response( + self, original_request: str, is_valid: bool, ctx: WorkflowContext[Never, SanitizedEmailResult] + ) -> None: + """Handle the domain validation response. + + This method receives the response from the external system and yields the final + validation result (True if both format and domain are valid, False otherwise). + """ + if original_request not in self._pending_domains: + raise ValueError(f"Received response for unknown domain: '{original_request}'") + partial_result = self._pending_domains.pop(original_request) + if is_valid: + print(f"✅ Domain '{original_request}' is valid.") + await ctx.yield_output( + SanitizedEmailResult( + original=partial_result.original, sanitized=partial_result.sanitized, is_valid=True + ) + ) + else: + print(f"❌ Domain '{original_request}' is invalid.") + await ctx.yield_output( + SanitizedEmailResult( + original=partial_result.original, sanitized=partial_result.sanitized, is_valid=False + ) + ) + + # Build the workflow + sanitizer = EmailSanitizer(id="email_sanitizer") + format_validator = EmailFormatValidator(id="email_format_validator") + domain_validator = DomainValidator(id="domain_validator") + + return ( + WorkflowBuilder() + .set_start_executor(sanitizer) + .add_edge(sanitizer, format_validator) + .add_edge(format_validator, domain_validator) + .build() + ) @dataclass -class ValidationResult: - """Result of email validation.""" +class Email: + recipient: str + subject: str + body: str - email: str - is_valid: bool - reason: str +class SmartEmailOrchestrator(Executor): + """Orchestrates email address validation using a sub-workflow.""" -# 2. Implement the sub-workflow executor (completely standard) -class EmailValidator(Executor): - """Validates email addresses - doesn't know it's in a sub-workflow.""" + def __init__(self, id: str, approved_domains: set[str]): + """Initialize the orchestrator with a set of approved domains. - def __init__(self) -> None: - """Initialize the EmailValidator executor.""" - super().__init__(id="email_validator") - # Use a dict to track multiple pending emails by request_id - self._pending_emails: dict[str, str] = {} + Args: + id: The executor ID. + approved_domains: A set of domains that are considered valid. + """ + super().__init__(id=id) + self._approved_domains = approved_domains + # Keep track of previously approved and disapproved recipients + self._approved_recipients: set[str] = set() + self._disapproved_recipients: set[str] = set() + # Record pending emails waiting for validation results + self._pending_emails: dict[str, Email] = {} @handler - async def validate_request( - self, - request: EmailValidationRequest, - ctx: WorkflowContext[DomainCheckRequest | ValidationResult, ValidationResult], - ) -> None: - """Validate an email address.""" - print(f"🔍 Sub-workflow validating email: {request.email}") + async def run(self, email: Email, ctx: WorkflowContext[Email | str, bool]) -> None: + """Start the email delivery process. - # Extract domain - domain = request.email.split("@")[1] if "@" in request.email else "" - - if not domain: - print(f"❌ Invalid email format: {request.email}") - result = ValidationResult(email=request.email, is_valid=False, reason="Invalid email format") - await ctx.yield_output(result) + This handler receives an Email object. If the recipient has been previously approved, + it sends the email object to the next executor to handle delivery. If the recipient + has been previously disapproved, it yields False as the final result. Otherwise, + it sends the recipient email address to the sub-workflow for validation. + """ + recipient = email.recipient + if recipient in self._approved_recipients: + print(f"📧 Recipient '{recipient}' has been previously approved.") + await ctx.send_message(email) + return + if recipient in self._disapproved_recipients: + print(f"🚫 Blocking email to previously disapproved recipient: '{recipient}'") + await ctx.yield_output(False) return - print(f"🌐 Sub-workflow requesting domain check for: {domain}") - # Request domain check - domain_check = DomainCheckRequest(domain=domain) - # Store the pending email with the request_id for correlation - self._pending_emails[domain_check.request_id] = request.email - await ctx.send_message(domain_check, target_id="email_request_info") + print(f"🔍 Validating new recipient email address: '{recipient}'") + self._pending_emails[recipient] = email + await ctx.send_message(recipient) @handler - async def handle_domain_response( - self, - response: RequestResponse[DomainCheckRequest, bool], - ctx: WorkflowContext[ValidationResult, ValidationResult], + async def handler_domain_validation_request( + self, request: SubWorkflowRequestMessage, ctx: WorkflowContext[SubWorkflowResponseMessage] ) -> None: - """Handle domain check response from RequestInfo with correlation.""" - approved = bool(response.data) - domain = ( - response.original_request.domain - if (hasattr(response, "original_request") and response.original_request) - else "unknown" - ) - print(f"📬 Sub-workflow received domain response for '{domain}': {approved}") - - # Find the corresponding email using the request_id - request_id = ( - response.original_request.request_id - if (hasattr(response, "original_request") and response.original_request) - else None - ) - if request_id and request_id in self._pending_emails: - email = self._pending_emails.pop(request_id) # Remove from pending - result = ValidationResult( - email=email, - is_valid=approved, - reason="Domain approved" if approved else "Domain not approved", - ) - print(f"✅ Sub-workflow completing validation for: {email}") - await ctx.yield_output(result) + """Handle requests from the sub-workflow for domain validation. - -# 3. Implement the parent workflow with request interception -class SmartEmailOrchestrator(Executor): - """Parent orchestrator that can intercept domain checks.""" - - approved_domains: set[str] = set() - - def __init__(self, approved_domains: set[str] | None = None): - """Initialize the SmartEmailOrchestrator with approved domains. - - Args: - approved_domains: Set of pre-approved domains, defaults to example.com, test.org, company.com + Note that the message type must be SubWorkflowRequestMessage to intercept the request. And + the response must be sent back using SubWorkflowResponseMessage to route the response + back to the sub-workflow. """ - super().__init__(id="email_orchestrator", approved_domains=approved_domains) - self._results: list[ValidationResult] = [] - - @handler - async def start_validation(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None: - """Start validating a batch of emails.""" - print(f"📧 Starting validation of {len(emails)} email addresses") - print("=" * 60) - for email in emails: - print(f"📤 Sending '{email}' to sub-workflow for validation") - request = EmailValidationRequest(email=email) - await ctx.send_message(request, target_id="email_validator_workflow") + if not isinstance(request.source_event.data, str): + raise TypeError(f"Expected domain string, got {type(request.source_event.data)}") + domain = request.source_event.data + is_valid = domain in self._approved_domains + print(f"🌐 External domain validation for '{domain}': {'valid' if is_valid else 'invalid'}") + await ctx.send_message(request.create_response(is_valid), target_id=request.executor_id) @handler - async def handle_domain_request( - self, - request: DomainCheckRequest, - ctx: WorkflowContext[RequestResponse[DomainCheckRequest, bool] | DomainCheckRequest], - ) -> None: - """Handle requests from sub-workflows.""" - print(f"🔍 Parent intercepting domain check for: {request.domain}") - - if request.domain in self.approved_domains: - print(f"✅ Domain '{request.domain}' is pre-approved locally!") - # Send response back to sub-workflow - response = RequestResponse(data=True, original_request=request, request_id=request.request_id) - await ctx.send_message(response, target_id=request.source_executor_id) + async def handle_validation_result(self, result: SanitizedEmailResult, ctx: WorkflowContext[Email, bool]) -> None: + """Handle the email address validation result. + + This handler receives the validation result from the sub-workflow. + If the email address is valid, it adds the recipient to the approved list + and sends the email object to the next executor to handle delivery. + If the email address is invalid, it adds the recipient to the disapproved list + and yields False as the final result. + """ + email = self._pending_emails.pop(result.original) + email.recipient = result.sanitized # Use the sanitized email address + if result.is_valid: + print(f"✅ Email address '{result.original}' is valid.") + self._approved_recipients.add(result.original) + await ctx.send_message(email) else: - print(f"❓ Domain '{request.domain}' unknown, forwarding to external service...") - # Forward to external handler - await ctx.send_message(request) + print(f"🚫 Email address '{result.original}' is invalid. Blocking email.") + self._disapproved_recipients.add(result.original) + await ctx.yield_output(False) - @handler - async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) -> None: - """Collect validation results. It comes from the sub-workflow yielded output.""" - status_icon = "✅" if result.is_valid else "❌" - print(f"📥 {status_icon} Validation result: {result.email} -> {result.reason}") - self._results.append(result) - @property - def results(self) -> list[ValidationResult]: - """Get the collected validation results.""" - return self._results +class EmailDelivery(Executor): + """Simulates email delivery.""" + @handler + async def handle(self, email: Email, ctx: WorkflowContext[Never, bool]) -> None: + """Simulate sending the email and yield True as the final result.""" + print(f"📤 Sending email to '{email.recipient}' with subject '{email.subject}'") + await asyncio.sleep(1) # Simulate network delay + print(f"✅ Email sent to '{email.recipient}' successfully.") + await ctx.yield_output(True) -async def run_example() -> None: - """Run the sub-workflow example.""" - print("🚀 Setting up sub-workflow with request interception...") - print() - # 4. Build the sub-workflow - email_validator = EmailValidator() - # Match the target_id used in EmailValidator ("email_request_info") - request_info = RequestInfoExecutor(id="email_request_info") +async def main() -> None: + # A list of approved domains + approved_domains = {"example.com", "company.com"} - validation_workflow = ( - WorkflowBuilder() - .set_start_executor(email_validator) - .add_edge(email_validator, request_info) - .add_edge(request_info, email_validator) - .build() - ) + # Create executors in the main workflow + orchestrator = SmartEmailOrchestrator(id="smart_email_orchestrator", approved_domains=approved_domains) + email_delivery = EmailDelivery(id="email_delivery") - # 5. Build the parent workflow with interception - orchestrator = SmartEmailOrchestrator(approved_domains={"example.com", "company.com"}) - workflow_executor = WorkflowExecutor(validation_workflow, id="email_validator_workflow") - # Add a RequestInfoExecutor to handle forwarded external requests - main_request_info = RequestInfoExecutor(id="main_request_info") + # Create the sub-workflow for email address validation + validation_workflow = build_email_address_validation_workflow() + validation_workflow_executor = WorkflowExecutor(validation_workflow, id="email_validation_workflow") - main_workflow = ( + # Build the main workflow + workflow = ( WorkflowBuilder() .set_start_executor(orchestrator) - .add_edge(orchestrator, workflow_executor) - .add_edge(workflow_executor, orchestrator) # For ValidationResult collection and request interception - # Add edges for external request handling - .add_edge(orchestrator, main_request_info) - .add_edge(main_request_info, workflow_executor) # Route external responses to sub-workflow + .add_edge(orchestrator, validation_workflow_executor) + .add_edge(validation_workflow_executor, orchestrator) + .add_edge(orchestrator, email_delivery) .build() ) - # 6. Prepare test inputs: known domain, unknown domain test_emails = [ - "user@example.com", # Should be intercepted and approved - "admin@company.com", # Should be intercepted and approved - "guest@unknown.org", # Should be forwarded externally + Email(recipient="user1@example.com", subject="Hello User1", body="This is a test email."), + Email(recipient=" user2@invalid", subject="Hello User2", body="This is a test email."), + Email(recipient=" user3@company.com ", subject="Hello User3", body="This is a test email."), + Email(recipient="user4@unknown.com", subject="Hello User4", body="This is a test email."), + # Re-send to an approved recipient + Email(recipient="user1@example.com", subject="Hello User1", body="This is a test email."), + # Re-send to a disapproved recipient + Email(recipient=" user2@invalid", subject="Hello User2", body="This is a test email."), ] - # 7. Run the workflow - result = await main_workflow.run(test_emails) - - # 8. Handle any external requests - request_events = result.get_request_info_events() - if request_events: - print(f"\n🌐 Handling {len(request_events)} external request(s)...") - for event in request_events: - if event.data and hasattr(event.data, "domain"): - print(f"🔍 External domain check needed for: {event.data.domain}") - - # Simulate external responses - external_responses: dict[str, bool] = {} - for event in request_events: - # Simulate external domain checking - if event.data and hasattr(event.data, "domain"): - domain = event.data.domain - # Let's say unknown.org is actually approved externally - approved = domain == "unknown.org" - print(f"🌐 External service response for '{domain}': {'APPROVED' if approved else 'REJECTED'}") - external_responses[event.request_id] = approved - - # 9. Send external responses - await main_workflow.send_responses(external_responses) - else: - print("\n🎯 All requests were intercepted and handled locally!") - - # 10. Display final summary - print("\n📊 Final Results Summary:") - print("=" * 60) - for result in orchestrator.results: - status = "✅ VALID" if result.is_valid else "❌ INVALID" - print(f"{status} {result.email}: {result.reason}") - - print(f"\n🏁 Processed {len(orchestrator.results)} emails total") + # Execute the workflow + for email in test_emails: + print(f"\n🚀 Processing email to '{email.recipient}'") + async for event in workflow.run_stream(email): + if isinstance(event, WorkflowOutputEvent): + print(f"🎉 Final result for '{email.recipient}': {'Delivered' if event.data else 'Blocked'}") if __name__ == "__main__": - asyncio.run(run_example()) + asyncio.run(main()) From a4b928d8df620a034be55b6ab5cd13018e79b2ee Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Thu, 16 Oct 2025 13:30:59 -0700 Subject: [PATCH 04/26] Prototype: request interception in sub workflows 2 --- .../agent_framework/_workflows/_executor.py | 2 +- .../_workflows/_request_info_mixin.py | 5 + .../_workflows/_workflow_executor.py | 24 +- .../sub_workflow_parallel_requests.py | 630 ++++++++---------- 4 files changed, 284 insertions(+), 377 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 7e07956a35..d9f9f65fa0 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -418,7 +418,7 @@ def output_types(self) -> list[type[Any]]: output_types: set[type[Any]] = set() # Collect output types from all handlers - for handler_spec in self._handler_specs: + for handler_spec in self._handler_specs + self._response_handler_specs: handler_output_types = handler_spec.get("output_types", []) output_types.update(handler_output_types) diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index 56146222e7..cbd86de0de 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -56,6 +56,11 @@ def _discover_response_handlers(self) -> None: except AttributeError: continue # Skip non-callable attributes or those without handler spec + # A request sent via `request_info` must be handled by a response handler inside the same executor. + # It is safe to assume that an executor is request-response capable if it has at least one response + # handler, and that the executor could send a request. + self.is_request_response_capable = bool(self._response_handlers) + ExecutorT = TypeVar("ExecutorT", bound="Executor") ContextT = TypeVar("ContextT", bound="WorkflowContext[Any, Any]") diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 12b800d3f6..48f4a25508 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. import contextlib -import inspect import logging import uuid from collections.abc import Mapping @@ -50,7 +49,8 @@ class SubWorkflowResponseMessage: This message wraps a RequestResponse emitted by the parent workflow. Attributes: - response: The response to the original request. + data: The response data to the original request. + source_event: The original RequestInfoEvent emitted by the sub-workflow executor. """ data: Any @@ -297,23 +297,12 @@ def output_types(self) -> list[type[Any]]: """ output_types = list(self.workflow.output_types) - # Check if the sub-workflow contains a RequestInfoExecutor - # If so, collect the specific RequestInfoMessage subtypes from all executors - has_request_info_executor = any( - isinstance(executor, RequestInfoExecutor) for executor in self.workflow.executors.values() + is_request_response_capable = any( + executor.is_request_response_capable for executor in self.workflow.executors.values() ) - if has_request_info_executor: - # Collect all RequestInfoMessage subtypes from executor output types - for executor in self.workflow.executors.values(): - for output_type in executor.output_types: - # Check if this is a RequestInfoMessage subclass - if ( - inspect.isclass(output_type) - and issubclass(output_type, RequestInfoMessage) - and output_type not in output_types - ): - output_types.append(output_type) + if is_request_response_capable: + output_types.append(SubWorkflowRequestMessage) return output_types @@ -412,6 +401,7 @@ async def _process_workflow_result( # Process outputs for output in outputs: + # TODO(@taochen): Allow the sub-workflow to output directly await ctx.send_message(output) # Process request info events diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py index e3c3652df0..fc58f06315 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py @@ -1,87 +1,58 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio +import uuid from dataclasses import dataclass -from typing import Any +from typing import Literal, Never from agent_framework import ( Executor, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, + RequestInfoEvent, + SubWorkflowRequestMessage, + SubWorkflowResponseMessage, + Workflow, WorkflowBuilder, WorkflowContext, WorkflowExecutor, handler, + response_handler, ) -from typing_extensions import Never """ -Sample: Sub-workflow with parallel request handling by specialized interceptors - -This sample demonstrates how different parent executors can handle different types of requests -from the same sub-workflow using regular @handler methods for RequestInfoMessage subclasses. - -Prerequisites: -- No external services required (external handling simulated via `RequestInfoExecutor`). - -Key architectural principles: -1. Specialized interceptors: Each parent executor handles only specific request types -2. Type-based routing: ResourceCache handles ResourceRequest, PolicyEngine handles PolicyCheckRequest -3. Automatic type filtering: Each interceptor only receives requests with matching types -4. Fallback forwarding: Unhandled requests are forwarded to external services - -The example simulates a resource allocation system where: -- Sub-workflow makes mixed requests for resources (CPU, memory) and policy checks -- ResourceCache executor intercepts ResourceRequest messages, serves from cache or forwards -- PolicyEngine executor intercepts PolicyCheckRequest messages, applies rules or forwards -- Each interceptor uses typed @handler methods for automatic filtering - -Flow visualization: - - Coordinator - | - | Mixed list[resource + policy requests] - v - [ Sub-workflow: WorkflowExecutor(ResourceRequester) ] - | - | Emits different RequestInfoMessage types: - | - ResourceRequest - | - PolicyCheckRequest - v - Parent workflow routes to specialized handlers: - | | - | ResourceCache.handle_resource_request | PolicyEngine.handle_policy_request - | (@handler ResourceRequest) | (@handler PolicyCheckRequest) - v v - Cache hit/miss decision Policy allow/deny decision - | | - | RequestResponse OR forward | RequestResponse OR forward - v v - Back to sub-workflow <----------> External RequestInfoExecutor - | - v - External responses route back -""" +This sample demonstrates how to handle multiple parallel requests from a sub-workflow to +different executors in the main workflow. +Prerequisite: +- Understanding of sub-workflows. +- Understanding of requests and responses. -# 1. Define domain-specific request/response types -@dataclass -class ResourceRequest(RequestInfoMessage): - """Request for computing resources.""" +This pattern is useful when a sub-workflow needs to interact with multiple external systems +or services. - resource_type: str = "cpu" # cpu, memory, disk, etc. - amount: int = 1 - priority: str = "normal" # low, normal, high +This sample implements a resource request distribution system where: +1. A sub-workflow generates requests for computing resources and policy checks. +2. The main workflow has executors that handle resource allocation and policy checking. +3. Responses are routed back to the sub-workflow, which collects and processes them. + +The sub-workflow sends two types of requests: +- ResourceRequest: Requests for computing resources (e.g., CPU, memory). +- PolicyRequest: Requests to check resource allocation policies. + +The main workflow contains: +- ResourceAllocator: Simulates a system that allocates computing resources. +- PolicyEngine: Simulates a policy engine that approves or denies resource requests. +""" @dataclass -class PolicyCheckRequest(RequestInfoMessage): - """Request to check resource allocation policy.""" +class ComputingResourceRequest: + """Request for computing resources.""" - resource_type: str = "" - amount: int = 0 - policy_type: str = "quota" # quota, compliance, security + request_type: Literal["resource", "policy"] + resource_type: Literal["cpu", "memory", "disk", "gpu"] + amount: int + priority: Literal["low", "normal", "high"] | None = None + policy_type: Literal["quota", "security"] | None = None @dataclass @@ -102,340 +73,281 @@ class PolicyResponse: @dataclass -class RequestFinished: - pass +class ResourceRequest: + """Request for computing resources.""" + resource_type: Literal["cpu", "memory", "disk", "gpu"] + amount: int + priority: Literal["low", "normal", "high"] + id: str = str(uuid.uuid4()) -# 2. Implement the sub-workflow executor - makes resource and policy requests -class ResourceRequester(Executor): - """Simple executor that requests resources and checks policies.""" - def __init__(self): - super().__init__(id="resource_requester") - self._request_count = 0 +@dataclass +class PolicyRequest: + """Request to check resource allocation policy.""" - @handler - async def request_resources( - self, - requests: list[dict[str, Any]], - ctx: WorkflowContext[ResourceRequest | PolicyCheckRequest], - ) -> None: - """Process a list of resource requests.""" - print(f"🏭 Sub-workflow processing {len(requests)} requests") - self._request_count += len(requests) - - for req_data in requests: - req_type = req_data.get("request_type", "resource") - - request: ResourceRequest | PolicyCheckRequest - if req_type == "resource": - print(f" 📦 Requesting resource: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)}") - request = ResourceRequest( - resource_type=req_data.get("type", "cpu"), - amount=req_data.get("amount", 1), - priority=req_data.get("priority", "normal"), - ) - # Send to parent workflow for interception - not to target_id - await ctx.send_message(request) - elif req_type == "policy": - print( - f" 🛡️ Checking policy: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)} " - f"({req_data.get('policy_type', 'quota')})" - ) - request = PolicyCheckRequest( - resource_type=req_data.get("type", "cpu"), - amount=req_data.get("amount", 1), - policy_type=req_data.get("policy_type", "quota"), - ) - # Send to parent workflow for interception - not to target_id - await ctx.send_message(request) + policy_type: Literal["quota", "security"] + resource_type: Literal["cpu", "memory", "disk", "gpu"] + amount: int + id: str = str(uuid.uuid4()) + + +def build_resource_request_distribution_workflow() -> Workflow: + class RequestDistribution(Executor): + """Distributes computing resource requests to appropriate executors.""" + + @handler + async def distribute_requests( + self, + requests: list[ComputingResourceRequest], + ctx: WorkflowContext[ResourceRequest | PolicyRequest | int], + ) -> None: + for req in requests: + if req.request_type == "resource": + if req.priority is None: + raise ValueError("Priority must be set for resource requests") + await ctx.send_message(ResourceRequest(req.resource_type, req.amount, req.priority)) + elif req.request_type == "policy": + if req.policy_type is None: + raise ValueError("Policy type must be set for policy requests") + await ctx.send_message(PolicyRequest(req.policy_type, req.resource_type, req.amount)) + else: + raise ValueError(f"Unknown request type: {req.request_type}") + # Notify the collector about the number of requests sent + await ctx.send_message(len(requests)) + + class ResourceRequester(Executor): + """Handles resource allocation requests.""" + + @handler + async def run(self, request: ResourceRequest, ctx: WorkflowContext) -> None: + await ctx.request_info(request, ResourceRequest, ResourceResponse) + + @response_handler + async def handle_response( + self, original_request: ResourceRequest, response: ResourceResponse, ctx: WorkflowContext[ResourceResponse] + ) -> None: + print(f"Resource allocated: {response.allocated} {response.resource_type} from {response.source}") + await ctx.send_message(response) + + class PolicyChecker(Executor): + """Handles policy check requests.""" + + @handler + async def run(self, request: PolicyRequest, ctx: WorkflowContext) -> None: + await ctx.request_info(request, PolicyRequest, PolicyResponse) + + @response_handler + async def handle_response( + self, original_request: PolicyRequest, response: PolicyResponse, ctx: WorkflowContext[PolicyResponse] + ) -> None: + print(f"Policy check result: {response.approved} - {response.reason}") + await ctx.send_message(response) + + class ResultCollector(Executor): + """Collects and processes all responses.""" + + def __init__(self, id: str) -> None: + super().__init__(id) + self._request_count = 0 + self._responses: list[ResourceResponse | PolicyResponse] = [] + + @handler + async def set_request_count(self, count: int, ctx: WorkflowContext) -> None: + if count <= 0: + raise ValueError("Request count must be positive") + self._request_count = count + + @handler + async def collect(self, response: ResourceResponse | PolicyResponse, ctx: WorkflowContext[Never, str]) -> None: + self._responses.append(response) + print(f"Collected {len(self._responses)}/{self._request_count} responses") + if len(self._responses) == self._request_count: + # All responses received, process them + await ctx.yield_output(f"All {self._request_count} requests processed.") + elif len(self._responses) > self._request_count: + raise ValueError("Received more responses than expected") + + orchestrator = RequestDistribution("orchestrator") + resource_requester = ResourceRequester("resource_requester") + policy_checker = PolicyChecker("policy_checker") + result_collector = ResultCollector("result_collector") + + return ( + WorkflowBuilder() + .set_start_executor(orchestrator) + .add_edge(orchestrator, resource_requester) + .add_edge(orchestrator, policy_checker) + .add_edge(resource_requester, result_collector) + .add_edge(policy_checker, result_collector) + .add_edge(orchestrator, result_collector) # For request count + .build() + ) - @handler - async def handle_resource_response( - self, - response: RequestResponse[ResourceRequest, ResourceResponse], - ctx: WorkflowContext[Never, RequestFinished], - ) -> None: - """Handle resource allocation response.""" - if response.data: - source_icon = "🏪" if response.data.source == "cache" else "🌐" - print( - f"📦 {source_icon} Sub-workflow received: {response.data.allocated} {response.data.resource_type} " - f"from {response.data.source}" - ) - if self._collect_results(): - # Yield completion result to the parent workflow. - await ctx.yield_output(RequestFinished()) - @handler - async def handle_policy_response( - self, - response: RequestResponse[PolicyCheckRequest, PolicyResponse], - ctx: WorkflowContext[Never, RequestFinished], - ) -> None: - """Handle policy check response.""" - if response.data: - status_icon = "✅" if response.data.approved else "❌" - print( - f"🛡️ {status_icon} Sub-workflow received policy response: " - f"{response.data.approved} - {response.data.reason}" - ) - if self._collect_results(): - # Yield completion result to the parent workflow. - await ctx.yield_output(RequestFinished()) - - def _collect_results(self) -> bool: - """Collect and summarize results.""" - self._request_count -= 1 - print(f"📊 Sub-workflow completed request ({self._request_count} remaining)") - return self._request_count == 0 - - -# 3. Implement the Resource Cache - Uses typed handler for ResourceRequest -class ResourceCache(Executor): - """Interceptor that handles RESOURCE requests from cache using typed routing.""" - - # Use class attributes to avoid Pydantic assignment restrictions - cache: dict[str, int] = {"cpu": 10, "memory": 50, "disk": 100} - results: list[ResourceResponse] = [] - - def __init__(self): - super().__init__(id="resource_cache") - # Instance initialization only; state kept in class attributes as above +class ResourceAllocator(Executor): + """Simulates a system that allocates computing resources.""" + + def __init__(self, id: str) -> None: + super().__init__(id) + self._cache: dict[str, int] = {"cpu": 10, "memory": 50, "disk": 100} + # Record pending requests to match responses + self._pending_requests: dict[str, RequestInfoEvent] = {} + + async def _handle_resource_request(self, request: ResourceRequest) -> ResourceResponse | None: + """Allocates resources based on request and available cache.""" + available = self._cache.get(request.resource_type, 0) + if available >= request.amount: + self._cache[request.resource_type] -= request.amount + return ResourceResponse(request.resource_type, request.amount, "cache") + return None @handler - async def handle_resource_request( - self, request: ResourceRequest, ctx: WorkflowContext[RequestResponse[ResourceRequest, Any] | ResourceRequest] + async def handle_subworkflow_request( + self, request: SubWorkflowRequestMessage, ctx: WorkflowContext[SubWorkflowResponseMessage] ) -> None: - """Handle RESOURCE requests from sub-workflows and check cache first.""" - resource_request = request - print(f"🏪 CACHE interceptor checking: {resource_request.amount} {resource_request.resource_type}") - - available = self.cache.get(resource_request.resource_type, 0) - - if available >= resource_request.amount: - # We can satisfy from cache - self.cache[resource_request.resource_type] -= resource_request.amount - response_data = ResourceResponse( - resource_type=resource_request.resource_type, allocated=resource_request.amount, source="cache" - ) - print(f" ✅ Cache satisfied: {resource_request.amount} {resource_request.resource_type}") - self.results.append(response_data) - - # Send response back to sub-workflow - response = RequestResponse(data=response_data, original_request=request, request_id=request.request_id) - await ctx.send_message(response, target_id=request.source_executor_id) + """Handles requests from sub-workflows.""" + source_event: RequestInfoEvent = request.source_event + if not isinstance(source_event.data, ResourceRequest): + return + + request_payload: ResourceRequest = source_event.data + response = await self._handle_resource_request(request_payload) + if response: + await ctx.send_message(request.create_response(response)) else: - # Cache miss - forward to external - print(f" ❌ Cache miss: need {resource_request.amount}, have {available} {resource_request.resource_type}") - await ctx.send_message(request) + # Request cannot be fulfilled via cache, forward the request to external + self._pending_requests[request_payload.id] = source_event + await ctx.request_info(request_payload, ResourceRequest, ResourceResponse) - @handler - async def collect_result( - self, response: RequestResponse[ResourceRequest, ResourceResponse], ctx: WorkflowContext + @response_handler + async def handle_external_response( + self, + original_request: ResourceRequest, + response: ResourceResponse, + ctx: WorkflowContext[SubWorkflowResponseMessage], ) -> None: - """Collect results from external requests that were forwarded.""" - if response.data and response.data.source != "cache": # Don't double-count our own results - self.results.append(response.data) - print( - f"🏪 🌐 Cache received external response: {response.data.allocated} {response.data.resource_type} " - f"from {response.data.source}" - ) + """Handles responses from external systems and routes them to the sub-workflow.""" + print(f"External resource allocated: {response.allocated} {response.resource_type} from {response.source}") + source_event = self._pending_requests.pop(original_request.id, None) + if source_event is None: + raise ValueError("No matching pending request found for the resource response") + await ctx.send_message(SubWorkflowResponseMessage(data=response, source_event=source_event)) -# 4. Implement the Policy Engine - Uses typed handler for PolicyCheckRequest class PolicyEngine(Executor): - """Interceptor that handles POLICY requests using typed routing.""" - - # Use class attributes for simple sample state - quota: dict[str, int] = { - "cpu": 5, # Only allow up to 5 CPU units - "memory": 20, # Only allow up to 20 memory units - "disk": 1000, # Liberal disk policy - } - results: list[PolicyResponse] = [] - - def __init__(self): - super().__init__(id="policy_engine") - # Instance initialization only; state kept in class attributes as above + """Simulates a policy engine that approves or denies resource requests.""" + + def __init__(self, id: str) -> None: + super().__init__(id) + self._quota: dict[str, int] = { + "cpu": 5, # Only allow up to 5 CPU units + "memory": 20, # Only allow up to 20 memory units + "disk": 1000, # Liberal disk policy + } + # Record pending requests to match responses + self._pending_requests: dict[str, RequestInfoEvent] = {} @handler - async def handle_policy_request( - self, - request: PolicyCheckRequest, - ctx: WorkflowContext[RequestResponse[PolicyCheckRequest, Any] | PolicyCheckRequest], + async def handle_subworkflow_request( + self, request: SubWorkflowRequestMessage, ctx: WorkflowContext[SubWorkflowResponseMessage] ) -> None: - """Handle POLICY requests from sub-workflows and apply rules.""" - policy_request = request - print( - f"🛡️ POLICY interceptor checking: {policy_request.amount} {policy_request.resource_type}, policy={policy_request.policy_type}" - ) - - quota_limit = self.quota.get(policy_request.resource_type, 0) - - if policy_request.policy_type == "quota": - if policy_request.amount <= quota_limit: - response_data = PolicyResponse(approved=True, reason=f"Within quota ({quota_limit})") - print(f" ✅ Policy approved: {policy_request.amount} <= {quota_limit}") - self.results.append(response_data) - - # Send response back to sub-workflow - response = RequestResponse(data=response_data, original_request=request, request_id=request.request_id) - await ctx.send_message(response, target_id=request.source_executor_id) - return - - # Exceeds quota - forward to external for review - print(f" ❌ Policy exceeds quota: {policy_request.amount} > {quota_limit}, forwarding to external") - await ctx.send_message(request) + """Handles requests from sub-workflows.""" + source_event: RequestInfoEvent = request.source_event + if not isinstance(source_event.data, PolicyRequest): return - # Unknown policy type - forward to external - print(f" ❓ Unknown policy type: {policy_request.policy_type}, forwarding") - await ctx.send_message(request) + request_payload: PolicyRequest = source_event.data + # Simple policy logic for demonstration + if request_payload.policy_type == "quota": + allowed_amount = self._quota.get(request_payload.resource_type, 0) + if request_payload.amount <= allowed_amount: + response = PolicyResponse(True, "Within quota limits") + else: + response = PolicyResponse(False, "Exceeds quota limits") + await ctx.send_message(request.create_response(response)) + else: + # For other policy types, forward to external system + self._pending_requests[request_payload.id] = source_event + await ctx.request_info(request_payload, PolicyRequest, PolicyResponse) - @handler - async def collect_policy_result( - self, response: RequestResponse[PolicyCheckRequest, PolicyResponse], ctx: WorkflowContext + @response_handler + async def handle_external_response( + self, + original_request: PolicyRequest, + response: PolicyResponse, + ctx: WorkflowContext[SubWorkflowResponseMessage], ) -> None: - """Collect policy results from external requests that were forwarded.""" - if response.data: - self.results.append(response.data) - print(f"🛡️ 🌐 Policy received external response: {response.data.approved} - {response.data.reason}") - - -class Coordinator(Executor): - def __init__(self): - super().__init__(id="coordinator") - - @handler - async def start(self, requests: list[dict[str, Any]], ctx: WorkflowContext[list[dict[str, Any]]]) -> None: - """Start the resource allocation process.""" - await ctx.send_message(requests, target_id="resource_workflow") - - @handler - async def handle_completion(self, completion: RequestFinished, ctx: WorkflowContext) -> None: - """Handle sub-workflow completion. - - It comes from the sub-workflow yielded output. - """ - print("🎯 Main workflow received completion.") + """Handles responses from external systems and routes them to the sub-workflow.""" + print(f"External policy check result: {response.approved} - {response.reason}") + source_event = self._pending_requests.pop(original_request.id, None) + if source_event is None: + raise ValueError("No matching pending request found for the policy response") + await ctx.send_message(SubWorkflowResponseMessage(data=response, source_event=source_event)) async def main() -> None: - """Demonstrate parallel request interception patterns.""" - print("🚀 Starting Sub-Workflow Parallel Request Interception Demo...") - print("=" * 60) + # Create executors in the main workflow + sub_workflow = build_resource_request_distribution_workflow() + resource_allocator = ResourceAllocator("resource_allocator") + policy_engine = PolicyEngine("policy_engine") - # 5. Create the sub-workflow - resource_requester = ResourceRequester() - sub_request_info = RequestInfoExecutor(id="sub_request_info") - - sub_workflow = ( - WorkflowBuilder() - .set_start_executor(resource_requester) - .add_edge(resource_requester, sub_request_info) - .add_edge(sub_request_info, resource_requester) - .build() - ) - - # 6. Create parent workflow with PROPER interceptor pattern - cache = ResourceCache() # Intercepts ResourceRequest - policy = PolicyEngine() # Intercepts PolicyCheckRequest (different type!) - workflow_executor = WorkflowExecutor(sub_workflow, id="resource_workflow") - main_request_info = RequestInfoExecutor(id="main_request_info") - - # Create a simple coordinator that starts the process - coordinator = Coordinator() - - # TYPED ROUTING: Each executor handles specific typed RequestInfoMessage messages + # Build the main workflow + sub_workflow_executor = WorkflowExecutor(sub_workflow, "sub_workflow_executor") main_workflow = ( WorkflowBuilder() - .set_start_executor(coordinator) - .add_edge(coordinator, workflow_executor) # Start sub-workflow - .add_edge(workflow_executor, coordinator) # Sub-workflow completion back to coordinator - .add_edge(workflow_executor, cache) # WorkflowExecutor sends ResourceRequest to cache - .add_edge(workflow_executor, policy) # WorkflowExecutor sends PolicyCheckRequest to policy - .add_edge(cache, workflow_executor) # Cache sends RequestResponse back - .add_edge(policy, workflow_executor) # Policy sends RequestResponse back - .add_edge(cache, main_request_info) # Cache forwards ResourceRequest to external - .add_edge(policy, main_request_info) # Policy forwards PolicyCheckRequest to external - .add_edge(main_request_info, workflow_executor) # External responses back to sub-workflow + .set_start_executor(sub_workflow_executor) + .add_edge(sub_workflow_executor, resource_allocator) + .add_edge(resource_allocator, sub_workflow_executor) + .add_edge(sub_workflow_executor, policy_engine) + .add_edge(policy_engine, sub_workflow_executor) .build() ) - # 7. Test with various requests (mixed resource and policy) + # Test requests test_requests = [ - {"request_type": "resource", "type": "cpu", "amount": 2, "priority": "normal"}, # Cache hit - {"request_type": "policy", "type": "cpu", "amount": 3, "policy_type": "quota"}, # Policy hit - {"request_type": "resource", "type": "memory", "amount": 15, "priority": "normal"}, # Cache hit - {"request_type": "policy", "type": "memory", "amount": 100, "policy_type": "quota"}, # Policy miss -> external - {"request_type": "resource", "type": "gpu", "amount": 1, "priority": "high"}, # Cache miss -> external - {"request_type": "policy", "type": "disk", "amount": 500, "policy_type": "quota"}, # Policy hit - {"request_type": "policy", "type": "cpu", "amount": 1, "policy_type": "security"}, # Unknown policy -> external + ComputingResourceRequest("resource", "cpu", 2, priority="normal"), # cache hit + ComputingResourceRequest("policy", "cpu", 3, policy_type="quota"), # policy hit + ComputingResourceRequest("resource", "memory", 15, priority="normal"), # cache hit + ComputingResourceRequest("policy", "memory", 100, policy_type="quota"), # policy miss -> external + ComputingResourceRequest("resource", "gpu", 1, priority="high"), # cache miss -> external + ComputingResourceRequest("policy", "disk", 500, policy_type="quota"), # policy hit + ComputingResourceRequest("policy", "cpu", 1, policy_type="security"), # unknown policy -> external ] - print(f"🧪 Testing with {len(test_requests)} mixed requests:") - for i, req in enumerate(test_requests, 1): - req_icon = "📦" if req["request_type"] == "resource" else "🛡️" - print( - f" {i}. {req_icon} {req['type']} x{req['amount']} " - f"({req.get('priority', req.get('policy_type', 'default'))})" - ) - print("=" * 70) - - # 8. Run the workflow - print("🎬 Running workflow...") - events = await main_workflow.run(test_requests) - - # 9. Handle any external requests that couldn't be intercepted - request_events = events.get_request_info_events() - if request_events: - print(f"\n🌐 Handling {len(request_events)} external request(s)...") - - external_responses: dict[str, Any] = {} - for event in request_events: + # Run the workflow + print(f"🧪 Testing with {len(test_requests)} mixed requests.") + print("🚀 Starting main workflow...") + run_result = await main_workflow.run(test_requests) + + # Handle request info events + request_info_events = run_result.get_request_info_events() + if request_info_events: + print(f"\n🔍 Handling {len(request_info_events)} request info events...\n") + + responses: dict[str, ResourceResponse | PolicyResponse] = {} + for event in request_info_events: if isinstance(event.data, ResourceRequest): - # Handle ResourceRequest - create ResourceResponse + # Simulate external resource allocation resource_response = ResourceResponse( resource_type=event.data.resource_type, allocated=event.data.amount, source="external_provider" ) - external_responses[event.request_id] = resource_response - print(f" 🏭 External provider: {resource_response.allocated} {resource_response.resource_type}") - elif isinstance(event.data, PolicyCheckRequest): - # Handle PolicyCheckRequest - create PolicyResponse - policy_response = PolicyResponse(approved=True, reason="External policy service approved") - external_responses[event.request_id] = policy_response - print(f" 🔒 External policy: {'✅ APPROVED' if policy_response.approved else '❌ DENIED'}") - - await main_workflow.send_responses(external_responses) - else: - print("\n🎯 All requests were intercepted internally!") - - # 10. Show results and analysis - print("\n" + "=" * 70) - print("📊 RESULTS ANALYSIS") - print("=" * 70) - - print(f"\n🏪 Cache Results ({len(cache.results)} handled):") - for result in cache.results: - print(f" ✅ {result.allocated} {result.resource_type} from {result.source}") - - print(f"\n🛡️ Policy Results ({len(policy.results)} handled):") - for result in policy.results: - status_icon = "✅" if result.approved else "❌" - print(f" {status_icon} Approved: {result.approved} - {result.reason}") - - print("\n💾 Final Cache State:") - for resource, amount in cache.cache.items(): - print(f" 📦 {resource}: {amount} remaining") - - print("\n📈 Summary:") - print(f" 🎯 Total requests: {len(test_requests)}") - print(f" 🏪 Resource requests handled: {len(cache.results)}") - print(f" 🛡️ Policy requests handled: {len(policy.results)}") - print(f" 🌐 External requests: {len(request_events) if request_events else 0}") - - print("\n" + "=" * 70) + responses[event.request_id] = resource_response + elif isinstance(event.data, PolicyRequest): + # Simulate external policy check + response = PolicyResponse(True, "External system approved") + responses[event.request_id] = response + else: + print(f"Unknown request info event data type: {type(event.data)}") + + run_result = await main_workflow.send_responses(responses) + + outputs = run_result.get_outputs() + if outputs: + print("\nWorkflow completed with outputs:") + for output in outputs: + # TODO(@taochen): Allow the sub-workflow to output directly + print(f"- {output}") if __name__ == "__main__": From 598830c655310e02c7199453385e4936818b3b13 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 17 Oct 2025 15:56:05 -0700 Subject: [PATCH 05/26] WIP: Make checkpointing work --- .../agent_framework/_workflows/_checkpoint.py | 1 + .../_workflows/_checkpoint_encoding.py | 250 +++++++++++++ .../_workflows/_checkpoint_summary.py | 6 +- .../agent_framework/_workflows/_executor.py | 18 +- .../agent_framework/_workflows/_runner.py | 13 +- .../_workflows/_runner_context.py | 353 ++++-------------- .../agent_framework/_workflows/_workflow.py | 22 +- .../checkpoint_with_human_in_the_loop.py | 302 +++++---------- 8 files changed, 436 insertions(+), 529 deletions(-) create mode 100644 python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint.py b/python/packages/core/agent_framework/_workflows/_checkpoint.py index d47287237f..cc21b32f38 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint.py @@ -28,6 +28,7 @@ class WorkflowCheckpoint: messages: dict[str, list[dict[str, Any]]] = field(default_factory=dict) # type: ignore[misc] shared_state: dict[str, Any] = field(default_factory=dict) # type: ignore[misc] executor_states: dict[str, dict[str, Any]] = field(default_factory=dict) # type: ignore[misc] + pending_request_info_events: dict[str, dict[str, Any]] = field(default_factory=dict) # type: ignore[misc] # Runtime state iteration_count: int = 0 diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py b/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py new file mode 100644 index 0000000000..191674fa23 --- /dev/null +++ b/python/packages/core/agent_framework/_workflows/_checkpoint_encoding.py @@ -0,0 +1,250 @@ +# Copyright (c) Microsoft. All rights reserved. + +import contextlib +import importlib +import logging +import sys +from dataclasses import fields, is_dataclass +from typing import Any, cast + +logger = logging.getLogger(__name__) + +# Checkpoint serialization helpers +MODEL_MARKER = "__af_model__" +DATACLASS_MARKER = "__af_dataclass__" + +# Guards to prevent runaway recursion while encoding arbitrary user data +_MAX_ENCODE_DEPTH = 100 +_CYCLE_SENTINEL = "" + + +def encode_checkpoint_value(value: Any) -> Any: + """Recursively encode values into JSON-serializable structures. + + - Objects exposing to_dict/to_json -> { MODEL_MARKER: "module:Class", value: encoded } + - dataclass instances -> { DATACLASS_MARKER: "module:Class", value: {field: encoded} } + - dict -> encode keys as str and values recursively + - list/tuple/set -> list of encoded items + - other -> returned as-is if already JSON-serializable + + Includes cycle and depth protection to avoid infinite recursion. + """ + + def _enc(v: Any, stack: set[int], depth: int) -> Any: + # Depth guard + if depth > _MAX_ENCODE_DEPTH: + logger.debug(f"Max encode depth reached at depth={depth} for type={type(v)}") + return "" + + # Structured model handling (objects exposing to_dict/to_json) + if _supports_model_protocol(v): + cls = cast(type[Any], type(v)) # type: ignore + try: + if hasattr(v, "to_dict") and callable(getattr(v, "to_dict", None)): + raw = v.to_dict() # type: ignore[attr-defined] + strategy = "to_dict" + elif hasattr(v, "to_json") and callable(getattr(v, "to_json", None)): + serialized = v.to_json() # type: ignore[attr-defined] + if isinstance(serialized, (bytes, bytearray)): + try: + serialized = serialized.decode() + except Exception: + serialized = serialized.decode(errors="replace") + raw = serialized + strategy = "to_json" + else: + raise AttributeError("Structured model lacks serialization hooks") + return { + MODEL_MARKER: f"{cls.__module__}:{cls.__name__}", + "strategy": strategy, + "value": _enc(raw, stack, depth + 1), + } + except Exception as exc: # best-effort fallback + logger.debug(f"Structured model serialization failed for {cls}: {exc}") + return str(v) + + # Dataclasses (instances only) + if is_dataclass(v) and not isinstance(v, type): + oid = id(v) + if oid in stack: + logger.debug("Cycle detected while encoding dataclass instance") + return _CYCLE_SENTINEL + stack.add(oid) + try: + # type(v) already narrows sufficiently; cast was redundant + dc_cls: type[Any] = type(v) + field_values: dict[str, Any] = {} + for f in fields(v): + field_values[f.name] = _enc(getattr(v, f.name), stack, depth + 1) + return { + DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}", + "value": field_values, + } + finally: + stack.remove(oid) + + # Collections + if isinstance(v, dict): + v_dict = cast("dict[object, object]", v) + oid = id(v_dict) + if oid in stack: + logger.debug("Cycle detected while encoding dict") + return _CYCLE_SENTINEL + stack.add(oid) + try: + json_dict: dict[str, Any] = {} + for k_any, val_any in v_dict.items(): # type: ignore[assignment] + k_str: str = str(k_any) + json_dict[k_str] = _enc(val_any, stack, depth + 1) + return json_dict + finally: + stack.remove(oid) + + if isinstance(v, (list, tuple, set)): + iterable_v = cast("list[object] | tuple[object, ...] | set[object]", v) + oid = id(iterable_v) + if oid in stack: + logger.debug("Cycle detected while encoding iterable") + return _CYCLE_SENTINEL + stack.add(oid) + try: + seq: list[object] = list(iterable_v) + encoded_list: list[Any] = [] + for item in seq: + encoded_list.append(_enc(item, stack, depth + 1)) + return encoded_list + finally: + stack.remove(oid) + + # Primitives (or unknown objects): ensure JSON-serializable + if isinstance(v, (str, int, float, bool)) or v is None: + return v + # Fallback: stringify unknown objects to avoid JSON serialization errors + try: + return str(v) + except Exception: + return f"<{type(v).__name__}>" + + return _enc(value, set(), 0) + + +def decode_checkpoint_value(value: Any) -> Any: + """Recursively decode values previously encoded by encode_checkpoint_value.""" + if isinstance(value, dict): + value_dict = cast(dict[str, Any], value) # encoded form always uses string keys + # Structured model marker handling + if MODEL_MARKER in value_dict and "value" in value_dict: + type_key: str | None = value_dict.get(MODEL_MARKER) # type: ignore[assignment] + strategy: str | None = value_dict.get("strategy") # type: ignore[assignment] + raw_encoded: Any = value_dict.get("value") + decoded_payload = decode_checkpoint_value(raw_encoded) + if isinstance(type_key, str): + try: + cls = _import_qualified_name(type_key) + except Exception as exc: + logger.debug(f"Failed to import structured model {type_key}: {exc}") + cls = None + + if cls is not None: + if strategy == "to_dict" and hasattr(cls, "from_dict"): + with contextlib.suppress(Exception): + return cls.from_dict(decoded_payload) + if strategy == "to_json" and hasattr(cls, "from_json"): + if isinstance(decoded_payload, (str, bytes, bytearray)): + with contextlib.suppress(Exception): + return cls.from_json(decoded_payload) + if isinstance(decoded_payload, dict) and hasattr(cls, "from_dict"): + with contextlib.suppress(Exception): + return cls.from_dict(decoded_payload) + return decoded_payload + # Dataclass marker handling + if DATACLASS_MARKER in value_dict and "value" in value_dict: + type_key_dc: str | None = value_dict.get(DATACLASS_MARKER) # type: ignore[assignment] + raw_dc: Any = value_dict.get("value") + decoded_raw = decode_checkpoint_value(raw_dc) + if isinstance(type_key_dc, str): + try: + module_name, class_name = type_key_dc.split(":", 1) + module = sys.modules.get(module_name) + if module is None: + module = importlib.import_module(module_name) + cls_dc: Any = getattr(module, class_name) + constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw) + if constructed is not None: + return constructed + except Exception as exc: + logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value") + return decoded_raw + + # Regular dict: decode recursively + decoded: dict[str, Any] = {} + for k_any, v_any in value_dict.items(): + decoded[k_any] = decode_checkpoint_value(v_any) + return decoded + if isinstance(value, list): + # After isinstance check, treat value as list[Any] for decoding + value_list: list[Any] = value # type: ignore[assignment] + return [decode_checkpoint_value(v_any) for v_any in value_list] + return value + + +def _supports_model_protocol(obj: object) -> bool: + """Detect objects that expose dictionary serialization hooks.""" + try: + obj_type: type[Any] = type(obj) + except Exception: + return False + + has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type] + has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None)) + + has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type] + has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None)) + + return (has_to_dict and has_from_dict) or (has_to_json and has_from_json) + + +def _import_qualified_name(qualname: str) -> type[Any] | None: + if ":" not in qualname: + return None + module_name, class_name = qualname.split(":", 1) + module = sys.modules.get(module_name) + if module is None: + module = importlib.import_module(module_name) + attr: Any = module + for part in class_name.split("."): + attr = getattr(attr, part) + return attr if isinstance(attr, type) else None + + +def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None: + if not isinstance(cls, type): + logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}") + return None + + if isinstance(payload, dict): + try: + return cls(**payload) # type: ignore[arg-type] + except TypeError as exc: + logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}") + except Exception as exc: + logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}") + try: + instance = object.__new__(cls) + except Exception as exc: + logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}") + return None + for key, val in payload.items(): # type: ignore[attr-defined] + try: + setattr(instance, key, val) # type: ignore[arg-type] + except Exception as exc: + logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}") + return instance + + try: + return cls(payload) # type: ignore[call-arg] + except TypeError as exc: + logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}") + except Exception as exc: + logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}") + return None diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py b/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py index e42e05dd91..1c710f2183 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py @@ -7,8 +7,8 @@ from typing import Any from ._checkpoint import WorkflowCheckpoint +from ._checkpoint_encoding import decode_checkpoint_value # type: ignore from ._request_info_executor import PendingRequestDetails, RequestInfoMessage, RequestResponse -from ._runner_context import _decode_checkpoint_value # type: ignore logger = logging.getLogger(__name__) @@ -18,6 +18,7 @@ class WorkflowCheckpointSummary: """Human-readable summary of a workflow checkpoint.""" checkpoint_id: str + timestamp: str iteration_count: int targets: list[str] executor_ids: list[str] @@ -54,6 +55,7 @@ def get_checkpoint_summary( return WorkflowCheckpointSummary( checkpoint_id=checkpoint.checkpoint_id, + timestamp=checkpoint.timestamp, iteration_count=checkpoint.iteration_count, targets=targets, executor_ids=executor_ids, @@ -90,7 +92,7 @@ def _pending_requests_from_checkpoint( for message in message_list: if not isinstance(message, Mapping): continue - payload = _decode_checkpoint_value(message.get("data")) + payload = decode_checkpoint_value(message.get("data")) _merge_message_payload(pending, payload, message) return list(pending.values()) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index d9f9f65fa0..e9ca84d53c 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -17,7 +17,7 @@ ) from ._model_utils import DictConvertible from ._request_info_mixin import RequestInfoMixin -from ._runner_context import Message, ResponseMessage, RunnerContext +from ._runner_context import Message, MessageType, RunnerContext from ._shared_state import SharedState from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext, validate_workflow_context_annotation @@ -236,14 +236,14 @@ async def execute( # Default to find handler in regular handlers target_handlers = self._handlers - if isinstance(message, ResponseMessage): - # Wrap the response handlers to include original_request parameter - target_handlers = { - message_type: functools.partial(handler, message.original_request) - for message_type, handler in self._response_handlers.items() - } - # Handle case where Message wrapper is passed instead of raw data if isinstance(message, Message): + # Wrap the response handlers to include original_request parameter + if message.type == MessageType.RESPONSE: + target_handlers = { + message_type: functools.partial(handler, message.original_request) + for message_type, handler in self._response_handlers.items() + } + # Handle case where Message wrapper is passed instead of raw data message = message.data # Create processing span for tracing (gracefully handles disabled tracing) @@ -362,7 +362,7 @@ def can_handle(self, message: Message) -> bool: Returns: True if the executor can handle the message type, False otherwise. """ - if isinstance(message, ResponseMessage): + if message.type == MessageType.RESPONSE: return any(is_instance_of(message.data, message_type) for message_type in self._response_handlers) return any(is_instance_of(message.data, message_type) for message_type in self._handlers) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index c3c11537f5..3d01c7b85f 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -7,17 +7,15 @@ from typing import TYPE_CHECKING, Any from ._checkpoint import CheckpointStorage, WorkflowCheckpoint +from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpoint_value from ._edge import EdgeGroup from ._edge_runner import EdgeRunner, create_edge_runner from ._events import WorkflowEvent from ._executor import Executor from ._runner_context import ( - _DATACLASS_MARKER, # type: ignore - _MODEL_MARKER, # type: ignore CheckpointState, Message, - RunnerContext, - _decode_checkpoint_value, # type: ignore + RunnerContext, # type: ignore ) from ._shared_state import SharedState @@ -151,6 +149,8 @@ async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]: raise RuntimeError(f"Runner did not converge after {self._max_iterations} iterations.") logger.info(f"Workflow completed after {self._iteration} supersteps") + # TODO(@taochen): iteration is reset to zero, even in the event of a request info event. + # Should iteration be preserved in the event of a request info event? self._iteration = 0 self._resumed_from_checkpoint = False # Reset resume flag for next run finally: @@ -168,10 +168,10 @@ def _normalize_message_payload(message: Message) -> None: data = message.data if not isinstance(data, dict): return - if _MODEL_MARKER not in data and _DATACLASS_MARKER not in data: + if MODEL_MARKER not in data and DATACLASS_MARKER not in data: return try: - decoded = _decode_checkpoint_value(data) + decoded = decode_checkpoint_value(data) except Exception as exc: # pragma: no cover - defensive logger.debug("Failed to decode checkpoint payload during delivery: %s", exc) return @@ -374,6 +374,7 @@ def _checkpoint_to_state(checkpoint: WorkflowCheckpoint) -> CheckpointState: "executor_states": checkpoint.executor_states, "iteration_count": checkpoint.iteration_count, "max_iterations": checkpoint.max_iterations, + "pending_request_info_events": checkpoint.pending_request_info_events, } def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[EdgeRunner]]: diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index c02245c123..1be6462a93 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -1,16 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -import contextlib -import importlib import logging -import sys import uuid from copy import copy -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass +from enum import Enum from typing import Any, Protocol, TypedDict, TypeVar, cast, runtime_checkable from ._checkpoint import CheckpointStorage, WorkflowCheckpoint +from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._const import DEFAULT_MAX_ITERATIONS, INTERNAL_SOURCE_ID from ._events import RequestInfoEvent, WorkflowEvent from ._shared_state import SharedState @@ -20,6 +19,16 @@ T = TypeVar("T") +class MessageType(Enum): + """Enumeration of message types in the workflow.""" + + STANDARD = "standard" + """A standard message between executors.""" + + RESPONSE = "response" + """A response message to a pending request.""" + + @dataclass class Message: """A class representing a message in the workflow.""" @@ -27,12 +36,16 @@ class Message: data: Any source_id: str target_id: str | None = None + type: MessageType = MessageType.STANDARD # OpenTelemetry trace context fields for message propagation # These are plural to support fan-in scenarios where multiple messages are aggregated trace_contexts: list[dict[str, str]] | None = None # W3C Trace Context headers from multiple sources source_span_ids: list[str] | None = None # Publishing span IDs for linking from multiple sources + # For response messages, the original request data + original_request: Any = None + # Backward compatibility properties @property def trace_context(self) -> dict[str, str] | None: @@ -44,12 +57,30 @@ def source_span_id(self) -> str | None: """Get the first source span ID for backward compatibility.""" return self.source_span_ids[0] if self.source_span_ids else None + def to_dict(self) -> dict[str, Any]: + """Convert the Message to a dictionary for serialization.""" + return { + "data": encode_checkpoint_value(self.data), + "source_id": self.source_id, + "target_id": self.target_id, + "type": self.type.value, + "trace_contexts": self.trace_contexts, + "source_span_ids": self.source_span_ids, + "original_request": self.original_request, + } -@dataclass -class ResponseMessage(Message): - """A message representing a response to a pending request.""" - - original_request: Any = None + @staticmethod + def from_dict(data: dict[str, Any]) -> "Message": + """Create a Message from a dictionary.""" + return Message( + data=decode_checkpoint_value(data.get("data")), + source_id=data["source_id"], + target_id=data.get("target_id"), + type=MessageType(data.get("type", "standard")), + trace_contexts=data.get("trace_contexts"), + source_span_ids=data.get("source_span_ids"), + original_request=data.get("original_request"), + ) class CheckpointState(TypedDict): @@ -58,248 +89,7 @@ class CheckpointState(TypedDict): executor_states: dict[str, dict[str, Any]] iteration_count: int max_iterations: int - - -# Checkpoint serialization helpers -_MODEL_MARKER = "__af_model__" -_DATACLASS_MARKER = "__af_dataclass__" -_AF_MARKER = "__af__" - -# Guards to prevent runaway recursion while encoding arbitrary user data -_MAX_ENCODE_DEPTH = 100 -_CYCLE_SENTINEL = "" - - -def _instantiate_checkpoint_dataclass(cls: type[Any], payload: Any) -> Any | None: - if not isinstance(cls, type): - logger.debug(f"Checkpoint decoder received non-type dataclass reference: {cls!r}") - return None - - if isinstance(payload, dict): - try: - return cls(**payload) # type: ignore[arg-type] - except TypeError as exc: - logger.debug(f"Checkpoint decoder could not call {cls.__name__}(**payload): {exc}") - except Exception as exc: - logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}(**payload): {exc}") - try: - instance = object.__new__(cls) - except Exception as exc: - logger.debug(f"Checkpoint decoder could not allocate {cls.__name__} without __init__: {exc}") - return None - for key, val in payload.items(): # type: ignore[attr-defined] - try: - setattr(instance, key, val) # type: ignore[arg-type] - except Exception as exc: - logger.debug(f"Checkpoint decoder could not set attribute {key} on {cls.__name__}: {exc}") - return instance - - try: - return cls(payload) # type: ignore[call-arg] - except TypeError as exc: - logger.debug(f"Checkpoint decoder could not call {cls.__name__}({payload!r}): {exc}") - except Exception as exc: - logger.warning(f"Checkpoint decoder encountered unexpected error calling {cls.__name__}({payload!r}): {exc}") - return None - - -def _supports_model_protocol(obj: object) -> bool: - """Detect objects that expose dictionary serialization hooks.""" - try: - obj_type: type[Any] = type(obj) - except Exception: - return False - - has_to_dict = hasattr(obj, "to_dict") and callable(getattr(obj, "to_dict", None)) # type: ignore[arg-type] - has_from_dict = hasattr(obj_type, "from_dict") and callable(getattr(obj_type, "from_dict", None)) - - has_to_json = hasattr(obj, "to_json") and callable(getattr(obj, "to_json", None)) # type: ignore[arg-type] - has_from_json = hasattr(obj_type, "from_json") and callable(getattr(obj_type, "from_json", None)) - - return (has_to_dict and has_from_dict) or (has_to_json and has_from_json) - - -def _import_qualified_name(qualname: str) -> type[Any] | None: - if ":" not in qualname: - return None - module_name, class_name = qualname.split(":", 1) - module = sys.modules.get(module_name) - if module is None: - module = importlib.import_module(module_name) - attr: Any = module - for part in class_name.split("."): - attr = getattr(attr, part) - return attr if isinstance(attr, type) else None - - -def _encode_checkpoint_value(value: Any) -> Any: - """Recursively encode values into JSON-serializable structures. - - - Objects exposing to_dict/to_json -> { _MODEL_MARKER: "module:Class", value: encoded } - - dataclass instances -> { _DATACLASS_MARKER: "module:Class", value: {field: encoded} } - - dict -> encode keys as str and values recursively - - list/tuple/set -> list of encoded items - - other -> returned as-is if already JSON-serializable - - Includes cycle and depth protection to avoid infinite recursion. - """ - - def _enc(v: Any, stack: set[int], depth: int) -> Any: - # Depth guard - if depth > _MAX_ENCODE_DEPTH: - logger.debug(f"Max encode depth reached at depth={depth} for type={type(v)}") - return "" - - # Structured model handling (objects exposing to_dict/to_json) - if _supports_model_protocol(v): - cls = cast(type[Any], type(v)) # type: ignore - try: - if hasattr(v, "to_dict") and callable(getattr(v, "to_dict", None)): - raw = v.to_dict() # type: ignore[attr-defined] - strategy = "to_dict" - elif hasattr(v, "to_json") and callable(getattr(v, "to_json", None)): - serialized = v.to_json() # type: ignore[attr-defined] - if isinstance(serialized, (bytes, bytearray)): - try: - serialized = serialized.decode() - except Exception: - serialized = serialized.decode(errors="replace") - raw = serialized - strategy = "to_json" - else: - raise AttributeError("Structured model lacks serialization hooks") - return { - _MODEL_MARKER: f"{cls.__module__}:{cls.__name__}", - "strategy": strategy, - "value": _enc(raw, stack, depth + 1), - } - except Exception as exc: # best-effort fallback - logger.debug(f"Structured model serialization failed for {cls}: {exc}") - return str(v) - - # Dataclasses (instances only) - if is_dataclass(v) and not isinstance(v, type): - oid = id(v) - if oid in stack: - logger.debug("Cycle detected while encoding dataclass instance") - return _CYCLE_SENTINEL - stack.add(oid) - try: - # type(v) already narrows sufficiently; cast was redundant - dc_cls: type[Any] = type(v) - field_values: dict[str, Any] = {} - for f in fields(v): # type: ignore[arg-type] - field_values[f.name] = _enc(getattr(v, f.name), stack, depth + 1) - return { - _DATACLASS_MARKER: f"{dc_cls.__module__}:{dc_cls.__name__}", - "value": field_values, - } - finally: - stack.remove(oid) - - # Collections - if isinstance(v, dict): - v_dict = cast("dict[object, object]", v) - oid = id(v_dict) - if oid in stack: - logger.debug("Cycle detected while encoding dict") - return _CYCLE_SENTINEL - stack.add(oid) - try: - json_dict: dict[str, Any] = {} - for k_any, val_any in v_dict.items(): # type: ignore[assignment] - k_str: str = str(k_any) - json_dict[k_str] = _enc(val_any, stack, depth + 1) - return json_dict - finally: - stack.remove(oid) - - if isinstance(v, (list, tuple, set)): - iterable_v = cast("list[object] | tuple[object, ...] | set[object]", v) - oid = id(iterable_v) - if oid in stack: - logger.debug("Cycle detected while encoding iterable") - return _CYCLE_SENTINEL - stack.add(oid) - try: - seq: list[object] = list(iterable_v) - encoded_list: list[Any] = [] - for item in seq: - encoded_list.append(_enc(item, stack, depth + 1)) - return encoded_list - finally: - stack.remove(oid) - - # Primitives (or unknown objects): ensure JSON-serializable - if isinstance(v, (str, int, float, bool)) or v is None: - return v - # Fallback: stringify unknown objects to avoid JSON serialization errors - try: - return str(v) - except Exception: - return f"<{type(v).__name__}>" - - return _enc(value, set(), 0) - - -def _decode_checkpoint_value(value: Any) -> Any: - """Recursively decode values previously encoded by _encode_checkpoint_value.""" - if isinstance(value, dict): - value_dict = cast(dict[str, Any], value) # encoded form always uses string keys - # Structured model marker handling - if _MODEL_MARKER in value_dict and "value" in value_dict: - type_key: str | None = value_dict.get(_MODEL_MARKER) # type: ignore[assignment] - strategy: str | None = value_dict.get("strategy") # type: ignore[assignment] - raw_encoded: Any = value_dict.get("value") - decoded_payload = _decode_checkpoint_value(raw_encoded) - if isinstance(type_key, str): - try: - cls = _import_qualified_name(type_key) - except Exception as exc: - logger.debug(f"Failed to import structured model {type_key}: {exc}") - cls = None - - if cls is not None: - if strategy == "to_dict" and hasattr(cls, "from_dict"): - with contextlib.suppress(Exception): - return cls.from_dict(decoded_payload) - if strategy == "to_json" and hasattr(cls, "from_json"): - if isinstance(decoded_payload, (str, bytes, bytearray)): - with contextlib.suppress(Exception): - return cls.from_json(decoded_payload) - if isinstance(decoded_payload, dict) and hasattr(cls, "from_dict"): - with contextlib.suppress(Exception): - return cls.from_dict(decoded_payload) - return decoded_payload - # Dataclass marker handling - if _DATACLASS_MARKER in value_dict and "value" in value_dict: - type_key_dc: str | None = value_dict.get(_DATACLASS_MARKER) # type: ignore[assignment] - raw_dc: Any = value_dict.get("value") - decoded_raw = _decode_checkpoint_value(raw_dc) - if isinstance(type_key_dc, str): - try: - module_name, class_name = type_key_dc.split(":", 1) - module = sys.modules.get(module_name) - if module is None: - module = importlib.import_module(module_name) - cls_dc: Any = getattr(module, class_name) - constructed = _instantiate_checkpoint_dataclass(cls_dc, decoded_raw) - if constructed is not None: - return constructed - except Exception as exc: - logger.debug(f"Failed to decode dataclass {type_key_dc}: {exc}; returning raw value") - return decoded_raw - - # Regular dict: decode recursively - decoded: dict[str, Any] = {} - for k_any, v_any in value_dict.items(): - decoded[k_any] = _decode_checkpoint_value(v_any) - return decoded - if isinstance(value, list): - # After isinstance check, treat value as list[Any] for decoding - value_list: list[Any] = value # type: ignore[assignment] - return [_decode_checkpoint_value(v_any) for v_any in value_list] - return value + pending_request_info_events: dict[str, dict[str, Any]] @runtime_checkable @@ -619,6 +409,7 @@ async def restore_from_checkpoint(self, checkpoint_id: str) -> bool: "executor_states": checkpoint.executor_states, "iteration_count": checkpoint.iteration_count, "max_iterations": checkpoint.max_iterations, + "pending_request_info_events": checkpoint.pending_request_info_events, } await self.set_checkpoint_state(state) self._workflow_id = checkpoint.workflow_id @@ -631,49 +422,33 @@ async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None return await self._checkpoint_storage.load_checkpoint(checkpoint_id) async def get_checkpoint_state(self) -> CheckpointState: - serializable_messages: dict[str, list[dict[str, Any]]] = {} - for source_id, message_list in self._messages.items(): - serializable_messages[source_id] = [ - { - "data": _encode_checkpoint_value(msg.data), - "source_id": msg.source_id, - "target_id": msg.target_id, - "trace_contexts": msg.trace_contexts, - "source_span_ids": msg.source_span_ids, - } - for msg in message_list - ] return { - "messages": serializable_messages, - "shared_state": _encode_checkpoint_value(self._shared_state), - "executor_states": _encode_checkpoint_value(self._executor_states), + "messages": { + source_id: [msg.to_dict() for msg in message_list] for source_id, message_list in self._messages.items() + }, + "shared_state": encode_checkpoint_value(self._shared_state), + "executor_states": encode_checkpoint_value(self._executor_states), "iteration_count": self._iteration_count, "max_iterations": self._max_iterations, + "pending_request_info_events": encode_checkpoint_value(self._pending_request_info_events), } async def set_checkpoint_state(self, state: CheckpointState) -> None: + # Restore messages self._messages.clear() messages_data = state.get("messages", {}) for source_id, message_list in messages_data.items(): - self._messages[source_id] = [ - Message( - data=_decode_checkpoint_value(msg.get("data")), - source_id=msg.get("source_id", ""), - target_id=msg.get("target_id"), - trace_contexts=msg.get("trace_contexts"), - source_span_ids=msg.get("source_span_ids"), - ) - for msg in message_list - ] + self._messages[source_id] = [Message.from_dict(msg_data) for msg_data in message_list] + # Restore shared_state - decoded_shared_raw = _decode_checkpoint_value(state.get("shared_state", {})) + decoded_shared_raw = decode_checkpoint_value(state.get("shared_state", {})) if isinstance(decoded_shared_raw, dict): self._shared_state = cast(dict[str, Any], decoded_shared_raw) else: # fallback to empty dict if corrupted self._shared_state = {} # Restore executor_states ensuring value types are dicts - decoded_exec_raw = _decode_checkpoint_value(state.get("executor_states", {})) + decoded_exec_raw = decode_checkpoint_value(state.get("executor_states", {})) if isinstance(decoded_exec_raw, dict): typed_exec: dict[str, dict[str, Any]] = {} for k_raw, v_raw in decoded_exec_raw.items(): # type: ignore[assignment] @@ -691,6 +466,12 @@ async def set_checkpoint_state(self, state: CheckpointState) -> None: self._iteration_count = state.get("iteration_count", 0) self._max_iterations = state.get("max_iterations", 100) + # Pending request info events + self._pending_request_info_events = decode_checkpoint_value(state.get("pending_request_info_events", {})) + await asyncio.gather( + *(self.add_event(pending_request) for pending_request in self._pending_request_info_events.values()) + ) + async def add_request_info_event(self, event: RequestInfoEvent) -> None: """Add a RequestInfoEvent to the context and track it for correlation. @@ -718,15 +499,17 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No f"expected {event.response_type.__name__}, got {type(response).__name__}" ) - await self.send_message( - ResponseMessage( - data=response, - source_id=INTERNAL_SOURCE_ID(event.source_executor_id), - target_id=event.source_executor_id, - original_request=event.data, - ) + # Create ResponseMessage instance + response_msg = Message( + data=response, + source_id=INTERNAL_SOURCE_ID(event.source_executor_id), + target_id=event.source_executor_id, + type=MessageType.RESPONSE, + original_request=event.data, ) + await self.send_message(response_msg) + # Clear the event from pending requests self._pending_request_info_events.pop(request_id, None) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index ecfcf92788..24f4dacb70 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -391,19 +391,19 @@ async def run_stream(self, message: Any) -> AsyncIterable[WorkflowEvent]: WorkflowEvent: The events generated during the workflow execution. """ self._ensure_not_running() - try: - async def initial_execution() -> None: - executor = self.get_start_executor() - await executor.execute( - message, - [self.__class__.__name__], # source_executor_ids - self._shared_state, # shared_state - self._runner.context, # runner_context - trace_contexts=None, # No parent trace context for workflow start - source_span_ids=None, # No source span for workflow start - ) + async def initial_execution() -> None: + executor = self.get_start_executor() + await executor.execute( + message, + [self.__class__.__name__], # source_executor_ids + self._shared_state, # shared_state + self._runner.context, # runner_context + trace_contexts=None, # No parent trace context for workflow start + source_span_ids=None, # No source span for workflow start + ) + try: async for event in self._run_workflow_with_tracing( initial_executor_fn=initial_execution, reset_context=True, streaming=True ): diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py index 3dc80339bc..be24187ab9 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py @@ -1,11 +1,13 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from collections.abc import AsyncIterable from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +# NOTE: the Azure client imports above are real dependencies. When running this +# sample outside of Azure-enabled environments you may wish to swap in the +# `agent_framework.builtin` chat client or mock the writer executor. We keep the +# concrete import here so readers can see an end-to-end configuration. from agent_framework import ( AgentExecutor, AgentExecutorRequest, @@ -14,30 +16,20 @@ Executor, FileCheckpointStorage, RequestInfoEvent, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, Role, + Workflow, WorkflowBuilder, + WorkflowCheckpoint, WorkflowContext, WorkflowOutputEvent, - WorkflowRunState, WorkflowStatusEvent, get_checkpoint_summary, handler, + response_handler, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential -# NOTE: the Azure client imports above are real dependencies. When running this -# sample outside of Azure-enabled environments you may wish to swap in the -# `agent_framework.builtin` chat client or mock the writer executor. We keep the -# concrete import here so readers can see an end-to-end configuration. - -if TYPE_CHECKING: - from agent_framework import Workflow - from agent_framework._workflows._checkpoint import WorkflowCheckpoint - """ Sample: Checkpoint + human-in-the-loop quickstart. @@ -110,7 +102,7 @@ async def prepare(self, brief: str, ctx: WorkflowContext[AgentExecutorRequest, s @dataclass -class HumanApprovalRequest(RequestInfoMessage): +class HumanApprovalRequest: """Message sent to the human reviewer via RequestInfoExecutor.""" # These fields are intentionally simple because they are serialised into @@ -124,18 +116,12 @@ class HumanApprovalRequest(RequestInfoMessage): class ReviewGateway(Executor): """Routes agent drafts to humans and optionally back for revisions.""" - def __init__(self, id: str, reviewer_id: str, writer_id: str, finalize_id: str) -> None: + def __init__(self, id: str, writer_id: str) -> None: super().__init__(id=id) - self._reviewer_id = reviewer_id self._writer_id = writer_id - self._finalize_id = finalize_id @handler - async def on_agent_response( - self, - response: AgentExecutorResponse, - ctx: WorkflowContext[HumanApprovalRequest, str], - ) -> None: + async def on_agent_response(self, response: AgentExecutorResponse, ctx: WorkflowContext) -> None: # Capture the agent output so we can surface it to the reviewer and # persist iterations. The `RequestInfoExecutor` relies on this state to # rehydrate when checkpoints are restored. @@ -145,31 +131,32 @@ async def on_agent_response( # Emit a human approval request. Because this flows through # RequestInfoExecutor it will pause the workflow until an answer is # supplied either interactively or via pre-supplied responses. - await ctx.send_message( + await ctx.request_info( HumanApprovalRequest( prompt="Review the draft. Reply 'approve' or provide edit instructions.", draft=draft, iteration=iteration, ), - target_id=self._reviewer_id, + HumanApprovalRequest, + str, ) - @handler + @response_handler async def on_human_feedback( self, - feedback: RequestResponse[HumanApprovalRequest, str], + original_request: HumanApprovalRequest, + feedback: str, ctx: WorkflowContext[AgentExecutorRequest | str, str], ) -> None: # The RequestResponse wrapper gives us both the human data and the # original request message, even when resuming from checkpoints. - reply = (feedback.data or "").strip() + reply = feedback.strip() state = await ctx.get_state() or {} - draft = state.get("last_draft") or (feedback.original_request.draft if feedback.original_request else "") + draft = state.get("last_draft") or (original_request.draft or "") if reply.lower() == "approve": - # When the human signs off we can short-circuit the workflow and - # send the approved draft to the final executor. - await ctx.send_message(draft, target_id=self._finalize_id) + # Workflow is completed when the human approves. + await ctx.yield_output(draft) return # Any other response loops us back to the writer with fresh guidance. @@ -187,63 +174,34 @@ async def on_human_feedback( ) -class FinaliseExecutor(Executor): - """Publishes the approved text.""" - - @handler - async def publish(self, text: str, ctx: WorkflowContext[Any, str]) -> None: - # Store the output so diagnostics or a UI could fetch the final copy. - await ctx.set_state({"published_text": text}) - # Yield the final output so the workflow completes cleanly. - await ctx.yield_output(text) - - -def create_workflow(*, checkpoint_storage: FileCheckpointStorage | None = None) -> "Workflow": +def create_workflow(checkpoint_storage: FileCheckpointStorage) -> Workflow: """Assemble the workflow graph used by both the initial run and resume.""" - # The Azure client is created once so our agent executor can issue calls to - # the hosted model. The agent id is stable across runs which keeps - # checkpoints deterministic. + # The Azure client is created once so our agent executor can issue calls to the hosted + # model. The agent id is stable across runs which keeps checkpoints deterministic. chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) - writer = AgentExecutor( - chat_client.create_agent( - instructions="Write concise, warm release notes that sound human and helpful.", - ), - id="writer", - ) - # RequestInfoExecutor is the lynchpin for human-in-the-loop: every draft is - # routed through it so checkpoints can pause while waiting for responses. - review = RequestInfoExecutor(id="request_info") - finalise = FinaliseExecutor(id="finalise") - gateway = ReviewGateway( - id="review_gateway", - reviewer_id=review.id, - writer_id=writer.id, - finalize_id=finalise.id, - ) + agent = chat_client.create_agent(instructions="Write concise, warm release notes that sound human and helpful.") + + writer = AgentExecutor(agent, id="writer") + gateway = ReviewGateway(id="review_gateway", writer_id=writer.id) prepare = BriefPreparer(id="prepare_brief", agent_id=writer.id) # Wire the workflow DAG. Edges mirror the numbered steps described in the # module docstring. Because `WorkflowBuilder` is declarative, reading these # edges is often the quickest way to understand execution order. - builder = ( + workflow_builder = ( WorkflowBuilder(max_iterations=6) .set_start_executor(prepare) .add_edge(prepare, writer) .add_edge(writer, gateway) - .add_edge(gateway, review) - .add_edge(review, gateway) # human resumes loop - .add_edge(gateway, writer) # revisions - .add_edge(gateway, finalise) + .add_edge(gateway, writer) # revisions loop + .with_checkpointing(checkpoint_storage=checkpoint_storage) ) - # Opt-in to persistence when the caller provides storage. The workflow - # object itself is identical whether or not checkpointing is enabled. - if checkpoint_storage: - builder = builder.with_checkpointing(checkpoint_storage=checkpoint_storage) - return builder.build() + return workflow_builder.build() -def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: + +def render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: """Pretty-print saved checkpoints with the new framework summaries.""" print("\nCheckpoint summary:") @@ -251,7 +209,7 @@ def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: # Compose a single line per checkpoint so the user can scan the output # and pick the resume point that still has outstanding human work. line = ( - f"- {summary.checkpoint_id} | iter={summary.iteration_count} " + f"- {summary.checkpoint_id} | timestamp={summary.timestamp} | iter={summary.iteration_count} " f"| targets={summary.targets} | states={summary.executor_ids}" ) if summary.status: @@ -263,152 +221,70 @@ def _render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: print(line) -def _print_events(events: list[Any]) -> tuple[str | None, list[tuple[str, HumanApprovalRequest]]]: - """Echo workflow events to the console and collect outstanding requests.""" - - completed_output: str | None = None - requests: list[tuple[str, HumanApprovalRequest]] = [] - - for event in events: - print(f"Event: {event}") - if isinstance(event, WorkflowOutputEvent): - completed_output = event.data - if isinstance(event, RequestInfoEvent) and isinstance(event.data, HumanApprovalRequest): - # Capture pending human approvals so the caller can ask the user for - # input after the current batch of events is processed. - requests.append((event.request_id, event.data)) - elif isinstance(event, WorkflowStatusEvent) and event.state in { - WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS, - WorkflowRunState.IDLE_WITH_PENDING_REQUESTS, - }: - print(f"Workflow state: {event.state.name}") - - return completed_output, requests - - -def _prompt_for_responses(requests: list[tuple[str, HumanApprovalRequest]]) -> dict[str, str] | None: +def prompt_for_responses(requests: dict[str, HumanApprovalRequest]) -> dict[str, str]: """Interactive CLI prompt for any live RequestInfo requests.""" - if not requests: - return None - answers: dict[str, str] = {} - for request_id, request in requests: - # Keep the prompt conversational so testers can use the script without - # memorising the workflow APIs. + responses: dict[str, str] = {} + for request_id, request in requests.items(): print("\n=== Human approval needed ===") print(f"request_id: {request_id}") - if request.iteration: - print(f"Iteration: {request.iteration}") + print(f"Iteration: {request.iteration}") print(request.prompt) print("Draft: \n---\n" + request.draft + "\n---") - answer = input("Type 'approve' or enter revision guidance (or 'exit' to quit): ").strip() # noqa: ASYNC250 - if answer.lower() == "exit": + response = input("Type 'approve' or enter revision guidance (or 'exit' to quit): ").strip() + if response.lower() == "exit": raise SystemExit("Stopped by user.") - answers[request_id] = answer - return answers - - -def _maybe_pre_supply_responses(cp: "WorkflowCheckpoint") -> dict[str, str] | None: - """Offer to collect responses before resuming a checkpoint.""" - - pending = get_checkpoint_summary(cp).pending_requests - if not pending: - return None - - print( - "This checkpoint still has pending human input. Provide the responses now so the resume step " - "applies them immediately and does not re-emit the original RequestInfo event." - ) - choice = input("Pre-supply responses for this checkpoint? [y/N]: ").strip().lower() # noqa: ASYNC250 - if choice not in {"y", "yes"}: - return None - - answers: dict[str, str] = {} - for item in pending: - iteration = item.iteration or 0 - print(f"\nPending draft (iteration {iteration} | request_id={item.request_id}):") - draft_text = (item.draft or "").strip() - if draft_text: - # The shortened preview in the summary may truncate text; here we - # show the full draft so the reviewer can make an informed choice. - print("Draft:\n---\n" + draft_text + "\n---") - else: - print("Draft: [not captured in checkpoint payload - refer to your notes/log]") - prompt_text = (item.prompt or "Review the draft").strip() - print(prompt_text) - answer = input("Response ('approve' or guidance, 'exit' to abort): ").strip() # noqa: ASYNC250 - if answer.lower() == "exit": - raise SystemExit("Resume aborted by user.") - answers[item.request_id] = answer - return answers - - -async def _consume(stream: AsyncIterable[Any]) -> list[Any]: - """Materialise an async event stream into a list.""" + responses[request_id] = response - return [event async for event in stream] + return responses -async def run_interactive_session(workflow: "Workflow", initial_message: str) -> str | None: +async def run_interactive_session( + workflow: Workflow, + initial_message: str | None = None, + checkpoint_id: str | None = None, +) -> str: """Run the workflow until it either finishes or pauses for human input.""" - pending_responses: dict[str, str] | None = None + requests: dict[str, HumanApprovalRequest] = {} + responses: dict[str, str] | None = None completed_output: str | None = None - first = True - - while completed_output is None: - if first: - # Kick off the workflow with the initial brief. The returned events - # include RequestInfo events when the agent produces a draft. - events = await _consume(workflow.run_stream(initial_message)) - first = False - elif pending_responses: - # Feed any answers the user just typed back into the workflow. - events = await _consume(workflow.send_responses_streaming(pending_responses)) + + while True: + if responses: + event_stream = workflow.send_responses_streaming(responses) + requests.clear() else: + if initial_message: + print(f"\nStarting workflow with brief: {initial_message}\n") + event_stream = workflow.run_stream(initial_message) + elif checkpoint_id: + print("\nStarting workflow from checkpoint...\n") + event_stream = workflow.run_stream_from_checkpoint(checkpoint_id) + else: + raise ValueError("Either initial_message or checkpoint_id must be provided") + + async for event in event_stream: + if isinstance(event, WorkflowStatusEvent): + print(event) + if isinstance(event, WorkflowOutputEvent): + completed_output = event.data + if isinstance(event, RequestInfoEvent): + if isinstance(event.data, HumanApprovalRequest): + requests[event.request_id] = event.data + else: + raise ValueError("Unexpected request data type") + + if completed_output: break - completed_output, requests = _print_events(events) - if completed_output is None: - pending_responses = _prompt_for_responses(requests) - - return completed_output + if requests: + responses = prompt_for_responses(requests) + continue + raise RuntimeError("Workflow stopped without completing or requesting input") -async def resume_from_checkpoint( - workflow: "Workflow", - checkpoint_id: str, - storage: FileCheckpointStorage, - pre_supplied: dict[str, str] | None, -) -> None: - """Resume a stored checkpoint and continue until completion or another pause.""" - - print(f"\nResuming from checkpoint: {checkpoint_id}") - events = await _consume( - workflow.run_stream_from_checkpoint( - checkpoint_id, - checkpoint_storage=storage, - responses=pre_supplied, - ) - ) - completed_output, requests = _print_events(events) - if pre_supplied and not requests and completed_output is None: - # When the checkpoint only needed the provided answers we let the user - # know the workflow is waiting for the next superstep (usually another - # agent response). - print("Pre-supplied responses applied automatically; workflow is now waiting for the next step.") - - pending = _prompt_for_responses(requests) - while completed_output is None and pending: - events = await _consume(workflow.send_responses_streaming(pending)) - completed_output, requests = _print_events(events) - if completed_output is None: - pending = _prompt_for_responses(requests) - else: - break - - if completed_output: - print(f"Workflow completed with: {completed_output}") + return completed_output async def main() -> None: @@ -428,11 +304,8 @@ async def main() -> None: ) print("Running workflow (human approval required)...") - completed = await run_interactive_session(workflow, initial_message=brief) - if completed: - print(f"Initial run completed with final copy: {completed}") - else: - print("Initial run paused for human input.") + result = await run_interactive_session(workflow, initial_message=brief) + print(f"Workflow completed with: {result}") checkpoints = await storage.list_checkpoints() if not checkpoints: @@ -441,7 +314,7 @@ async def main() -> None: # Show the user what is available before we prompt for the index. The # summary helper keeps this output consistent with other tooling. - _render_checkpoint_summary(checkpoints) + render_checkpoint_summary(checkpoints) sorted_cps = sorted(checkpoints, key=lambda c: c.timestamp) print("\nAvailable checkpoints:") @@ -472,14 +345,11 @@ async def main() -> None: print("Selected checkpoint already reflects a completed workflow; nothing to resume.") return - # If the user wants, capture their decisions now so the resume call can - # push them into the workflow and avoid re-prompting. - pre_responses = _maybe_pre_supply_responses(chosen) - - resumed_workflow = create_workflow() + new_workflow = create_workflow(checkpoint_storage=storage) # Resume with a fresh workflow instance. The checkpoint carries the # persistent state while this object holds the runtime wiring. - await resume_from_checkpoint(resumed_workflow, chosen.checkpoint_id, storage, pre_responses) + result = await run_interactive_session(new_workflow, checkpoint_id=chosen.checkpoint_id) + print(f"Workflow completed with: {result}") if __name__ == "__main__": From 2c5f595f44bae4d894e1646d09502bb2a889aa43 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 21 Oct 2025 11:32:33 -0700 Subject: [PATCH 06/26] checkpointing with sub workflow --- .../agent_framework/_workflows/_events.py | 29 ++ .../agent_framework/_workflows/_runner.py | 1 + .../_workflows/_runner_context.py | 48 +-- .../_workflows/_typing_utils.py | 25 ++ .../agent_framework/_workflows/_workflow.py | 135 ++---- .../_workflows/_workflow_builder.py | 19 +- .../_workflows/_workflow_executor.py | 406 +++++++----------- .../core/tests/workflow/test_validation.py | 6 +- .../checkpoint/sub_workflow_checkpoint.py | 169 +++++--- 9 files changed, 382 insertions(+), 456 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index a339ea9cb3..0df8e8001e 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -10,6 +10,9 @@ from agent_framework import AgentRunResponse, AgentRunResponseUpdate +from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value +from ._typing_utils import deserialize_type, serialize_type + if TYPE_CHECKING: from ._request_info_executor import RequestInfoMessage @@ -240,6 +243,32 @@ def __repr__(self) -> str: f"response_type={self.response_type.__name__})" ) + def to_dict(self) -> dict[str, Any]: + """Convert the request info event to a dictionary for serialization.""" + return { + "data": encode_checkpoint_value(self.data), + "request_id": self.request_id, + "source_executor_id": self.source_executor_id, + "request_type": serialize_type(self.request_type), + "response_type": serialize_type(self.response_type), + } + + @staticmethod + def from_dict(data: dict[str, Any]) -> "RequestInfoEvent": + """Create a RequestInfoEvent from a dictionary.""" + # Validation + for property in ["data", "request_id", "source_executor_id", "request_type", "response_type"]: + if property not in data: + raise KeyError(f"Missing '{property}' field in RequestInfoEvent dictionary.") + + return RequestInfoEvent( + request_id=data["request_id"], + source_executor_id=data["source_executor_id"], + request_type=deserialize_type(data["request_type"]), + request_data=decode_checkpoint_value(data["data"]), + response_type=deserialize_type(data["response_type"]), + ) + class WorkflowOutputEvent(WorkflowEvent): """Event triggered when a workflow executor yields output.""" diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index e9979f59b0..f90a088633 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -348,6 +348,7 @@ async def _restore_executor_states(self, executor_states: dict[str, dict[str, An await maybe # type: ignore[arg-type] restored = True except Exception as ex: # pragma: no cover - defensive + # TODO(@taochen): should we swallow the exception? logger.debug(f"Executor {exec_id} restore_state failed: {ex}") if not restored: diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index 3010ea2281..fdf4ae3426 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -72,8 +72,15 @@ def to_dict(self) -> dict[str, Any]: @staticmethod def from_dict(data: dict[str, Any]) -> "Message": """Create a Message from a dictionary.""" + # Validation + if "data" not in data: + raise KeyError("Missing 'data' field in Message dictionary.") + + if "source_id" not in data: + raise KeyError("Missing 'source_id' field in Message dictionary.") + return Message( - data=decode_checkpoint_value(data.get("data")), + data=decode_checkpoint_value(data["data"]), source_id=data["source_id"], target_id=data.get("target_id"), type=MessageType(data.get("type", "standard")), @@ -382,6 +389,7 @@ async def create_checkpoint(self, metadata: dict[str, Any] | None = None) -> str executor_states=state.get("executor_states", {}), iteration_count=state.get("iteration_count", 0), max_iterations=state.get("max_iterations", DEFAULT_MAX_ITERATIONS), + pending_request_info_events=state.get("pending_request_info_events", {}), metadata=metadata or {}, ) checkpoint_id = await self._checkpoint_storage.save_checkpoint(checkpoint) @@ -394,42 +402,29 @@ async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None return await self._checkpoint_storage.load_checkpoint(checkpoint_id) async def get_workflow_state(self) -> WorkflowState: - serializable_messages: dict[str, list[dict[str, Any]]] = {} + serialized_messages: dict[str, list[dict[str, Any]]] = {} for source_id, message_list in self._messages.items(): - serializable_messages[source_id] = [ - { - "data": encode_checkpoint_value(msg.data), - "source_id": msg.source_id, - "target_id": msg.target_id, - "trace_contexts": msg.trace_contexts, - "source_span_ids": msg.source_span_ids, - } - for msg in message_list - ] + serialized_messages[source_id] = [msg.to_dict() for msg in message_list] + + serialized_pending_request_info_events: dict[str, dict[str, Any]] = { + request_id: request.to_dict() for request_id, request in self._pending_request_info_events.items() + } return { - "messages": serializable_messages, + "messages": serialized_messages, "shared_state": encode_checkpoint_value(self._shared_state), "executor_states": encode_checkpoint_value(self._executor_states), "iteration_count": self._iteration_count, "max_iterations": self._max_iterations, - "pending_request_info_events": encode_checkpoint_value(self._pending_request_info_events), + "pending_request_info_events": serialized_pending_request_info_events, } async def set_workflow_state(self, state: WorkflowState) -> None: self._messages.clear() messages_data = state.get("messages", {}) for source_id, message_list in messages_data.items(): - self._messages[source_id] = [ - Message( - data=decode_checkpoint_value(msg.get("data")), - source_id=msg.get("source_id", ""), - target_id=msg.get("target_id"), - trace_contexts=msg.get("trace_contexts"), - source_span_ids=msg.get("source_span_ids"), - ) - for msg in message_list - ] + self._messages[source_id] = [Message.from_dict(msg) for msg in message_list] + # Restore shared_state decoded_shared_raw = decode_checkpoint_value(state.get("shared_state", {})) if isinstance(decoded_shared_raw, dict): @@ -457,7 +452,10 @@ async def set_workflow_state(self, state: WorkflowState) -> None: self._max_iterations = state.get("max_iterations", 100) # Pending request info events - self._pending_request_info_events = decode_checkpoint_value(state.get("pending_request_info_events", {})) + self._pending_request_info_events = { + request_id: RequestInfoEvent.from_dict(request) + for request_id, request in state.get("pending_request_info_events", {}).items() + } await asyncio.gather( *(self.add_event(pending_request) for pending_request in self._pending_request_info_events.values()) ) diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index f085fee5b1..8be339ab94 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -153,3 +153,28 @@ def is_instance_of(data: Any, target_type: type | UnionType | Any) -> bool: # Fallback: if we reach here, we assume data is an instance of the target_type return isinstance(data, target_type) + + +def serialize_type(t: type) -> str: + """Serialize a type to a string. + + For example, + + serialize_type(int) => "builtins.int" + """ + return f"{t.__module__}.{t.__qualname__}" + + +def deserialize_type(serialized_type_string: str) -> type: + """Deserialize a serialized type string. + + For example, + + deserialize_type("builtins.int") => int + """ + import importlib + + module_name, _, type_name = serialized_type_string.partition(".") + module = importlib.import_module(module_name) + + return getattr(module, type_name) diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index e2f768a960..148b6e20d8 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -35,7 +35,6 @@ from ._runner import Runner from ._runner_context import RunnerContext from ._shared_state import SharedState -from ._workflow_context import WorkflowContext if sys.version_info >= (3, 11): pass # pragma: no cover @@ -370,6 +369,7 @@ async def _run_workflow_with_tracing( capture_exception(span, exception=exc) raise + # region Streaming Run async def run_stream(self, message: Any) -> AsyncIterable[WorkflowEvent]: """Run the workflow with a starting message and stream events. @@ -404,16 +404,13 @@ async def run_stream_from_checkpoint( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None = None, - responses: dict[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Resume workflow execution from a checkpoint and stream events. Args: checkpoint_id: The ID of the checkpoint to restore from. checkpoint_storage: Optional checkpoint storage to use for restoration. - If not provided, the workflow must have been built with checkpointing enabled. - responses: Optional dictionary of responses to inject into the workflow - after restoration. Keys are request IDs, values are response data. + If not provided, the workflow must have been built with checkpointing enabled. Yields: WorkflowEvent: Events generated during workflow execution. @@ -424,57 +421,8 @@ async def run_stream_from_checkpoint( """ self._ensure_not_running() try: - - async def checkpoint_restoration() -> None: - has_checkpointing = self._runner.context.has_checkpointing() - - if not has_checkpointing and checkpoint_storage is None: - raise ValueError( - "Cannot restore from checkpoint: either provide checkpoint_storage parameter " - "or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)." - ) - - restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage) - - if not restored: - raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}") - - # Process any pending messages from the checkpoint first - # This ensures that RequestInfoExecutor state is properly populated - # before we try to handle responses - if await self._runner.context.has_messages(): - # Run one iteration to process pending messages - # This will populate RequestInfoExecutor._request_events properly - await self._runner._run_iteration() # type: ignore - - if responses: - request_info_executor = self._find_request_info_executor() - if request_info_executor: - for request_id, response_data in responses.items(): - ctx: WorkflowContext[Any] = WorkflowContext( - request_info_executor.id, - [self.__class__.__name__], - self._shared_state, - self._runner.context, - trace_contexts=None, # No parent trace context for new workflow span - source_span_ids=None, # No source span for response handling - ) - - if not await request_info_executor.has_pending_request(request_id, ctx): - logger.debug( - f"Skipping pre-supplied response for request {request_id}; " - f"no pending request found after checkpoint restoration." - ) - continue - - await request_info_executor.handle_response( - response_data, - request_id, - ctx, - ) - async for event in self._run_workflow_with_tracing( - initial_executor_fn=checkpoint_restoration, + initial_executor_fn=functools.partial(self._checkpoint_restoration, checkpoint_id, checkpoint_storage), reset_context=False, # Don't reset context when resuming from checkpoint streaming=True, ): @@ -503,6 +451,10 @@ async def send_responses_streaming(self, responses: dict[str, Any]) -> AsyncIter finally: self._reset_running_flag() + # endregion: Streaming Run + + # region: Run + async def run(self, message: Any, *, include_status_events: bool = False) -> WorkflowRunResult: """Run the workflow with the given message. @@ -559,7 +511,6 @@ async def run_from_checkpoint( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None = None, - responses: dict[str, Any] | None = None, ) -> WorkflowRunResult: """Resume workflow execution from a checkpoint. @@ -567,8 +518,6 @@ async def run_from_checkpoint( checkpoint_id: The ID of the checkpoint to restore from. checkpoint_storage: Optional checkpoint storage to use for restoration. If not provided, the workflow must have been built with checkpointing enabled. - responses: Optional dictionary of responses to inject into the workflow - after restoration. Keys are request IDs, values are response data. Returns: A WorkflowRunResult instance containing a list of events generated during the workflow execution. @@ -579,59 +528,12 @@ async def run_from_checkpoint( """ self._ensure_not_running() try: - - async def checkpoint_restoration() -> None: - has_checkpointing = self._runner.context.has_checkpointing() - - if not has_checkpointing and checkpoint_storage is None: - raise ValueError( - "Cannot restore from checkpoint: either provide checkpoint_storage parameter " - "or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)." - ) - - restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage) - - if not restored: - raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}") - - # Process any pending messages from the checkpoint first - # This ensures that RequestInfoExecutor state is properly populated - # before we try to handle responses - if await self._runner.context.has_messages(): - # Run one iteration to process pending messages - # This will populate RequestInfoExecutor._request_events properly - await self._runner._run_iteration() # type: ignore - - if responses: - request_info_executor = self._find_request_info_executor() - if request_info_executor: - for request_id, response_data in responses.items(): - ctx: WorkflowContext[Any] = WorkflowContext( - request_info_executor.id, - [self.__class__.__name__], - self._shared_state, - self._runner.context, - trace_contexts=None, # No parent trace context for new workflow span - source_span_ids=None, # No source span for response handling - ) - - if not await request_info_executor.has_pending_request(request_id, ctx): - logger.debug( - f"Skipping pre-supplied response for request {request_id}; " - f"no pending request found after checkpoint restoration." - ) - continue - - await request_info_executor.handle_response( - response_data, - request_id, - ctx, - ) - events = [ event async for event in self._run_workflow_with_tracing( - initial_executor_fn=checkpoint_restoration, + initial_executor_fn=functools.partial( + self._checkpoint_restoration, checkpoint_id, checkpoint_storage + ), reset_context=False, # Don't reset context when resuming from checkpoint ) ] @@ -665,6 +567,8 @@ async def send_responses(self, responses: dict[str, Any]) -> WorkflowRunResult: finally: self._reset_running_flag() + # endregion: Run + async def _send_responses_internal(self, responses: dict[str, Any]) -> None: """Internal method to validate and send responses to the executors.""" pending_requests = await self._runner_context.get_pending_request_info_events() @@ -687,6 +591,21 @@ async def _send_responses_internal(self, responses: dict[str, Any]) -> None: for request_id, response in responses.items() ]) + async def _checkpoint_restoration(self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None) -> None: + """Internal method to restore a run from a checkpoint.""" + has_checkpointing = self._runner.context.has_checkpointing() + + if not has_checkpointing and checkpoint_storage is None: + raise ValueError( + "Cannot restore from checkpoint: either provide checkpoint_storage parameter " + "or build workflow with WorkflowBuilder.with_checkpointing(checkpoint_storage)." + ) + + restored = await self._runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage) + + if not restored: + raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}") + def _get_executor_by_id(self, executor_id: str) -> Executor: """Get an executor by its ID. diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index d282fdd3e2..8bc1888f86 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -16,6 +16,7 @@ EdgeGroup, FanInEdgeGroup, FanOutEdgeGroup, + InternalEdgeGroup, SingleEdgeGroup, SwitchCaseEdgeGroup, SwitchCaseEdgeGroupCase, @@ -56,7 +57,6 @@ def __init__( """ self._edge_groups: list[EdgeGroup] = [] self._executors: dict[str, Executor] = {} - self._duplicate_executor_ids: set[str] = set() self._start_executor: Executor | str | None = None self._checkpoint_storage: CheckpointStorage | None = None self._max_iterations: int = max_iterations @@ -73,10 +73,18 @@ def __init__( def _add_executor(self, executor: Executor) -> str: """Add an executor to the map and return its ID.""" existing = self._executors.get(executor.id) - if existing is not None and existing is not executor: - self._duplicate_executor_ids.add(executor.id) - else: - self._executors[executor.id] = executor + if existing is not None: + if existing is executor: + # Already added + return executor.id + # ID conflict + raise ValueError(f"Duplicate executor ID '{executor.id}' detected in workflow.") + + # New executor + self._executors[executor.id] = executor + # Add an internal edge group for each unique executor + self._edge_groups.append(InternalEdgeGroup(executor.id)) + return executor.id def _maybe_wrap_agent( @@ -408,7 +416,6 @@ def build(self) -> Workflow: self._edge_groups, self._executors, self._start_executor, - duplicate_executor_ids=tuple(self._duplicate_executor_ids), ) # Add validation completed event diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 5a3f3844e4..9bf69d93bb 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -1,15 +1,16 @@ # Copyright (c) Microsoft. All rights reserved. +import asyncio import contextlib import logging import uuid -from collections.abc import Mapping from dataclasses import dataclass from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from ._workflow import Workflow +from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._events import ( RequestInfoEvent, WorkflowErrorEvent, @@ -20,13 +21,9 @@ Executor, handler, ) -from ._request_info_executor import ( - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, -) from ._runner_context import Message from ._typing_utils import is_instance_of +from ._workflow import WorkflowRunResult from ._workflow_context import WorkflowContext logger = logging.getLogger(__name__) @@ -36,9 +33,19 @@ class ExecutionContext: """Context for tracking a single sub-workflow execution.""" + # The ID of the execution context execution_id: str + + # Responses that have been collected so far for requests that + # were sent out in the previous iteration collected_responses: dict[str, Any] # request_id -> response_data + + # Number of responses to be expected. If the WorkflowExecutor has + # not received all responses, it won't run the sub workflow. expected_response_count: int + + # Pending requests to be fulfilled. This will get updated as the + # WorkflowExecutor receives responses. pending_requests: dict[str, RequestInfoEvent] # request_id -> request_info_event @@ -269,7 +276,6 @@ def __init__(self, workflow: "Workflow", id: str, **kwargs: Any): self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext # Map request_id to execution_id for response routing self._request_to_execution: dict[str, str] = {} # request_id -> execution_id - self._active_executions: int = 0 # Count of active sub-workflow executions self._state_loaded: bool = False @property @@ -324,7 +330,7 @@ def can_handle(self, message: Message) -> bool: # For other messages, only handle if the wrapped workflow can accept them as input return any(is_instance_of(message.data, input_type) for input_type in self.workflow.input_types) - @handler # No output_types - can send any completion data type + @handler async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) -> None: """Execute the sub-workflow with raw input data. @@ -335,11 +341,6 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) input_data: The input data to send to the sub-workflow. ctx: The workflow context from the parent. """ - # Skip RequestResponse - it has a specific handler - if isinstance(input_data, RequestResponse): - logger.debug(f"WorkflowExecutor {self.id} ignoring input of type {type(input_data)}") - return - await self._ensure_state_loaded(ctx) # Create execution context for this sub-workflow run @@ -352,9 +353,6 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) ) self._execution_contexts[execution_id] = execution_context - # Track this execution - self._active_executions += 1 - logger.debug(f"WorkflowExecutor {self.id} starting sub-workflow {self.workflow.id} execution {execution_id}") try: @@ -374,95 +372,6 @@ async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) exec_ctx = self._execution_contexts[execution_id] if not exec_ctx.pending_requests: del self._execution_contexts[execution_id] - self._active_executions -= 1 - - async def _process_workflow_result( - self, result: Any, execution_context: ExecutionContext, ctx: WorkflowContext[Any] - ) -> None: - """Process the result from a workflow execution. - - This method handles the common logic for processing outputs, request info events, - and final states that is shared between process_workflow and handle_response. - - Args: - result: The workflow execution result. - execution_context: The execution context for this sub-workflow run. - ctx: The workflow context. - """ - # Collect all events from the workflow - request_info_events = result.get_request_info_events() - outputs = result.get_outputs() - final_state = result.get_final_state() - logger.debug( - f"WorkflowExecutor {self.id} processing workflow result with " - f"{len(outputs)} outputs and {len(request_info_events)} request info events, " - f"final state: {final_state}" - ) - - # Process outputs - for output in outputs: - # TODO(@taochen): Allow the sub-workflow to output directly - await ctx.send_message(output) - - # Process request info events - for event in request_info_events: - # Track the pending request in execution context - execution_context.pending_requests[event.request_id] = event - # Map request to execution for response routing - self._request_to_execution[event.request_id] = execution_context.execution_id - # TODO(@taochen): There should be two ways a sub-workflow can make a request: - # 1. In a workflow where the parent workflow has an executor that may intercept the - # request and handle it directly, a message should be sent. - # 2. In a workflow where the parent workflow does not handle the request, the request - # should be propagated via the `request_info` mechanism to an external source. And - # a @response_handler would be required in the WorkflowExecutor to handle the response. - await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) - - # Update expected response count for this execution - execution_context.expected_response_count = len(request_info_events) - - # Handle final state - if final_state == WorkflowRunState.FAILED: - # Find the WorkflowFailedEvent. - failed_events = [e for e in result if isinstance(e, WorkflowFailedEvent)] - if failed_events: - failed_event = failed_events[0] - error_type = failed_event.details.error_type - error_message = failed_event.details.message - exception = Exception( - f"Sub-workflow {self.workflow.id} failed with error: {error_type} - {error_message}" - ) - error_event = WorkflowErrorEvent( - data=exception, - ) - await ctx.add_event(error_event) - self._active_executions -= 1 - elif final_state == WorkflowRunState.IDLE: - # Sub-workflow is idle - nothing more to do now - logger.debug(f"Sub-workflow {self.workflow.id} is idle with {self._active_executions} active executions") - self._active_executions -= 1 # Treat idle as completion for now - elif final_state == WorkflowRunState.CANCELLED: - # Sub-workflow was cancelled - treat as completion - logger.debug( - f"Sub-workflow {self.workflow.id} was cancelled with {self._active_executions} active executions" - ) - self._active_executions -= 1 - elif final_state == WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS: - # Sub-workflow is still running with pending requests - logger.debug( - f"Sub-workflow {self.workflow.id} is still in progress with {len(request_info_events)} " - f"pending requests with {self._active_executions} active executions" - ) - elif final_state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS: - # Sub-workflow is idle but has pending requests - logger.debug( - f"Sub-workflow {self.workflow.id} is idle with pending requests: " - f"{len(request_info_events)} with {self._active_executions} active executions" - ) - else: - raise RuntimeError(f"Unexpected final state: {final_state}") - - await self._persist_execution_state(ctx) @handler async def handle_response(self, response: SubWorkflowResponseMessage, ctx: WorkflowContext[Any]) -> None: @@ -528,7 +437,6 @@ async def handle_response(self, response: SubWorkflowResponseMessage, ctx: Workf # Clean up execution context if it's completed (no pending requests) if not execution_context.pending_requests: del self._execution_contexts[execution_id] - self._active_executions -= 1 async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None: if self._state_loaded: @@ -547,159 +455,161 @@ async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None: else: self._state_loaded = True - def restore_state(self, state: dict[str, Any]) -> None: + async def restore_state(self, state: dict[str, Any]) -> None: """Restore pending request bookkeeping from a checkpoint snapshot.""" - self._execution_contexts = {} - self._request_to_execution = {} - - executions_payload = state.get("executions") - if isinstance(executions_payload, Mapping) and executions_payload: - for execution_id, payload in executions_payload.items(): - if not isinstance(execution_id, str) or not isinstance(payload, Mapping): - continue - - pending_ids_raw = payload.get("pending_request_ids", []) - if not isinstance(pending_ids_raw, list): - continue - pending_ids = [rid for rid in pending_ids_raw if isinstance(rid, str)] - - expected = payload.get("expected_response_count", len(pending_ids)) - try: - expected_count = int(expected) - except (TypeError, ValueError): - expected_count = len(pending_ids) - - collected_ids_raw = payload.get("collected_response_ids", []) - collected: dict[str, Any] = {} - if isinstance(collected_ids_raw, list): - for rid in collected_ids_raw: - if isinstance(rid, str): - collected[rid] = None - - exec_ctx = ExecutionContext( - execution_id=execution_id, - collected_responses=collected, - expected_response_count=expected_count, - pending_requests={rid: None for rid in pending_ids}, - ) - - if exec_ctx.pending_requests or exec_ctx.collected_responses: - self._execution_contexts[execution_id] = exec_ctx - for rid in exec_ctx.pending_requests: - self._request_to_execution[rid] = execution_id - else: - pending_ids = state.get("pending_request_ids", []) - if isinstance(pending_ids, list): - pending = [rid for rid in pending_ids if isinstance(rid, str)] - if pending: - try: - expected = int(state.get("expected_response_count", len(pending))) - except (TypeError, ValueError): - expected = len(pending) - - execution_id = str(uuid.uuid4()) - exec_ctx = ExecutionContext( - execution_id=execution_id, - collected_responses={}, - expected_response_count=expected, - pending_requests={rid: None for rid in pending}, - ) - self._execution_contexts[execution_id] = exec_ctx - for rid in pending: - self._request_to_execution[rid] = execution_id - + # Validate the state contains the right keys + if "execution_contexts" not in state: + raise KeyError("Missing 'execution_contexts' in WorkflowExecutor state.") + if "request_to_execution" not in state: + raise KeyError("Missing 'request_to_execution' in WorkflowExecutor state.") + + # Validate the execution contexts stored in the state have the right keys and values + execution_contexts: dict[str, ExecutionContext] | None = None try: - self._active_executions = int(state.get("active_executions", len(self._execution_contexts))) - except (TypeError, ValueError): - self._active_executions = len(self._execution_contexts) - - helper_states = state.get("request_info_executor_states", {}) - restored_request_data: dict[str, RequestInfoMessage] = {} - if isinstance(helper_states, Mapping): - for exec_id, helper_state in helper_states.items(): - helper_executor = self.workflow.executors.get(exec_id) - if not isinstance(helper_executor, RequestInfoExecutor) or not isinstance(helper_state, Mapping): - continue - with contextlib.suppress(Exception): - helper_executor.restore_state(dict(helper_state)) - for req_id, event in getattr(helper_executor, "_request_events", {}).items(): # type: ignore[attr-defined] - if ( - isinstance(req_id, str) - and isinstance(event, RequestInfoEvent) - and isinstance(event.data, RequestInfoMessage) - ): - restored_request_data[req_id] = event.data - - if restored_request_data: - for req_id, data in restored_request_data.items(): - execution_id = self._request_to_execution.get(req_id) - if execution_id and execution_id in self._execution_contexts: - self._execution_contexts[execution_id].pending_requests[req_id] = data - - for execution_id, exec_ctx in self._execution_contexts.items(): - for req_id in exec_ctx.pending_requests: - self._request_to_execution.setdefault(req_id, execution_id) - - request_map = state.get("request_to_execution") - if isinstance(request_map, Mapping): - for req_id, execution_id in request_map.items(): - if ( - isinstance(req_id, str) - and isinstance(execution_id, str) - and execution_id in self._execution_contexts - ): - self._request_to_execution.setdefault(req_id, execution_id) - - self._state_loaded = True + execution_contexts = { + key: decode_checkpoint_value(value) for key, value in state["execution_contexts"].items() + } + except Exception as ex: + raise RuntimeError("Failed to deserialize execution context.") from ex + + if not all( + isinstance(key, str) and isinstance(value, ExecutionContext) for key, value in execution_contexts.items() + ): + raise ValueError("Execution contexts must have 'str' as key and 'ExecutionContext' as value.") + if not all(key == value.execution_id for key, value in execution_contexts.items()): + raise ValueError("Execution contexts must have matching keys and IDs.") + + # Validate the request_to_execution map contain the right data + request_to_execution = state["request_to_execution"] + if not all(isinstance(key, str) and isinstance(value, str) for key, value in request_to_execution.items()): + raise ValueError("Request to execution map must have 'str' as key and 'str' as value.") + if not all(value in execution_contexts for value in request_to_execution.values()): + raise ValueError( + "'request_to_execution` contains unknown execution ID that is not part of the execution contexts." + ) - def _build_state_snapshot(self) -> dict[str, Any]: - executions: dict[str, Any] = {} - pending_request_ids: list[str] = [] + self._execution_contexts = execution_contexts + self._request_to_execution = request_to_execution + + # Add the `request_info_event`s back to the sub workflow. + # This is only a temporary solution to rehydrate the sub workflow with the requests. + # The proper way would be to rehydrate the workflow from a checkpoint on a Workflow + # API instead of the '_runner_context' object that should be hidden. And the sub workflow + # should be rehydrated from a checkpoint object instead of from a subset of the state. + # TODO(@taochen): how to handle the case when the parent workflow has checkpointing + # set up but not the sub workflow? + request_info_events = [ + request_info_event + for execution_context in self._execution_contexts.values() + for request_info_event in execution_context.pending_requests.values() + ] + await asyncio.gather(*[ + self.workflow._runner_context.add_request_info_event(event) for event in request_info_events + ]) - for execution_id, exec_ctx in self._execution_contexts.items(): - if not exec_ctx.pending_requests and not exec_ctx.collected_responses: - continue + self._state_loaded = True - request_ids = list(exec_ctx.pending_requests.keys()) - pending_request_ids.extend(request_ids) + async def _persist_execution_state(self, ctx: WorkflowContext) -> None: + """Persist the state of the WorkflowExecutor for checkpointing purposes.""" + state = { + "execution_contexts": { + execution_id: encode_checkpoint_value(execution_context) + for execution_id, execution_context in self._execution_contexts.items() + }, + "request_to_execution": dict(self._request_to_execution), + } - summary: dict[str, Any] = { - "pending_request_ids": request_ids, - "expected_response_count": exec_ctx.expected_response_count, - } + try: + await ctx.set_executor_state(state) + except Exception as exc: # pragma: no cover - transport specific + logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}") - if exec_ctx.collected_responses: - summary["collected_response_ids"] = list(exec_ctx.collected_responses.keys()) + async def _process_workflow_result( + self, + result: WorkflowRunResult, + execution_context: ExecutionContext, + ctx: WorkflowContext[Any], + ) -> None: + """Process the result from a workflow execution. - executions[execution_id] = summary + This method handles the common logic for processing outputs, request info events, + and final states that is shared between process_workflow and handle_response. - helper_states: dict[str, Any] = {} - for exec_id, executor in self.workflow.executors.items(): - if isinstance(executor, RequestInfoExecutor): - with contextlib.suppress(Exception): - snapshot = executor.snapshot_state() - if snapshot: - helper_states[exec_id] = snapshot + Args: + result: The workflow execution result. + execution_context: The execution context for this sub-workflow run. + ctx: The workflow context. + """ + # Collect all events from the workflow + request_info_events = result.get_request_info_events() + outputs = result.get_outputs() + workflow_run_state = result.get_final_state() + logger.debug( + f"WorkflowExecutor {self.id} processing workflow result with " + f"{len(outputs)} outputs and {len(request_info_events)} request info events. " + f"Workflow run state: {workflow_run_state}" + ) - has_state = bool(executions or helper_states or self._request_to_execution) - if not has_state: - return {} + # Process outputs + for output in outputs: + # TODO(@taochen): Allow the sub-workflow to output directly + await ctx.send_message(output) - state: dict[str, Any] = { - "executions": executions, - "request_to_execution": dict(self._request_to_execution), - "pending_request_ids": pending_request_ids, - "active_executions": self._active_executions, - } + # Process request info events + for event in request_info_events: + # Track the pending request in execution context + execution_context.pending_requests[event.request_id] = event + # Map request to execution for response routing + self._request_to_execution[event.request_id] = execution_context.execution_id + # TODO(@taochen): There should be two ways a sub-workflow can make a request: + # 1. In a workflow where the parent workflow has an executor that may intercept the + # request and handle it directly, a message should be sent. + # 2. In a workflow where the parent workflow does not handle the request, the request + # should be propagated via the `request_info` mechanism to an external source. And + # a @response_handler would be required in the WorkflowExecutor to handle the response. + await ctx.send_message(SubWorkflowRequestMessage(source_event=event, executor_id=self.id)) - if helper_states: - state["request_info_executor_states"] = helper_states + # Update expected response count for this execution + execution_context.expected_response_count = len(request_info_events) - return state + # Handle final state + if workflow_run_state == WorkflowRunState.FAILED: + # Find the WorkflowFailedEvent. + failed_events = [e for e in result if isinstance(e, WorkflowFailedEvent)] + if failed_events: + failed_event = failed_events[0] + error_type = failed_event.details.error_type + error_message = failed_event.details.message + exception = Exception( + f"Sub-workflow {self.workflow.id} failed with error: {error_type} - {error_message}" + ) + error_event = WorkflowErrorEvent( + data=exception, + ) + await ctx.add_event(error_event) + elif workflow_run_state == WorkflowRunState.IDLE: + # Sub-workflow is idle - nothing more to do now + logger.debug( + f"Sub-workflow {self.workflow.id} is idle with {len(self._execution_contexts)} active executions" + ) + elif workflow_run_state == WorkflowRunState.CANCELLED: + # Sub-workflow was cancelled - treat as completion + logger.debug( + f"Sub-workflow {self.workflow.id} was cancelled with {len(self._execution_contexts)} active executions" + ) + elif workflow_run_state == WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS: + # Sub-workflow is still running with pending requests + logger.debug( + f"Sub-workflow {self.workflow.id} is still in progress with {len(request_info_events)} " + f"pending requests with {len(self._execution_contexts)} active executions" + ) + elif workflow_run_state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS: + # Sub-workflow is idle but has pending requests + logger.debug( + f"Sub-workflow {self.workflow.id} is idle with pending requests: " + f"{len(request_info_events)} with {len(self._execution_contexts)} active executions" + ) + else: + raise RuntimeError(f"Unexpected workflow run state: {workflow_run_state}") - async def _persist_execution_state(self, ctx: WorkflowContext[Any]) -> None: - snapshot = self._build_state_snapshot() - try: - await ctx.set_executor_state(snapshot) - except Exception as exc: # pragma: no cover - transport specific - logger.warning(f"WorkflowExecutor {self.id} failed to persist state: {exc}") + await self._persist_execution_state(ctx) diff --git a/python/packages/core/tests/workflow/test_validation.py b/python/packages/core/tests/workflow/test_validation.py index 7c1a687bd4..37e7f18ed5 100644 --- a/python/packages/core/tests/workflow/test_validation.py +++ b/python/packages/core/tests/workflow/test_validation.py @@ -8,7 +8,6 @@ from agent_framework import ( EdgeDuplicationError, Executor, - ExecutorDuplicationError, GraphConnectivityError, TypeCompatibilityError, ValidationTypeEnum, @@ -83,11 +82,10 @@ def test_duplicate_executor_ids_fail_validation(): executor1 = StringExecutor(id="dup") executor2 = IntExecutor(id="dup") - with pytest.raises(ExecutorDuplicationError) as exc_info: + with pytest.raises(ValueError) as exc_info: (WorkflowBuilder().add_edge(executor1, executor2).set_start_executor(executor1).build()) - assert exc_info.value.executor_id == "dup" - assert exc_info.value.validation_type == ValidationTypeEnum.EXECUTOR_DUPLICATION + assert str(exc_info.value) == "Duplicate executor ID 'dup' detected in workflow." def test_edge_duplication_validation_fails(): diff --git a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py index 1591c6f049..21e385e3c6 100644 --- a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py +++ b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py @@ -3,6 +3,7 @@ import asyncio import contextlib import json +import uuid from dataclasses import dataclass, field, replace from datetime import datetime, timedelta from pathlib import Path @@ -11,9 +12,8 @@ Executor, FileCheckpointStorage, RequestInfoEvent, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, + SubWorkflowRequestMessage, + SubWorkflowResponseMessage, Workflow, WorkflowBuilder, WorkflowContext, @@ -22,6 +22,7 @@ WorkflowRunState, WorkflowStatusEvent, handler, + response_handler, ) CHECKPOINT_DIR = Path(__file__).with_suffix("").parent / "tmp" / "sub_workflow_checkpoints" @@ -78,9 +79,10 @@ class FinalDraft: @dataclass -class ReviewRequest(RequestInfoMessage): +class ReviewRequest: """Human approval request surfaced via RequestInfoExecutor.""" + id: str = str(uuid.uuid4()) topic: str = "" iteration: int = 1 draft_excerpt: str = "" @@ -88,6 +90,14 @@ class ReviewRequest(RequestInfoMessage): reviewer_guidance: list[str] = field(default_factory=list) # type: ignore +@dataclass +class ReviewDecision: + """The review decision to be sent to downstream executors along with the original request.""" + + decision: str + original_request: ReviewRequest + + # --------------------------------------------------------------------------- # Sub-workflow executors # --------------------------------------------------------------------------- @@ -122,7 +132,8 @@ def __init__(self) -> None: super().__init__(id="draft_review") @handler - async def request_review(self, draft: DraftPackage, ctx: WorkflowContext[ReviewRequest]) -> None: + async def request_review(self, draft: DraftPackage, ctx: WorkflowContext) -> None: + """Request a review upon receiving a draft.""" excerpt = draft.content.splitlines()[0] request = ReviewRequest( topic=draft.topic, @@ -134,15 +145,17 @@ async def request_review(self, draft: DraftPackage, ctx: WorkflowContext[ReviewR "Confirm CTA is action-oriented", ], ) - await ctx.send_message(request, target_id="sub_review_requests") + await ctx.request_info(request, ReviewRequest, str) - @handler + @response_handler async def forward_decision( self, - decision: RequestResponse[ReviewRequest, str], - ctx: WorkflowContext[RequestResponse[ReviewRequest, str]], + original_request: ReviewRequest, + decision: str, + ctx: WorkflowContext[ReviewDecision], ) -> None: - await ctx.send_message(decision, target_id="draft_finaliser") + """Route the decision to the next executor.""" + await ctx.send_message(ReviewDecision(decision=decision, original_request=original_request)) class DraftFinaliser(Executor): @@ -154,11 +167,11 @@ def __init__(self) -> None: @handler async def on_review_decision( self, - decision: RequestResponse[ReviewRequest, str], + review_decision: ReviewDecision, ctx: WorkflowContext[DraftTask, FinalDraft], ) -> None: - reply = (decision.data or "").strip().lower() - original = decision.original_request + reply = review_decision.decision.strip().lower() + original = review_decision.original_request topic = original.topic if original else "unknown topic" iteration = original.iteration if original else 1 @@ -192,12 +205,11 @@ class LaunchCoordinator(Executor): def __init__(self) -> None: super().__init__(id="launch_coordinator") - self._final: FinalDraft | None = None @handler async def kick_off(self, topic: str, ctx: WorkflowContext[DraftTask]) -> None: task = DraftTask(topic=topic, due=_utc_now() + timedelta(hours=2)) - await ctx.send_message(task, target_id="launch_subworkflow") + await ctx.send_message(task) @handler async def collect_final(self, draft: FinalDraft, ctx: WorkflowContext[None, FinalDraft]) -> None: @@ -209,8 +221,6 @@ async def collect_final(self, draft: FinalDraft, ctx: WorkflowContext[None, Fina normalised = replace(draft, approved_at=parsed) approved_at = parsed - self._final = normalised - approved_display = approved_at.isoformat() if hasattr(approved_at, "isoformat") else str(approved_at) print("\n>>> Parent workflow received approved draft:") @@ -221,9 +231,50 @@ async def collect_final(self, draft: FinalDraft, ctx: WorkflowContext[None, Fina await ctx.yield_output(normalised) - @property - def final_result(self) -> FinalDraft | None: - return self._final + @handler + async def handler_sub_workflow_request( + self, + request: SubWorkflowRequestMessage, + ctx: WorkflowContext, + ) -> None: + """Handle requests from the sub-workflow. + + Note that the message type must be SubWorkflowRequestMessage to intercept the request. + """ + if not isinstance(request.source_event.data, ReviewRequest): + raise TypeError(f"Expected 'ReviewRequest', got {type(request.source_event.data)}") + + # Record the request to response matching + review_request = request.source_event.data + executor_state = await ctx.get_executor_state() or {} + executor_state[review_request.id] = request + await ctx.set_executor_state(executor_state) + + # Send the request without modification + await ctx.request_info(review_request, ReviewRequest, str) + + @response_handler + async def handle_request_response( + self, + original_request: ReviewRequest, + response: str, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + """Process the response and send it back to the sub-workflow. + + Note that the response must be sent back using SubWorkflowResponseMessage to route + the response back to the sub-workflow. + """ + executor_state = await ctx.get_executor_state() or {} + request_message = executor_state.pop(original_request.id, None) + + # Save the executor state back to the context + await ctx.set_executor_state(executor_state) + + if request_message is None: + raise ValueError("No matching pending request found for the resource response") + + await ctx.send_message(request_message.create_response(response)) # --------------------------------------------------------------------------- @@ -234,17 +285,13 @@ def final_result(self) -> FinalDraft | None: def build_sub_workflow() -> WorkflowExecutor: writer = DraftWriter() router = DraftReviewRouter() - request_info = RequestInfoExecutor(id="sub_review_requests") finaliser = DraftFinaliser() sub_workflow = ( WorkflowBuilder() .set_start_executor(writer) .add_edge(writer, router) - .add_edge(router, request_info) - .add_edge(request_info, router, condition=lambda msg: isinstance(msg, RequestResponse)) - .add_edge(router, finaliser, condition=lambda msg: isinstance(msg, RequestResponse)) - .add_edge(request_info, finaliser) + .add_edge(router, finaliser) .add_edge(finaliser, writer) # permits revision loops .build() ) @@ -252,28 +299,19 @@ def build_sub_workflow() -> WorkflowExecutor: return WorkflowExecutor(sub_workflow, id="launch_subworkflow") -def build_parent_workflow(storage: FileCheckpointStorage) -> tuple[LaunchCoordinator, Workflow]: +def build_parent_workflow(storage: FileCheckpointStorage) -> Workflow: coordinator = LaunchCoordinator() sub_executor = build_sub_workflow() - parent_request_info = RequestInfoExecutor(id="parent_review_gateway") - workflow = ( + return ( WorkflowBuilder() .set_start_executor(coordinator) .add_edge(coordinator, sub_executor) - .add_edge(sub_executor, coordinator, condition=lambda msg: isinstance(msg, FinalDraft)) - .add_edge( - sub_executor, - parent_request_info, - condition=lambda msg: isinstance(msg, RequestInfoMessage), - ) - .add_edge(parent_request_info, sub_executor) + .add_edge(sub_executor, coordinator) .with_checkpointing(storage) .build() ) - return coordinator, workflow - async def main() -> None: CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True) @@ -282,9 +320,10 @@ async def main() -> None: storage = FileCheckpointStorage(CHECKPOINT_DIR) - _, workflow = build_parent_workflow(storage) + workflow = build_parent_workflow(storage) print("\n=== Stage 1: run until sub-workflow requests human review ===") + request_id: str | None = None async for event in workflow.run_stream("Contoso Gadget Launch"): if isinstance(event, RequestInfoEvent) and request_id is None: @@ -294,52 +333,52 @@ async def main() -> None: break if request_id is None: - print("Sub-workflow completed without requesting review.") - return + raise RuntimeError("Sub-workflow completed without requesting review.") checkpoints = await storage.list_checkpoints(workflow.id) if not checkpoints: - print("No checkpoints written.") - return + raise RuntimeError("No checkpoints found.") + # Print the checkpoint to show pending requests + # We didn't handle the request above so the request is still pending the last checkpoint checkpoints.sort(key=lambda cp: cp.timestamp) resume_checkpoint = checkpoints[-1] print(f"Using checkpoint {resume_checkpoint.checkpoint_id} at iteration {resume_checkpoint.iteration_count}") checkpoint_path = storage.storage_path / f"{resume_checkpoint.checkpoint_id}.json" if checkpoint_path.exists(): - snapshot = json.loads(checkpoint_path.read_text()) - exec_states = snapshot.get("executor_states", {}) - sub_pending = exec_states.get("sub_review_requests", {}).get("request_events", {}) - parent_pending = exec_states.get("parent_review_gateway", {}).get("request_events", {}) - print(f"Pending review requests (sub executor snapshot): {list(sub_pending.keys())}") - print(f"Pending review requests (parent executor snapshot): {list(parent_pending.keys())}") - - print("\n=== Stage 2: resume from checkpoint and approve draft ===") + checkpoint_content_dict = json.loads(checkpoint_path.read_text()) + print(f"Pending review requests: {checkpoint_content_dict.get('pending_request_info_events', {})}") + + print("\n=== Stage 2: resume from checkpoint ===") + # Rebuild fresh instances to mimic a separate process resuming - coordinator2, workflow2 = build_parent_workflow(storage) + workflow2 = build_parent_workflow(storage) - approval_response = "approve" - final_event: WorkflowOutputEvent | None = None + request_info_event: RequestInfoEvent | None = None async for event in workflow2.run_stream_from_checkpoint( resume_checkpoint.checkpoint_id, - responses={request_id: approval_response}, ): + if isinstance(event, RequestInfoEvent): + request_info_event = event + + if request_info_event is None: + raise RuntimeError("No request_info_event captured.") + + print("\n=== Stage 3: approve draft ==") + + approval_response = "approve" + output_event: WorkflowOutputEvent | None = None + async for event in workflow2.send_responses_streaming({request_info_event.request_id: approval_response}): if isinstance(event, WorkflowOutputEvent): - final_event = event + output_event = event - if final_event is None: - print("Workflow did not complete after resume.") - return + if output_event is None: + raise RuntimeError("Workflow did not complete after resume.") - final = final_event.data + output = output_event.data print("\n=== Final Draft (from resumed run) ===") - print(final) - - if coordinator2.final_result is None: - print("Coordinator did not capture final result via handler.") - else: - print("Coordinator stored final draft successfully.") + print(output) """" Sample Output: From 4939e46a79ddffe69dca5a803e1b5656d4b4c4b5 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 21 Oct 2025 13:12:34 -0700 Subject: [PATCH 07/26] Fix function executor --- .../_workflows/_function_executor.py | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_function_executor.py b/python/packages/core/agent_framework/_workflows/_function_executor.py index eaeaf15f63..afe6ccb02f 100644 --- a/python/packages/core/agent_framework/_workflows/_function_executor.py +++ b/python/packages/core/agent_framework/_workflows/_function_executor.py @@ -39,9 +39,12 @@ def __init__(self, func: Callable[..., Any], id: str | None = None): # Validate function signature and extract types message_type, ctx_annotation, output_types, workflow_output_types = _validate_function_signature(func) + # Store the original function + self._original_func = func # Determine if function has WorkflowContext parameter - has_context = ctx_annotation is not None - is_async = asyncio.iscoroutinefunction(func) + self._has_context = ctx_annotation is not None + # Determine if the function is an async function + self._is_async = asyncio.iscoroutinefunction(func) # Initialize parent WITHOUT calling _discover_handlers yet # We'll manually set up the attributes first @@ -49,25 +52,18 @@ def __init__(self, func: Callable[..., Any], id: str | None = None): kwargs = {"type": "FunctionExecutor"} super().__init__(id=executor_id, defer_discovery=True, **kwargs) - self._handlers = {} - self._handler_specs = [] - - # Store the original function and whether it has context - self._original_func = func - self._has_context = has_context - self._is_async = is_async # Create a wrapper function that always accepts both message and context - if has_context and is_async: + if self._has_context and self._is_async: # Async function with context - already has the right signature wrapped_func: Callable[[Any, WorkflowContext[Any]], Awaitable[Any]] = func # type: ignore - elif has_context and not is_async: + elif self._has_context and not self._is_async: # Sync function with context - wrap to make async using thread pool async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: # Call the sync function with both parameters in a thread return await asyncio.to_thread(func, message, ctx) # type: ignore - elif not has_context and is_async: + elif not self._has_context and self._is_async: # Async function without context - wrap to ignore context async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: # Call the async function with just the message @@ -91,6 +87,7 @@ async def wrapped_func(message: Any, ctx: WorkflowContext[Any]) -> Any: # Now we can safely call _discover_handlers (it won't find any class-level handlers) self._discover_handlers() + self._discover_response_handlers() if not self._handlers: raise ValueError( From 2767ee5f1a7c9873855cbd3443dd974cc1a1331e Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 22 Oct 2025 08:54:43 -0700 Subject: [PATCH 08/26] Allow sub-workflow to output directly --- .../_workflows/_workflow_executor.py | 21 +++++++++++++------ .../sub_workflow_parallel_requests.py | 13 +++++++++++- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 9bf69d93bb..1c8721c6c0 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -109,7 +109,7 @@ class WorkflowExecutor(Executor): 1. Starts the wrapped workflow with the input message 2. Runs the sub-workflow to completion or until it needs external input 3. Processes the sub-workflow's complete event stream after execution - 4. Forwards outputs to the parent workflow's event stream + 4. Forwards outputs to the parent workflow as messages 5. Handles external requests by routing them to the parent workflow 6. Accumulates responses and resumes sub-workflow execution @@ -259,18 +259,25 @@ async def handle_request( - Concurrent executions are fully isolated and do not interfere with each other """ - def __init__(self, workflow: "Workflow", id: str, **kwargs: Any): + def __init__(self, workflow: "Workflow", id: str, allow_direct_output: bool = False, **kwargs: Any): """Initialize the WorkflowExecutor. Args: workflow: The workflow to execute as a sub-workflow. id: Unique identifier for this executor. + allow_direct_output: Whether to allow direct output from the sub-workflow. + By default, outputs from the sub-workflow are sent to + other executors in the parent workflow as messages. + When this is set to true, the outputs are yielded + directly from the WorkflowExecutor to the parent + workflow's event stream. Keyword Args: **kwargs: Additional keyword arguments passed to the parent constructor. """ super().__init__(id, **kwargs) self.workflow = workflow + self.allow_direct_output = allow_direct_output # Track execution contexts for concurrent sub-workflow executions self._execution_contexts: dict[str, ExecutionContext] = {} # execution_id -> ExecutionContext @@ -496,7 +503,7 @@ async def restore_state(self, state: dict[str, Any]) -> None: # The proper way would be to rehydrate the workflow from a checkpoint on a Workflow # API instead of the '_runner_context' object that should be hidden. And the sub workflow # should be rehydrated from a checkpoint object instead of from a subset of the state. - # TODO(@taochen): how to handle the case when the parent workflow has checkpointing + # TODO(@taochen#1614): how to handle the case when the parent workflow has checkpointing # set up but not the sub workflow? request_info_events = [ request_info_event @@ -551,9 +558,11 @@ async def _process_workflow_result( ) # Process outputs - for output in outputs: - # TODO(@taochen): Allow the sub-workflow to output directly - await ctx.send_message(output) + if self.allow_direct_output: + # Note that the executor is allowed to continue its own execution after yielding outputs. + await asyncio.gather(*[ctx.yield_output(output) for output in outputs]) + else: + await asyncio.gather(*[ctx.send_message(output) for output in outputs]) # Process request info events for event in request_info_events: diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py index fc58f06315..b22e452d5b 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py @@ -292,8 +292,17 @@ async def main() -> None: resource_allocator = ResourceAllocator("resource_allocator") policy_engine = PolicyEngine("policy_engine") + # Create the WorkflowExecutor for the sub-workflow + # Setting allow_direct_output=True to let the sub-workflow output directly. + # This is because the sub-workflow is the both the entry point and the exit + # point of the main workflow. + sub_workflow_executor = WorkflowExecutor( + sub_workflow, + "sub_workflow_executor", + allow_direct_output=True, + ) + # Build the main workflow - sub_workflow_executor = WorkflowExecutor(sub_workflow, "sub_workflow_executor") main_workflow = ( WorkflowBuilder() .set_start_executor(sub_workflow_executor) @@ -348,6 +357,8 @@ async def main() -> None: for output in outputs: # TODO(@taochen): Allow the sub-workflow to output directly print(f"- {output}") + else: + raise RuntimeError("Workflow did not produce an output.") if __name__ == "__main__": From bd82c02f74f6137d4cca8652363317121e800635 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Wed, 22 Oct 2025 13:46:33 -0700 Subject: [PATCH 09/26] Remove ReqeustInfoExecutor and related classes; Debugging checkpoint_with_human_in_the_loop --- .../agent_framework/_workflows/__init__.py | 10 - .../agent_framework/_workflows/__init__.pyi | 10 - .../_workflows/_checkpoint_summary.py | 161 +---- .../agent_framework/_workflows/_events.py | 4 +- .../agent_framework/_workflows/_executor.py | 14 +- .../agent_framework/_workflows/_magentic.py | 65 +- .../_workflows/_request_info_executor.py | 573 ------------------ .../agent_framework/_workflows/_runner.py | 35 +- .../_workflows/_runner_context.py | 3 - .../_workflows/_typing_utils.py | 31 +- .../agent_framework/_workflows/_validation.py | 6 +- .../agent_framework/_workflows/_workflow.py | 26 +- .../_workflows/_workflow_executor.py | 48 +- .../getting_started/workflows/README.md | 4 +- ...re_chat_agents_tool_calls_with_feedback.py | 186 +++--- .../workflow_as_agent_human_in_the_loop.py | 50 +- .../checkpoint_with_human_in_the_loop.py | 27 +- .../checkpoint/sub_workflow_checkpoint.py | 4 +- .../orchestration/magentic_checkpoint.py | 41 +- 19 files changed, 248 insertions(+), 1050 deletions(-) delete mode 100644 python/packages/core/agent_framework/_workflows/_request_info_executor.py diff --git a/python/packages/core/agent_framework/_workflows/__init__.py b/python/packages/core/agent_framework/_workflows/__init__.py index ad379d8aa8..95b00f61b2 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.py +++ b/python/packages/core/agent_framework/_workflows/__init__.py @@ -74,12 +74,6 @@ MagenticStartMessage, StandardMagenticManager, ) -from ._request_info_executor import ( - PendingRequestDetails, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, -) from ._request_info_mixin import response_handler from ._runner import Runner from ._runner_context import ( @@ -148,11 +142,7 @@ "MagenticResponseMessage", "MagenticStartMessage", "Message", - "PendingRequestDetails", "RequestInfoEvent", - "RequestInfoExecutor", - "RequestInfoMessage", - "RequestResponse", "Runner", "RunnerContext", "SequentialBuilder", diff --git a/python/packages/core/agent_framework/_workflows/__init__.pyi b/python/packages/core/agent_framework/_workflows/__init__.pyi index 0599325cbd..a449b72213 100644 --- a/python/packages/core/agent_framework/_workflows/__init__.pyi +++ b/python/packages/core/agent_framework/_workflows/__init__.pyi @@ -72,12 +72,6 @@ from ._magentic import ( MagenticStartMessage, StandardMagenticManager, ) -from ._request_info_executor import ( - PendingRequestDetails, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, -) from ._request_info_mixin import response_handler from ._runner import Runner from ._runner_context import ( @@ -146,11 +140,7 @@ __all__ = [ "MagenticResponseMessage", "MagenticStartMessage", "Message", - "PendingRequestDetails", "RequestInfoEvent", - "RequestInfoExecutor", - "RequestInfoMessage", - "RequestResponse", "Runner", "RunnerContext", "SequentialBuilder", diff --git a/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py b/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py index 539175d308..44f43a5d02 100644 --- a/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py +++ b/python/packages/core/agent_framework/_workflows/_checkpoint_summary.py @@ -1,14 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. import logging -from collections.abc import Iterable, Mapping from dataclasses import dataclass -from textwrap import shorten -from typing import Any from ._checkpoint import WorkflowCheckpoint -from ._checkpoint_encoding import decode_checkpoint_value -from ._request_info_executor import PendingRequestDetails, RequestInfoMessage, RequestResponse +from ._events import RequestInfoEvent logger = logging.getLogger(__name__) @@ -23,35 +19,23 @@ class WorkflowCheckpointSummary: targets: list[str] executor_ids: list[str] status: str - draft_preview: str | None - pending_requests: list[PendingRequestDetails] + pending_request_info_events: list[RequestInfoEvent] -def get_checkpoint_summary( - checkpoint: WorkflowCheckpoint, - *, - request_executor_ids: Iterable[str] | None = None, - preview_width: int = 70, -) -> WorkflowCheckpointSummary: +def get_checkpoint_summary(checkpoint: WorkflowCheckpoint) -> WorkflowCheckpointSummary: targets = sorted(checkpoint.messages.keys()) executor_ids = sorted(checkpoint.executor_states.keys()) - pending = _pending_requests_from_checkpoint(checkpoint, request_executor_ids=request_executor_ids) - - draft_preview: str | None = None - for entry in pending: - if entry.draft: - draft_preview = shorten(entry.draft, width=preview_width, placeholder="…") - break + pending_request_info_events = [ + RequestInfoEvent.from_dict(request) for request in checkpoint.pending_request_info_events.values() + ] status = "idle" - if pending: + if pending_request_info_events: status = "awaiting request response" elif not checkpoint.messages and "finalise" in executor_ids: status = "completed" elif checkpoint.messages: status = "awaiting next superstep" - elif request_executor_ids is not None and any(tid in targets for tid in request_executor_ids): - status = "awaiting request delivery" return WorkflowCheckpointSummary( checkpoint_id=checkpoint.checkpoint_id, @@ -60,134 +44,5 @@ def get_checkpoint_summary( targets=targets, executor_ids=executor_ids, status=status, - draft_preview=draft_preview, - pending_requests=pending, + pending_request_info_events=pending_request_info_events, ) - - -def _pending_requests_from_checkpoint( - checkpoint: WorkflowCheckpoint, - *, - request_executor_ids: Iterable[str] | None = None, -) -> list[PendingRequestDetails]: - executor_filter: set[str] | None = None - if request_executor_ids is not None: - executor_filter = {str(value) for value in request_executor_ids} - - pending: dict[str, PendingRequestDetails] = {} - - for state in checkpoint.executor_states.values(): - if not isinstance(state, Mapping): - continue - inner = state.get("pending_requests") - if isinstance(inner, Mapping): - for request_id, snapshot in inner.items(): # type: ignore[attr-defined] - _merge_snapshot(pending, str(request_id), snapshot) # type: ignore[arg-type] - - for source_id, message_list in checkpoint.messages.items(): - if executor_filter is not None and source_id not in executor_filter: - continue - if not isinstance(message_list, list): - continue - for message in message_list: - if not isinstance(message, Mapping): - continue - payload = decode_checkpoint_value(message.get("data")) - _merge_message_payload(pending, payload, message) - - return list(pending.values()) - - -def _merge_snapshot(pending: dict[str, PendingRequestDetails], request_id: str, snapshot: Any) -> None: - if not request_id or not isinstance(snapshot, Mapping): - return - - details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) - - _apply_update( - details, - prompt=snapshot.get("prompt"), # type: ignore[attr-defined] - draft=snapshot.get("draft"), # type: ignore[attr-defined] - iteration=snapshot.get("iteration"), # type: ignore[attr-defined] - source_executor_id=snapshot.get("source_executor_id"), # type: ignore[attr-defined] - ) - - extra = snapshot.get("details") # type: ignore[attr-defined] - if isinstance(extra, Mapping): - _apply_update( - details, - prompt=extra.get("prompt"), # type: ignore[attr-defined] - draft=extra.get("draft"), # type: ignore[attr-defined] - iteration=extra.get("iteration"), # type: ignore[attr-defined] - ) - - -def _merge_message_payload( - pending: dict[str, PendingRequestDetails], - payload: Any, - raw_message: Mapping[str, Any], -) -> None: - if isinstance(payload, RequestResponse): - request_id = payload.request_id or _get_field(payload.original_request, "request_id") # type: ignore[arg-type] - if not request_id: - return - details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) - _apply_update( - details, - prompt=_get_field(payload.original_request, "prompt"), # type: ignore[arg-type] - draft=_get_field(payload.original_request, "draft"), # type: ignore[arg-type] - iteration=_get_field(payload.original_request, "iteration"), # type: ignore[arg-type] - source_executor_id=raw_message.get("source_id"), - original_request=payload.original_request, # type: ignore[arg-type] - ) - elif isinstance(payload, RequestInfoMessage): - request_id = getattr(payload, "request_id", None) - if not request_id: - return - details = pending.setdefault(request_id, PendingRequestDetails(request_id=request_id)) - _apply_update( - details, - prompt=getattr(payload, "prompt", None), - draft=getattr(payload, "draft", None), - iteration=getattr(payload, "iteration", None), - source_executor_id=raw_message.get("source_id"), - original_request=payload, - ) - - -def _apply_update( - details: PendingRequestDetails, - *, - prompt: Any = None, - draft: Any = None, - iteration: Any = None, - source_executor_id: Any = None, - original_request: Any = None, -) -> None: - if prompt and not details.prompt: - details.prompt = str(prompt) - if draft and not details.draft: - details.draft = str(draft) - if iteration is not None and details.iteration is None: - coerced = _coerce_int(iteration) - if coerced is not None: - details.iteration = coerced - if source_executor_id and not details.source_executor_id: - details.source_executor_id = str(source_executor_id) - if original_request is not None and details.original_request is None: - details.original_request = original_request - - -def _get_field(obj: Any, key: str) -> Any: - if obj is None: - return None - if isinstance(obj, Mapping): - return obj.get(key) # type: ignore[attr-defined,return-value] - return getattr(obj, key, None) - - -def _coerce_int(value: Any) -> int | None: - try: - return int(value) - except (TypeError, ValueError): - return None diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index 0df8e8001e..57f3f9c3d9 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -14,7 +14,7 @@ from ._typing_utils import deserialize_type, serialize_type if TYPE_CHECKING: - from ._request_info_executor import RequestInfoMessage + pass class WorkflowEventSource(str, Enum): @@ -214,7 +214,7 @@ def __init__( request_id: str, source_executor_id: str, request_type: type, - request_data: "RequestInfoMessage", + request_data: Any, response_type: type, ): """Initialize the request info event. diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index e9ca84d53c..635c2dd8bb 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -90,16 +90,16 @@ async def handle_text(self, message: str, ctx: WorkflowContext[str]) -> None: class ParentExecutor(Executor): @handler - async def handle_domain_request( + async def handle_subworkflow_request( self, - request: DomainRequest, # Subclass of RequestInfoMessage - ctx: WorkflowContext[RequestResponse[RequestInfoMessage, Any] | DomainRequest], + request: SubWorkflowRequestMessage, + ctx: WorkflowContext[SubWorkflowResponseMessage], ) -> None: if self.is_allowed(request.domain): - response = RequestResponse(data=True, original_request=request, request_id=request.request_id) - await ctx.send_message(response, target_id=request.source_executor_id) + response = request.create_response(data=True) + await ctx.send_message(response, target_id=request.executor_id) else: - await ctx.send_message(request) # Forward to external + await ctx.request_info(request.source_event) ## Context Types Handler methods receive different WorkflowContext variants based on their type annotations: @@ -331,8 +331,6 @@ def _discover_handlers(self) -> None: message_type = handler_spec["message_type"] # Keep full generic types for handler registration to avoid conflicts - # Different RequestResponse[T, U] specializations are distinct handler types - if self._handlers.get(message_type) is not None: raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}") diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 55a2c29a71..180f8e38e2 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -29,7 +29,7 @@ from ._events import WorkflowEvent from ._executor import Executor, handler from ._model_utils import DictConvertible, encode_value -from ._request_info_executor import RequestInfoMessage, RequestResponse +from ._request_info_mixin import response_handler from ._workflow import Workflow, WorkflowRunResult from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext @@ -388,11 +388,10 @@ def from_dict(cls, value: dict[str, Any]) -> "MagenticResponseMessage": @dataclass -class MagenticPlanReviewRequest(RequestInfoMessage): +class MagenticPlanReviewRequest: """Human-in-the-loop request to review and optionally edit the plan before execution.""" - # Because RequestInfoMessage defines a default field (request_id), - # subclass fields must also have defaults to satisfy dataclass rules. + request_id: str = field(default_factory=lambda: str(uuid4())) task_text: str = "" facts_text: str = "" plan_text: str = "" @@ -1163,10 +1162,11 @@ async def handle_response_message( # Continue with inner loop await self._run_inner_loop(context) - @handler + @response_handler async def handle_plan_review_response( self, - response: RequestResponse[MagenticPlanReviewRequest, MagenticPlanReviewReply], + original_request: MagenticPlanReviewRequest, + response: MagenticPlanReviewReply, context: WorkflowContext[ # may broadcast ledger next, or ask for another round of review MagenticResponseMessage | MagenticRequestMessage | MagenticPlanReviewRequest, ChatMessage @@ -1178,26 +1178,21 @@ async def handle_plan_review_response( if self._context is None: return - human = response.data - if human is None: # type: ignore[unreachable] - # Defensive fallback: treat as revise with empty comments - human = MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.REVISE, comments="") - - if human.decision == MagenticPlanReviewDecision.APPROVE: + if response.decision == MagenticPlanReviewDecision.APPROVE: # Close the review loop on approval (no further plan review requests this run) self._require_plan_signoff = False # If the user supplied an edited plan, adopt it - if human.edited_plan_text: + if response.edited_plan_text: # Update the manager's internal ledger and rebuild the combined message mgr_ledger = getattr(self._manager, "task_ledger", None) if mgr_ledger is not None: - mgr_ledger.plan.text = human.edited_plan_text + mgr_ledger.plan.text = response.edited_plan_text team_text = _team_block(self._participants) combined = self._manager.task_ledger_full_prompt.format( task=self._context.task.text, team=team_text, facts=(mgr_ledger.facts.text if mgr_ledger else ""), - plan=human.edited_plan_text, + plan=response.edited_plan_text, ) self._task_ledger = ChatMessage( role=Role.ASSISTANT, @@ -1205,10 +1200,10 @@ async def handle_plan_review_response( author_name=MAGENTIC_MANAGER_NAME, ) # If approved with comments but no edited text, apply comments via replan and proceed (no extra review) - elif human.comments: + elif response.comments: # Record the human feedback for grounding self._context.chat_history.append( - ChatMessage(role=Role.USER, text=f"Human plan feedback: {human.comments}") + ChatMessage(role=Role.USER, text=f"Human plan feedback: {response.comments}") ) # Ask the manager to replan based on comments; proceed immediately self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) @@ -1258,26 +1253,26 @@ async def handle_plan_review_response( return # If the user provided an edited plan, adopt it directly and ask them to confirm once more - if human.edited_plan_text: + if response.edited_plan_text: mgr_ledger2 = getattr(self._manager, "task_ledger", None) if mgr_ledger2 is not None: - mgr_ledger2.plan.text = human.edited_plan_text + mgr_ledger2.plan.text = response.edited_plan_text # Rebuild combined message for preview in the next review request team_text = _team_block(self._participants) combined = self._manager.task_ledger_full_prompt.format( task=self._context.task.text, team=team_text, facts=(mgr_ledger2.facts.text if mgr_ledger2 else ""), - plan=human.edited_plan_text, + plan=response.edited_plan_text, ) self._task_ledger = ChatMessage(role=Role.ASSISTANT, text=combined, author_name=MAGENTIC_MANAGER_NAME) await self._send_plan_review_request(context) return # Else pass comments into the chat history and replan with the manager - if human.comments: + if response.comments: self._context.chat_history.append( - ChatMessage(role=Role.USER, text=f"Human plan feedback: {human.comments}") + ChatMessage(role=Role.USER, text=f"Human plan feedback: {response.comments}") ) # Ask the manager to replan; this only adjusts the plan stage, not a full reset @@ -1484,13 +1479,8 @@ async def _check_within_limits_or_complete( return True - async def _send_plan_review_request( - self, - context: WorkflowContext[ - MagenticResponseMessage | MagenticRequestMessage | MagenticPlanReviewRequest, ChatMessage - ], - ) -> None: - """Emit a PlanReviewRequest via RequestInfoExecutor.""" + async def _send_plan_review_request(self, context: WorkflowContext) -> None: + """Send a PlanReviewRequest.""" # If plan sign-off is disabled (e.g., ran out of review rounds), do nothing if not self._require_plan_signoff: return @@ -1505,7 +1495,7 @@ async def _send_plan_review_request( plan_text=plan_text, round_index=self._plan_review_round, ) - await context.send_message(req) + await context.request_info(req, MagenticPlanReviewRequest, MagenticPlanReviewReply) class MagenticAgentExecutor(Executor): @@ -1937,26 +1927,11 @@ async def _on_agent_delta(agent_id: str, update: AgentRunResponseUpdate, is_fina # Create workflow builder and set orchestrator as start workflow_builder = WorkflowBuilder().set_start_executor(orchestrator_executor) - if self._enable_plan_review: - from ._request_info_executor import RequestInfoExecutor - - request_info = RequestInfoExecutor(id="magentic_plan_review") - workflow_builder = ( - workflow_builder - # Only route plan review asks to request_info - .add_edge( - orchestrator_executor, - request_info, - condition=lambda msg: isinstance(msg, MagenticPlanReviewRequest), - ).add_edge(request_info, orchestrator_executor) - ) - def _route_to_agent(msg: object, *, agent_name: str) -> bool: """Route only messages meant for this agent. - MagenticRequestMessage -> only to the named agent - MagenticResponseMessage -> broadcast=True to all, or target_agent==agent_name - Everything else (e.g., RequestInfoMessage) -> do not route to agents. """ if isinstance(msg, MagenticRequestMessage): return msg.agent_name == agent_name diff --git a/python/packages/core/agent_framework/_workflows/_request_info_executor.py b/python/packages/core/agent_framework/_workflows/_request_info_executor.py deleted file mode 100644 index e979dc71d3..0000000000 --- a/python/packages/core/agent_framework/_workflows/_request_info_executor.py +++ /dev/null @@ -1,573 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import contextlib -import importlib -import json -import logging -import uuid -from collections.abc import Mapping, Sequence -from dataclasses import asdict, dataclass, field, fields, is_dataclass -from typing import Any, ClassVar, Generic, TypeVar, cast - -from ._events import ( - RequestInfoEvent, # type: ignore[reportPrivateUsage] -) -from ._executor import Executor, handler -from ._workflow_context import WorkflowContext - -logger = logging.getLogger(__name__) - - -@dataclass -class PendingRequestDetails: - """Lightweight information about a pending request captured in a checkpoint.""" - - request_id: str - prompt: str | None = None - draft: str | None = None - iteration: int | None = None - source_executor_id: str | None = None - original_request: "RequestInfoMessage | dict[str, Any] | None" = None - - -@dataclass -class PendingRequestSnapshot: - """Snapshot of a pending request for internal tracking. - - This snapshot should be JSON-serializable and contain enough - information to reconstruct the original request if needed. - """ - - request_id: str - source_executor_id: str - request_type: str - request_as_json_safe_dict: dict[str, Any] - - -@dataclass -class RequestInfoMessage: - """Base class for all request messages in workflows. - - Any message that should be routed to the RequestInfoExecutor for external - handling must inherit from this class. This ensures type safety and makes - the request/response pattern explicit. - """ - - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - """Unique identifier for correlating requests and responses.""" - - source_executor_id: str | None = None - """ID of the executor expecting a response to this request. - May differ from the executor that sent the request if intercepted and forwarded.""" - - -TRequest = TypeVar("TRequest", bound="RequestInfoMessage") -TResponse = TypeVar("TResponse") - - -@dataclass -class RequestResponse(Generic[TRequest, TResponse]): - """Response type for request/response correlation in workflows. - - This type is used by RequestInfoExecutor to create correlated responses - that include the original request context for proper message routing. - """ - - data: TResponse - """The response data returned from handling the request.""" - - original_request: TRequest - """The original request that this response corresponds to.""" - - request_id: str - """The ID of the original request.""" - - -# endregion: Request/Response Types - - -# region Request Info Executor -class RequestInfoExecutor(Executor): - """Built-in executor that handles request/response patterns in workflows. - - This executor acts as a gateway for external information requests. When it receives - a request message, it saves the request details and emits a RequestInfoEvent. When - a response is provided externally, it emits the response as a message. - """ - - _PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info" - - def __init__(self, id: str): - """Initialize the RequestInfoExecutor with a unique ID. - - Args: - id: Unique ID for this RequestInfoExecutor. - """ - super().__init__(id=id) - self._request_events: dict[str, RequestInfoEvent] = {} - - # region Public Methods - - @handler - async def handle_request(self, message: RequestInfoMessage, ctx: WorkflowContext) -> None: - """Run the RequestInfoExecutor with the given message.""" - # Use source_executor_id from message if available, otherwise fall back to context - source_executor_id = message.source_executor_id or ctx.get_source_executor_id() - - event = RequestInfoEvent( - request_id=message.request_id, - source_executor_id=source_executor_id, - request_type=type(message), - request_data=message, - ) - self._request_events[message.request_id] = event - await self._record_pending_request(message, source_executor_id, ctx) - await ctx.add_event(event) - - async def handle_response( - self, - response_data: Any, - request_id: str, - ctx: WorkflowContext[RequestResponse[RequestInfoMessage, Any]], - ) -> None: - """Handle a response to a request. - - Args: - request_id: The ID of the request to which this response corresponds. - response_data: The data returned in the response. - ctx: The workflow context for sending the response. - """ - event = self._request_events.get(request_id) - if event is None: - event = await self._rehydrate_request_event(request_id, cast(WorkflowContext, ctx)) - if event is None: - raise ValueError(f"No request found with ID: {request_id}") - - self._request_events.pop(request_id, None) - - # Create a correlated response that includes both the response data and original request - if not isinstance(event.data, RequestInfoMessage): - raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}") - correlated_response = RequestResponse(data=response_data, original_request=event.data, request_id=request_id) - await ctx.send_message(correlated_response, target_id=event.source_executor_id) - - await self._erase_pending_request(request_id, cast(WorkflowContext, ctx)) - - def snapshot_state(self) -> dict[str, Any]: - """Serialize pending requests so checkpoint restoration can resume seamlessly.""" - - def _encode_event(event: RequestInfoEvent) -> dict[str, Any] | None: - if event.data is None or not isinstance(event.data, RequestInfoMessage): - logger.warning( - f"RequestInfoExecutor {self.id} encountered invalid event data for request ID {event.request_id}: " - f"{type(event.data).__name__}. This request will be skipped in the checkpoint." - ) - return None - - payload = self._encode_request_payload(event.data, event.data.__class__) - - return { - "source_executor_id": event.source_executor_id, - "request_type": f"{event.request_type.__module__}:{event.request_type.__qualname__}", - "request_data": payload, - } - - return { - "request_events": { - rid: encoded - for rid, event in self._request_events.items() - if (encoded := _encode_event(event)) is not None - }, - } - - def restore_state(self, state: dict[str, Any]) -> None: - """Restore pending request bookkeeping from checkpoint state.""" - self._request_events.clear() - stored_events = state.get("request_events", {}) - - for request_id, payload in stored_events.items(): - request_type_qual = payload.get("request_type", "") - try: - request_type = _import_qualname(request_type_qual) - except Exception as exc: # pragma: no cover - defensive fallback - logger.debug( - "RequestInfoExecutor %s failed to import %s during restore: %s", - self.id, - request_type_qual, - exc, - ) - request_type = RequestInfoMessage - request_data_meta = payload.get("request_data", {}) - request_data = self._decode_request_data(request_data_meta) - event = RequestInfoEvent( - request_id=request_id, - source_executor_id=payload.get("source_executor_id", ""), - request_type=request_type, - request_data=request_data, - ) - self._request_events[request_id] = event - - async def has_pending_request(self, request_id: str, ctx: WorkflowContext) -> bool: - """Check if there is a pending request with the given ID. - - Args: - request_id: The ID of the request to check. - ctx: The workflow context for accessing state if needed. - - Returns: True if the request is pending, False otherwise. - """ - if request_id in self._request_events: - return True - - pending_requests = await self._retrieve_existing_pending_requests(ctx) - return request_id in pending_requests - - # endregion: Public Methods - - # region: Internal Methods - - async def _record_pending_request( - self, - message: RequestInfoMessage, - source_executor_id: str, - ctx: WorkflowContext, - ) -> None: - """Record a pending request to the executor's state for checkpointing purposes.""" - pending_request_snapshot = self._build_pending_request_snapshot(message, source_executor_id) - - existing_pending_requests = await self._retrieve_existing_pending_requests(ctx) - existing_pending_requests[message.request_id] = pending_request_snapshot - - await self._persist_to_executor_state(existing_pending_requests, ctx) - - async def _erase_pending_request(self, request_id: str, ctx: WorkflowContext) -> None: - """Erase a pending request from the executor's state after it has been handled for checkpointing purposes.""" - existing_pending_requests = await self._retrieve_existing_pending_requests(ctx) - if request_id in existing_pending_requests: - existing_pending_requests.pop(request_id) - await self._persist_to_executor_state(existing_pending_requests, ctx) - - async def _retrieve_existing_pending_requests(self, ctx: WorkflowContext) -> dict[str, PendingRequestSnapshot]: - """Retrieve existing pending requests from executor state.""" - executor_state = await ctx.get_executor_state() - if executor_state is None: - return {} - - stored_requests = executor_state.get(self._PENDING_SHARED_STATE_KEY, {}) - if not isinstance(stored_requests, dict): - raise TypeError(f"Unexpected type for pending requests: {type(stored_requests).__name__}") - - # Validate contents - for key, value in stored_requests.items(): # type: ignore - if not isinstance(key, str) or not isinstance(value, PendingRequestSnapshot): - raise TypeError( - "Invalid pending request entry in executor state. " - "Key must be `str` and value must be `PendingRequestSnapshot`." - ) - - return stored_requests # type: ignore - - async def _persist_to_executor_state( - self, pending: dict[str, PendingRequestSnapshot], ctx: WorkflowContext - ) -> None: - """Persist the current pending requests to the executor's state.""" - executor_state = await ctx.get_executor_state() or {} - executor_state[self._PENDING_SHARED_STATE_KEY] = pending - await ctx.set_executor_state(executor_state) - - def _build_pending_request_snapshot( - self, request: RequestInfoMessage, source_executor_id: str - ) -> PendingRequestSnapshot: - """Build a snapshot of the pending request for checkpointing.""" - request_as_json_safe_dict = self._convert_request_to_json_safe_dict(request) - - return PendingRequestSnapshot( - request_id=request.request_id, - source_executor_id=source_executor_id, - request_type=f"{type(request).__module__}:{type(request).__name__}", - request_as_json_safe_dict=request_as_json_safe_dict, - ) - - def _encode_request_payload(self, request_data: RequestInfoMessage, data_cls: type[Any]) -> dict[str, Any]: - if is_dataclass(request_data) and not isinstance(request_data, type): - dataclass_instance = cast(Any, request_data) - safe_value = _make_json_safe(asdict(dataclass_instance)) - return { - "kind": "dataclass", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - to_dict_fn = getattr(request_data, "to_dict", None) - if callable(to_dict_fn): - try: - dumped = to_dict_fn() - except TypeError: - dumped = to_dict_fn() - safe_value = _make_json_safe(dumped) - return { - "kind": "dict", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - to_json_fn = getattr(request_data, "to_json", None) - if callable(to_json_fn): - try: - dumped = to_json_fn() - except TypeError: - dumped = to_json_fn() - converted = dumped - if isinstance(dumped, (str, bytes, bytearray)): - decoded: str | bytes | bytearray - if isinstance(dumped, (bytes, bytearray)): - try: - decoded = dumped.decode() - except Exception: - decoded = dumped - else: - decoded = dumped - try: - converted = json.loads(decoded) - except Exception: - converted = decoded - safe_value = _make_json_safe(converted) - return { - "kind": "dict" if isinstance(converted, dict) else "json", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": safe_value, - } - - return { - "kind": "raw", - "type": f"{data_cls.__module__}:{data_cls.__qualname__}", - "value": self._convert_request_to_json_safe_dict(request_data), - } - - def _decode_request_data(self, metadata: dict[str, Any]) -> RequestInfoMessage: - kind = metadata.get("kind") - type_name = metadata.get("type", "") - value: Any = metadata.get("value", {}) - if type_name: - try: - imported = _import_qualname(type_name) - except Exception as exc: # pragma: no cover - defensive fallback - logger.debug( - "RequestInfoExecutor %s failed to import %s during decode: %s", - self.id, - type_name, - exc, - ) - imported = RequestInfoMessage - else: - imported = RequestInfoMessage - target_cls: type[RequestInfoMessage] - if isinstance(imported, type) and issubclass(imported, RequestInfoMessage): - target_cls = imported - else: - target_cls = RequestInfoMessage - - if kind == "dataclass" and isinstance(value, dict): - with contextlib.suppress(TypeError): - return target_cls(**value) # type: ignore[arg-type] - - # Backwards-compat handling for checkpoints that used to store pydantic as "dict" - if kind in {"dict", "pydantic", "json"} and isinstance(value, dict): - from_dict = getattr(target_cls, "from_dict", None) - if callable(from_dict): - with contextlib.suppress(Exception): - return cast(RequestInfoMessage, from_dict(value)) - - if kind == "json" and isinstance(value, str): - from_json = getattr(target_cls, "from_json", None) - if callable(from_json): - with contextlib.suppress(Exception): - return cast(RequestInfoMessage, from_json(value)) - with contextlib.suppress(Exception): - parsed = json.loads(value) - if isinstance(parsed, dict): - return self._decode_request_data({"kind": "dict", "type": type_name, "value": parsed}) - - if isinstance(value, dict): - with contextlib.suppress(TypeError): - return target_cls(**value) # type: ignore[arg-type] - instance = object.__new__(target_cls) - instance.__dict__.update(value) # type: ignore[arg-type] - return instance - - with contextlib.suppress(Exception): - return target_cls() - return RequestInfoMessage() - - def _convert_request_to_json_safe_dict(self, request: RequestInfoMessage) -> dict[str, Any]: - try: - data = _make_json_safe(asdict(request)) - if isinstance(data, dict): - return cast(dict[str, Any], data) - raise ValueError(f"Failed to convert {type(request).__name__} to dict") - except Exception as exc: - logger.error(f"RequestInfoExecutor {self.id} failed to serialize request: {exc}") - raise RuntimeError( - f"Failed to serialize request `{type(request).__name__}`: {exc}\n" - "Make sure request is a dataclass and derive from `RequestInfoMessage`." - ) from exc - - async def _rehydrate_request_event(self, request_id: str, ctx: WorkflowContext) -> RequestInfoEvent | None: - pending_requests = await self._retrieve_existing_pending_requests(ctx) - if (snapshot := pending_requests.get(request_id)) is None: - return None - - request = self._construct_request_from_snapshot(snapshot) - if request is None: - return None - - event = RequestInfoEvent( - request_id=request_id, - source_executor_id=snapshot.source_executor_id, - request_type=type(request), - request_data=request, - ) - self._request_events[request_id] = event - return event - - def _construct_request_from_snapshot(self, snapshot: PendingRequestSnapshot) -> RequestInfoMessage | None: - json_safe_dict = snapshot.request_as_json_safe_dict - - request_cls: type[RequestInfoMessage] = RequestInfoMessage - request_type_str = snapshot.request_type - if isinstance(request_type_str, str) and ":" in request_type_str: - module_name, class_name = request_type_str.split(":", 1) - try: - module = importlib.import_module(module_name) - candidate = getattr(module, class_name) - if isinstance(candidate, type) and issubclass(candidate, RequestInfoMessage): - request_cls = candidate - except Exception as exc: - logger.warning(f"RequestInfoExecutor {self.id} could not import {module_name}.{class_name}: {exc}") - request_cls = RequestInfoMessage - - request: RequestInfoMessage | None = self._instantiate_request(request_cls, json_safe_dict) - - if request is None and request_cls is not RequestInfoMessage: - request = self._instantiate_request(RequestInfoMessage, json_safe_dict) - - if request is None: - logger.warning( - f"RequestInfoExecutor {self.id} could not reconstruct request " - f"{request_type_str or RequestInfoMessage.__name__} from snapshot keys {sorted(json_safe_dict.keys())}" - ) - return None - - for key, value in json_safe_dict.items(): - if key == "request_id": - continue - try: - setattr(request, key, value) - except Exception as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not set attribute {key} on {type(request).__name__}: {exc}" - ) - continue - - snapshot_request_id = snapshot.request_id - if isinstance(snapshot_request_id, str) and snapshot_request_id: - try: - request.request_id = snapshot_request_id - except Exception as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not apply snapshot " - f"request_id to {type(request).__name__}: {exc}" - ) - - return request - - def _instantiate_request( - self, - request_cls: type[RequestInfoMessage], - details: dict[str, Any], - ) -> RequestInfoMessage | None: - try: - from_dict = getattr(request_cls, "from_dict", None) - if callable(from_dict): - return cast(RequestInfoMessage, from_dict(details)) - except (TypeError, ValueError) as exc: - logger.debug(f"RequestInfoExecutor {self.id} failed to hydrate {request_cls.__name__} via from_dict: {exc}") - except Exception as exc: - logger.warning( - f"RequestInfoExecutor {self.id} encountered unexpected error during " - f"{request_cls.__name__}.from_dict: {exc}" - ) - - if is_dataclass(request_cls): - try: - field_names = {f.name for f in fields(request_cls)} - ctor_kwargs = {name: details[name] for name in field_names if name in details} - return request_cls(**ctor_kwargs) - except (TypeError, ValueError) as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not instantiate dataclass " - f"{request_cls.__name__} with snapshot data: {exc}" - ) - except Exception as exc: - logger.warning( - f"RequestInfoExecutor {self.id} encountered unexpected error " - f"constructing dataclass {request_cls.__name__}: {exc}" - ) - - try: - instance = request_cls() - except Exception as exc: - logger.warning( - f"RequestInfoExecutor {self.id} could not instantiate {request_cls.__name__} without arguments: {exc}" - ) - return None - - for key, value in details.items(): - if key == "request_id": - continue - try: - setattr(instance, key, value) - except Exception as exc: - logger.debug( - f"RequestInfoExecutor {self.id} could not set attribute {key} on " - f"{request_cls.__name__} during instantiation: {exc}" - ) - continue - - return instance - - # endregion: Internal Methods - - -# region: Utility Functions - - -def _make_json_safe(value: Any) -> Any: - """Recursively convert a value to a JSON-safe representation.""" - if value is None or isinstance(value, (str, int, float, bool)): - return value - if isinstance(value, Mapping): - safe_dict: dict[str, Any] = {} - for key, val in value.items(): # type: ignore[attr-defined] - safe_dict[str(key)] = _make_json_safe(val) # type: ignore[arg-type] - return safe_dict - if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): - return [_make_json_safe(item) for item in value] # type: ignore[misc] - return repr(value) - - -def _import_qualname(qualname: str) -> type[Any]: - """Import a type given its qualified name in the format 'module:TypeName'.""" - module_name, _, type_name = qualname.partition(":") - if not module_name or not type_name: - raise ValueError(f"Invalid qualified name: {qualname}") - module = importlib.import_module(module_name) - attr: Any = module - for part in type_name.split("."): - attr = getattr(attr, part) - if not isinstance(attr, type): - raise TypeError(f"Resolved object is not a type: {qualname}") - return attr - - -# endregion: Utility Functions diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index f90a088633..83c5229e1f 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -20,7 +20,7 @@ from ._shared_state import SharedState if TYPE_CHECKING: - from ._request_info_executor import RequestInfoExecutor + pass logger = logging.getLogger(__name__) @@ -387,39 +387,6 @@ def _parse_edge_runners(self, edge_runners: list[EdgeRunner]) -> dict[str, list[ return parsed - def _find_request_info_executor(self) -> "RequestInfoExecutor | None": - """Find the RequestInfoExecutor instance in this workflow. - - Returns: - The RequestInfoExecutor instance if found, None otherwise. - """ - from ._request_info_executor import RequestInfoExecutor - - for executor in self._executors.values(): - if isinstance(executor, RequestInfoExecutor): - return executor - return None - - def _is_message_to_request_info_executor(self, msg: "Message") -> bool: - """Check if message targets any RequestInfoExecutor in this workflow. - - Args: - msg: The message to check. - - Returns: - True if the message targets a RequestInfoExecutor, False otherwise. - """ - from ._request_info_executor import RequestInfoExecutor - - if not msg.target_id: - return False - - # Check all executors to see if target_id matches a RequestInfoExecutor - for executor in self._executors.values(): - if executor.id == msg.target_id and isinstance(executor, RequestInfoExecutor): - return True - return False - def _convert_checkpoint_to_workflow_state(checkpoint: WorkflowCheckpoint) -> WorkflowState: """Helper function to convert a WorkflowCheckpoint to a WorkflowState.""" diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index fdf4ae3426..a1c82cfd30 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -498,9 +498,6 @@ async def send_request_info_response(self, request_id: str, response: Any) -> No await self.send_message(response_msg) - # Clear the event from pending requests - self._pending_request_info_events.pop(request_id, None) - async def get_pending_request_info_events(self) -> dict[str, RequestInfoEvent]: """Get the mapping of request IDs to their corresponding RequestInfoEvent. diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 8be339ab94..4212a7f5d3 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -1,7 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. import logging -from collections.abc import Mapping from dataclasses import fields, is_dataclass from types import UnionType from typing import Any, Union, get_args, get_origin @@ -117,35 +116,7 @@ def is_instance_of(data: Any, target_type: type | UnionType | Any) -> bool: ) ) - # Case 6: target_type is RequestResponse[T, U] - validate generic parameters - if origin and hasattr(origin, "__name__") and origin.__name__ == "RequestResponse": - if not isinstance(data, origin): - return False - # Validate generic parameters for RequestResponse[TRequest, TResponse] - if len(args) >= 2: - request_type, response_type = args[0], args[1] - # Check if the original_request matches TRequest and data matches TResponse - if ( - hasattr(data, "original_request") - and data.original_request is not None - and not is_instance_of(data.original_request, request_type) - ): - # Checkpoint decoding can leave original_request as a plain mapping. In that - # case we coerce it back into the expected request type so downstream handlers - # and validators still receive a fully typed RequestResponse instance. - original_request = data.original_request - if isinstance(original_request, Mapping): - coerced = _coerce_to_type(dict(original_request), request_type) - if coerced is None or not isinstance(coerced, request_type): - return False - data.original_request = coerced - else: - return False - if hasattr(data, "data") and data.data is not None and not is_instance_of(data.data, response_type): - return False - return True - - # Case 7: Other custom generic classes - check origin type only + # Case 6: Other custom generic classes - check origin type only # For generic classes, we check if data is an instance of the origin type # We don't validate the generic parameters at runtime since that's handled by type system if origin and hasattr(origin, "__name__"): diff --git a/python/packages/core/agent_framework/_workflows/_validation.py b/python/packages/core/agent_framework/_workflows/_validation.py index 6c2ef35681..da7f123908 100644 --- a/python/packages/core/agent_framework/_workflows/_validation.py +++ b/python/packages/core/agent_framework/_workflows/_validation.py @@ -10,7 +10,6 @@ from ._edge import Edge, EdgeGroup, FanInEdgeGroup, InternalEdgeGroup from ._executor import Executor -from ._request_info_executor import RequestInfoExecutor logger = logging.getLogger(__name__) @@ -247,14 +246,13 @@ def _validate_edge_type_compatibility(self, edge: Edge, edge_group: EdgeGroup) - # If either executor has no type information, log warning and skip validation # This allows for dynamic typing scenarios but warns about reduced validation coverage if not source_output_types or not target_input_types: - # Suppress warnings for RequestInfoExecutor where dynamic typing is expected - if not source_output_types and not isinstance(source_executor, RequestInfoExecutor): + if not source_output_types: logger.warning( f"Executor '{source_executor.id}' has no output type annotations. " f"Type compatibility validation will be skipped for edges from this executor. " f"Consider adding WorkflowContext[T] generics in handlers for better validation." ) - if not target_input_types and not isinstance(target_executor, RequestInfoExecutor): + if not target_input_types: logger.warning( f"Executor '{target_executor.id}' has no input type annotations. " f"Type compatibility validation will be skipped for edges to this executor. " diff --git a/python/packages/core/agent_framework/_workflows/_workflow.py b/python/packages/core/agent_framework/_workflows/_workflow.py index 148b6e20d8..e3b01f6d7c 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow.py +++ b/python/packages/core/agent_framework/_workflows/_workflow.py @@ -31,7 +31,6 @@ ) from ._executor import Executor from ._model_utils import DictConvertible -from ._request_info_executor import RequestInfoExecutor from ._runner import Runner from ._runner_context import RunnerContext from ._shared_state import SharedState @@ -138,11 +137,13 @@ class Workflow(DictConvertible): - run_stream_from_checkpoint(): Resume from checkpoint with streaming ## External Input Requests - Workflows can request external input using a RequestInfoExecutor: - 1. Executor connects to RequestInfoExecutor via edge group and back to itself - 2. Executor sends RequestInfoMessage to RequestInfoExecutor - 3. RequestInfoExecutor emits RequestInfoEvent and workflow enters IDLE_WITH_PENDING_REQUESTS - 4. Caller handles requests and uses send_responses()/send_responses_streaming() to continue + Executors within a workflow can request external input using `ctx.request_info()`: + 1. Executor calls `ctx.request_info()` to request input + 2. Executor implements `response_handler()` to process the response + 3. Requests are emitted as RequestInfoEvent instances in the event stream + 4. Workflow enters IDLE_WITH_PENDING_REQUESTS state + 5. Caller handles requests and uses send_responses()/send_responses_streaming() to continue + 6. Responses are routed back to the requesting executors and response handlers are invoked ## Checkpointing When enabled, checkpoints are created at the end of each superstep, capturing: @@ -619,19 +620,6 @@ def _get_executor_by_id(self, executor_id: str) -> Executor: raise ValueError(f"Executor with ID {executor_id} not found.") return self.executors[executor_id] - def _find_request_info_executor(self) -> RequestInfoExecutor | None: - """Find the RequestInfoExecutor instance in this workflow. - - Returns: - The RequestInfoExecutor instance if found, None otherwise. - """ - from ._request_info_executor import RequestInfoExecutor - - for executor in self.executors.values(): - if isinstance(executor, RequestInfoExecutor): - return executor - return None - # Graph signature helpers def _compute_graph_signature(self) -> dict[str, Any]: diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 1c8721c6c0..4b12d819f2 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -53,7 +53,7 @@ class ExecutionContext: class SubWorkflowResponseMessage: """Message sent from a parent workflow to a sub-workflow via WorkflowExecutor to provide requested information. - This message wraps a RequestResponse emitted by the parent workflow. + This message wraps the response data along with the original RequestInfoEvent emitted by the sub-workflow executor. Attributes: data: The response data to the original request. @@ -119,28 +119,41 @@ class WorkflowExecutor(Executor): ### Output Forwarding All outputs from the sub-workflow are automatically forwarded to the parent: + #### When `allow_direct_output` is False (default): + .. code-block:: python - # Sub-workflow yields outputs + # An executor in the sub-workflow yields outputs await ctx.yield_output("sub-workflow result") # WorkflowExecutor forwards to parent via ctx.send_message() # Parent receives the output as a regular message + #### When `allow_direct_output` is True: + + .. code-block:: python + # An executor in the sub-workflow yields outputs + await ctx.yield_output("sub-workflow result") + + # WorkflowExecutor yields output directly to parent workflow's event stream + # The output of the sub-workflow is considered the output of the parent workflow + # Caller of the parent workflow receives the output directly + ### Request/Response Coordination When sub-workflows need external information: .. code-block:: python - # Sub-workflow makes request + # An executor in the sub-workflow makes request request = MyDataRequest(query="user info") - # RequestInfoExecutor emits RequestInfoEvent - # WorkflowExecutor sets source_executor_id and forwards to parent - request.source_executor_id = "child_workflow_executor_id" - # Parent workflow can handle via @handler for RequestInfoMessage subclasses, - # or directly forward to external source via a RequestInfoExecutor in the parent - # workflow. + # WorkflowExecutor captures RequestInfoEvent and wraps it in a SubWorkflowRequestMessage + # then send it to the receiving executor in parent workflow. The executor in parent workflow + # can handle the request locally or forward it to an external source. + # The WorkflowExecutor tracks the pending request, and implements a response handler. + # When the response is received, it executes the response handler to accumulate responses + # and resume the sub-workflow when all expected responses are received. + # The response handler expects a SubWorkflowResponseMessage wrapping the response data. ### State Management WorkflowExecutor maintains execution state across request/response cycles: @@ -167,8 +180,8 @@ class WorkflowExecutor(Executor): .. code-block:: python # Includes all sub-workflow output types - # Plus RequestInfoMessage if sub-workflow can make requests - output_types = workflow.output_types + [RequestInfoMessage] # if applicable + # Plus SubWorkflowRequestMessage if sub-workflow can make requests + output_types = workflow.output_types + [SubWorkflowRequestMessage] # if applicable ``` ## Error Handling @@ -236,19 +249,19 @@ def __init__(self): ```python class ParentExecutor(Executor): @handler - async def handle_request( + async def handle_subworkflow_request( self, - request: MyRequestType, # Subclass of RequestInfoMessage - ctx: WorkflowContext[RequestResponse[RequestInfoMessage, Any] | RequestInfoMessage], + request: SubWorkflowRequestMessage, + ctx: WorkflowContext[SubWorkflowResponseMessage], ) -> None: # Handle request locally or forward to external source if self.can_handle_locally(request): # Send response back to sub-workflow - response = RequestResponse(data="local result", original_request=request, request_id=request.request_id) + response = request.create_response(data="local response data") await ctx.send_message(response, target_id=request.source_executor_id) else: # Forward to external handler - await ctx.send_message(request) + await ctx.request_info(request.source_event) ``` ## Implementation Notes @@ -306,7 +319,8 @@ def output_types(self) -> list[type[Any]]: Returns: A list of output types that the underlying workflow can produce. - Includes specific RequestInfoMessage subtypes if the sub-workflow contains RequestInfoExecutor. + Includes the SubWorkflowRequestMessage type if any executor in the + sub-workflow is request-response capable. """ output_types = list(self.workflow.output_types) diff --git a/python/samples/getting_started/workflows/README.md b/python/samples/getting_started/workflows/README.md index 17780a7aac..f20f3c740b 100644 --- a/python/samples/getting_started/workflows/README.md +++ b/python/samples/getting_started/workflows/README.md @@ -37,7 +37,7 @@ Once comfortable with these, explore the rest of the samples below. | Azure Chat Agents (Streaming) | [agents/azure_chat_agents_streaming.py](./agents/azure_chat_agents_streaming.py) | Add Azure Chat agents as edges and handle streaming events | | Azure AI Chat Agents (Streaming) | [agents/azure_ai_agents_streaming.py](./agents/azure_ai_agents_streaming.py) | Add Azure AI agents as edges and handle streaming events | | Azure Chat Agents (Function Bridge) | [agents/azure_chat_agents_function_bridge.py](./agents/azure_chat_agents_function_bridge.py) | Chain two agents with a function executor that injects external context | -| Azure Chat Agents (Tools + HITL) | [agents/azure_chat_agents_tool_calls_with_feedback.py](./agents/azure_chat_agents_tool_calls_with_feedback.py) | Tool-enabled writer/editor pipeline with human feedback gating via RequestInfoExecutor | +| Azure Chat Agents (Tools + HITL) | [agents/azure_chat_agents_tool_calls_with_feedback.py](./agents/azure_chat_agents_tool_calls_with_feedback.py) | Tool-enabled writer/editor pipeline with human feedback gating | | Custom Agent Executors | [agents/custom_agent_executors.py](./agents/custom_agent_executors.py) | Create executors to handle agent run methods | | Workflow as Agent (Reflection Pattern) | [agents/workflow_as_agent_reflection_pattern.py](./agents/workflow_as_agent_reflection_pattern.py) | Wrap a workflow so it can behave like an agent (reflection pattern) | | Workflow as Agent + HITL | [agents/workflow_as_agent_human_in_the_loop.py](./agents/workflow_as_agent_human_in_the_loop.py) | Extend workflow-as-agent with human-in-the-loop capability | @@ -55,7 +55,7 @@ Once comfortable with these, explore the rest of the samples below. | Sample | File | Concepts | |---|---|---| | Sub-Workflow (Basics) | [composition/sub_workflow_basics.py](./composition/sub_workflow_basics.py) | Wrap a workflow as an executor and orchestrate sub-workflows | -| Sub-Workflow: Request Interception | [composition/sub_workflow_request_interception.py](./composition/sub_workflow_request_interception.py) | Intercept and forward sub-workflow requests using @handler for RequestInfoMessage subclasses | +| Sub-Workflow: Request Interception | [composition/sub_workflow_request_interception.py](./composition/sub_workflow_request_interception.py) | Intercept and forward sub-workflow requests using @handler for SubWorkflowRequestMessage | | Sub-Workflow: Parallel Requests | [composition/sub_workflow_parallel_requests.py](./composition/sub_workflow_parallel_requests.py) | Multiple specialized interceptors handling different request types from same sub-workflow | ### control-flow diff --git a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py index aa8efa9cc1..a6c87a5899 100644 --- a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py +++ b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py @@ -3,26 +3,25 @@ import asyncio import json from dataclasses import dataclass, field -from typing import Annotated +from typing import Annotated, Never from agent_framework import ( AgentExecutorRequest, AgentExecutorResponse, + AgentRunResponse, AgentRunUpdateEvent, ChatMessage, Executor, FunctionCallContent, FunctionResultContent, RequestInfoEvent, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, Role, ToolMode, WorkflowBuilder, WorkflowContext, WorkflowOutputEvent, handler, + response_handler, ) from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential @@ -32,8 +31,8 @@ Sample: Tool-enabled agents with human feedback Pipeline layout: -writer_agent (uses Azure OpenAI tools) -> DraftFeedbackCoordinator -> RequestInfoExecutor --> DraftFeedbackCoordinator -> final_editor_agent +writer_agent (uses Azure OpenAI tools) -> Coordinator -> writer_agent +-> Coordinator -> final_editor_agent -> Coordinator -> output The writer agent calls tools to gather product facts before drafting copy. A custom executor packages the draft and emits a RequestInfoEvent so a human can comment, then replays the human @@ -41,7 +40,7 @@ Demonstrates: - Attaching Python function tools to an agent inside a workflow. -- Capturing the writer's output and routing it through RequestInfoExecutor for human review. +- Capturing the writer's output for human review. - Streaming AgentRunUpdateEvent updates alongside human-in-the-loop pauses. Prerequisites: @@ -82,27 +81,37 @@ def get_brand_voice_profile( @dataclass -class DraftFeedbackRequest(RequestInfoMessage): - """Payload sent to RequestInfoExecutor for human review.""" +class DraftFeedbackRequest: + """Payload sent for human review.""" prompt: str = "" draft_text: str = "" conversation: list[ChatMessage] = field(default_factory=list) # type: ignore[reportUnknownVariableType] -class DraftFeedbackCoordinator(Executor): +class Coordinator(Executor): """Bridge between the writer agent, human feedback, and final editor.""" - def __init__(self, *, id: str = "draft_feedback_coordinator") -> None: + def __init__(self, id: str, writer_id: str, final_editor_id: str) -> None: super().__init__(id) + self.writer_id = writer_id + self.final_editor_id = final_editor_id @handler async def on_writer_response( self, draft: AgentExecutorResponse, - ctx: WorkflowContext[DraftFeedbackRequest], + ctx: WorkflowContext[Never, AgentRunResponse], ) -> None: - # Preserve the full conversation so the final editor can see tool traces and the initial prompt. + """Handle responses from the other two agents in the workflow.""" + if draft.executor_id == self.final_editor_id: + # Final editor response; yield output directly. + await ctx.yield_output(draft.agent_run_response) + return + + # Writer agent response; request human feedback. + # Preserve the full conversation so the final editor + # can see tool traces and the initial prompt. conversation: list[ChatMessage] if draft.full_conversation is not None: conversation = list(draft.full_conversation) @@ -117,18 +126,34 @@ async def on_writer_response( "(tone tweaks, must-have detail, target audience, etc.). " "Keep it under 30 words." ) - await ctx.send_message(DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation)) + await ctx.request_info( + DraftFeedbackRequest(prompt=prompt, draft_text=draft_text, conversation=conversation), + DraftFeedbackRequest, + str, + ) - @handler + @response_handler async def on_human_feedback( self, - feedback: RequestResponse[DraftFeedbackRequest, str], + original_request: DraftFeedbackRequest, + feedback: str, ctx: WorkflowContext[AgentExecutorRequest], ) -> None: - note = (feedback.data or "").strip() - request = feedback.original_request + note = feedback.strip() + if note.lower() == "approve": + # Human approved the draft as-is; forward it unchanged. + await ctx.send_message( + AgentExecutorRequest( + messages=original_request.conversation + + [ChatMessage(Role.USER, text="The draft is approved as-is.")], + should_respond=True, + ), + target_id=self.final_editor_id, + ) + return - conversation: list[ChatMessage] = list(request.conversation) + # Human provided feedback; prompt the writer to revise. + conversation: list[ChatMessage] = list(original_request.conversation) instruction = ( "A human reviewer shared the following guidance:\n" f"{note or 'No specific guidance provided.'}\n\n" @@ -136,11 +161,57 @@ async def on_human_feedback( "Keep the response under 120 words and reflect any requested tone adjustments." ) conversation.append(ChatMessage(Role.USER, text=instruction)) - await ctx.send_message(AgentExecutorRequest(messages=conversation, should_respond=True)) + await ctx.send_message( + AgentExecutorRequest(messages=conversation, should_respond=True), target_id=self.writer_id + ) + + +def display_agent_run_update(event: AgentRunUpdateEvent, last_executor: str | None) -> None: + """Display an AgentRunUpdateEvent in a readable format.""" + printed_tool_calls: set[str] = set() + printed_tool_results: set[str] = set() + executor_id = event.executor_id + update = event.data + # Extract and print any new tool calls or results from the update. + function_calls = [c for c in update.contents if isinstance(c, FunctionCallContent)] # type: ignore[union-attr] + function_results = [c for c in update.contents if isinstance(c, FunctionResultContent)] # type: ignore[union-attr] + if executor_id != last_executor: + if last_executor is not None: + print() + print(f"{executor_id}:", end=" ", flush=True) + last_executor = executor_id + # Print any new tool calls before the text update. + for call in function_calls: + if call.call_id in printed_tool_calls: + continue + printed_tool_calls.add(call.call_id) + args = call.arguments + args_preview = json.dumps(args, ensure_ascii=False) if isinstance(args, dict) else (args or "").strip() + print( + f"\n{executor_id} [tool-call] {call.name}({args_preview})", + flush=True, + ) + print(f"{executor_id}:", end=" ", flush=True) + # Print any new tool results before the text update. + for result in function_results: + if result.call_id in printed_tool_results: + continue + printed_tool_results.add(result.call_id) + result_text = result.result + if not isinstance(result_text, str): + result_text = json.dumps(result_text, ensure_ascii=False) + print( + f"\n{executor_id} [tool-result] {result.call_id}: {result_text}", + flush=True, + ) + print(f"{executor_id}:", end=" ", flush=True) + # Finally, print the text update. + print(update, end="", flush=True) async def main() -> None: """Run the workflow and bridge human feedback between two agents.""" + # Create agents with tools and instructions. chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) writer_agent = chat_client.create_agent( @@ -157,33 +228,39 @@ async def main() -> None: final_editor_agent = chat_client.create_agent( name="final_editor_agent", instructions=( - "You are an editor who polishes marketing copy using human guidance. " - "Respect factual details from the prior messages while applying the feedback." + "You are an editor who polishes marketing copy after human approval. " + "Correct any legal or factual issues. Return the final version even if no changes are made. " ), ) - feedback_coordinator = DraftFeedbackCoordinator() - request_info_executor = RequestInfoExecutor(id="human_feedback") + coordinator = Coordinator( + id="coordinator", + writer_id="writer_agent", + final_editor_id="final_editor_agent", + ) + # Build the workflow. workflow = ( WorkflowBuilder() .set_start_executor(writer_agent) - .add_edge(writer_agent, feedback_coordinator) - .add_edge(feedback_coordinator, request_info_executor) - .add_edge(request_info_executor, feedback_coordinator) - .add_edge(feedback_coordinator, final_editor_agent) + .add_edge(writer_agent, coordinator) + .add_edge(coordinator, writer_agent) + .add_edge(final_editor_agent, coordinator) + .add_edge(coordinator, final_editor_agent) .build() ) + # Switch to turn on agent run update display. + # By default this is off to reduce clutter during human input. + display_agent_run_update_switch = False + print( - "Interactive mode. When prompted, provide a short feedback note for the editor (type 'exit' to quit).", + "Interactive mode. When prompted, provide a short feedback note for the editor.", flush=True, ) pending_responses: dict[str, str] | None = None completed = False - printed_tool_calls: set[str] = set() - printed_tool_results: set[str] = set() while not completed: last_executor: str | None = None @@ -198,48 +275,9 @@ async def main() -> None: requests: list[tuple[str, DraftFeedbackRequest]] = [] async for event in stream: - if isinstance(event, AgentRunUpdateEvent): - executor_id = event.executor_id - update = event.data - # Extract and print any new tool calls or results from the update. - function_calls = [c for c in update.contents if isinstance(c, FunctionCallContent)] # type: ignore[union-attr] - function_results = [c for c in update.contents if isinstance(c, FunctionResultContent)] # type: ignore[union-attr] - if executor_id != last_executor: - if last_executor is not None: - print() - print(f"{executor_id}:", end=" ", flush=True) - last_executor = executor_id - # Print any new tool calls before the text update. - for call in function_calls: - if call.call_id in printed_tool_calls: - continue - printed_tool_calls.add(call.call_id) - args = call.arguments - if isinstance(args, dict): - args_preview = json.dumps(args, ensure_ascii=False) - else: - args_preview = (args or "").strip() - print( - f"\n{executor_id} [tool-call] {call.name}({args_preview})", - flush=True, - ) - print(f"{executor_id}:", end=" ", flush=True) - # Print any new tool results before the text update. - for result in function_results: - if result.call_id in printed_tool_results: - continue - printed_tool_results.add(result.call_id) - result_text = result.result - if not isinstance(result_text, str): - result_text = json.dumps(result_text, ensure_ascii=False) - print( - f"\n{executor_id} [tool-result] {result.call_id}: {result_text}", - flush=True, - ) - print(f"{executor_id}:", end=" ", flush=True) - # Finally, print the text update. - print(update, end="", flush=True) - elif isinstance(event, RequestInfoEvent) and isinstance(event.data, DraftFeedbackRequest): + if isinstance(event, AgentRunUpdateEvent) and display_agent_run_update_switch: + display_agent_run_update(event, last_executor) + if isinstance(event, RequestInfoEvent) and isinstance(event.data, DraftFeedbackRequest): # Stash the request so we can prompt the human after the stream completes. requests.append((event.request_id, event.data)) last_executor = None @@ -256,7 +294,7 @@ async def main() -> None: for request_id, request in requests: print("\n----- Writer draft -----") print(request.draft_text.strip()) - print("\nProvide guidance for the editor (or press Enter to accept the draft).") + print("\nProvide guidance for the editor (or 'approve' to accept the draft).") answer = input("Human feedback: ").strip() # noqa: ASYNC250 if answer.lower() == "exit": print("Exiting...") diff --git a/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py b/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py index e4fba82495..7bdf3e922a 100644 --- a/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py @@ -7,6 +7,9 @@ from pathlib import Path from typing import Any +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential + # Ensure local getting_started package can be imported when running as a script. _SAMPLES_ROOT = Path(__file__).resolve().parents[3] if str(_SAMPLES_ROOT) not in sys.path: @@ -17,16 +20,13 @@ Executor, FunctionCallContent, FunctionResultContent, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, Role, WorkflowAgent, WorkflowBuilder, WorkflowContext, handler, + response_handler, ) -from agent_framework.openai import OpenAIChatClient # noqa: E402 from getting_started.workflows.agents.workflow_as_agent_reflection_pattern import ( # noqa: E402 ReviewRequest, ReviewResponse, @@ -40,20 +40,20 @@ This sample demonstrates how to build a workflow agent that escalates uncertain decisions to a human manager. A Worker generates results, while a Reviewer evaluates them. When the Reviewer is not confident, it escalates the decision -to a human via RequestInfoExecutor, receives the human response, and then -forwards that response back to the Worker. The workflow completes when idle. +to a human, receives the human response, and then forwards that response back +to the Worker. The workflow completes when idle. Prerequisites: - OpenAI account configured and accessible for OpenAIChatClient. - Familiarity with WorkflowBuilder, Executor, and WorkflowContext from agent_framework. -- Understanding of request-response message handling (RequestInfoMessage, RequestResponse). +- Understanding of request-response message handling in executors. - (Optional) Review of reflection and escalation patterns, such as those in workflow_as_agent_reflection.py. """ @dataclass -class HumanReviewRequest(RequestInfoMessage): +class HumanReviewRequest: """A request message type for escalation to a human reviewer.""" agent_request: ReviewRequest | None = None @@ -62,14 +62,13 @@ class HumanReviewRequest(RequestInfoMessage): class ReviewerWithHumanInTheLoop(Executor): """Executor that always escalates reviews to a human manager.""" - def __init__(self, worker_id: str, request_info_id: str, reviewer_id: str | None = None) -> None: + def __init__(self, worker_id: str, reviewer_id: str | None = None) -> None: unique_id = reviewer_id or f"{worker_id}-reviewer" super().__init__(id=unique_id) self._worker_id = worker_id - self._request_info_id = request_info_id @handler - async def review(self, request: ReviewRequest, ctx: WorkflowContext[ReviewResponse | HumanReviewRequest]) -> None: + async def review(self, request: ReviewRequest, ctx: WorkflowContext) -> None: # In this simplified example, we always escalate to a human manager. # See workflow_as_agent_reflection.py for an implementation # using an automated agent to make the review decision. @@ -77,23 +76,21 @@ async def review(self, request: ReviewRequest, ctx: WorkflowContext[ReviewRespon print("Reviewer: Escalating to human manager...") # Forward the request to a human manager by sending a HumanReviewRequest. - await ctx.send_message( - HumanReviewRequest(agent_request=request), - target_id=self._request_info_id, - ) + await ctx.request_info(HumanReviewRequest(agent_request=request), HumanReviewRequest, ReviewResponse) - @handler + @response_handler async def accept_human_review( - self, response: RequestResponse[HumanReviewRequest, ReviewResponse], ctx: WorkflowContext[ReviewResponse] + self, + original_request: ReviewRequest, + response: ReviewResponse, + ctx: WorkflowContext[ReviewResponse], ) -> None: # Accept the human review response and forward it back to the Worker. - human_response = response.data - assert isinstance(human_response, ReviewResponse) - print(f"Reviewer: Accepting human review for request {human_response.request_id[:8]}...") - print(f"Reviewer: Human feedback: {human_response.feedback}") - print(f"Reviewer: Human approved: {human_response.approved}") + print(f"Reviewer: Accepting human review for request {response.request_id[:8]}...") + print(f"Reviewer: Human feedback: {response.feedback}") + print(f"Reviewer: Human approved: {response.approved}") print("Reviewer: Forwarding human review back to worker...") - await ctx.send_message(human_response, target_id=self._worker_id) + await ctx.send_message(response, target_id=self._worker_id) async def main() -> None: @@ -102,10 +99,9 @@ async def main() -> None: # Create executors for the workflow. print("Creating chat client and executors...") - mini_chat_client = OpenAIChatClient(model_id="gpt-4.1-nano") + mini_chat_client = AzureOpenAIChatClient(credential=AzureCliCredential()) worker = Worker(id="sub-worker", chat_client=mini_chat_client) - request_info_executor = RequestInfoExecutor(id="request_info") - reviewer = ReviewerWithHumanInTheLoop(worker_id=worker.id, request_info_id=request_info_executor.id) + reviewer = ReviewerWithHumanInTheLoop(worker_id=worker.id) print("Building workflow with Worker ↔ Reviewer cycle...") # Build a workflow with bidirectional communication between Worker and Reviewer, @@ -114,8 +110,6 @@ async def main() -> None: WorkflowBuilder() .add_edge(worker, reviewer) # Worker sends requests to Reviewer .add_edge(reviewer, worker) # Reviewer sends feedback to Worker - .add_edge(reviewer, request_info_executor) # Reviewer requests human input - .add_edge(request_info_executor, reviewer) # Human input forwarded back to Reviewer .set_start_executor(worker) .build() .as_agent() # Convert workflow into an agent interface diff --git a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py index 737512b42d..a70bebfba2 100644 --- a/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/checkpoint/checkpoint_with_human_in_the_loop.py @@ -37,17 +37,14 @@ 1. A brief is turned into a consistent prompt for an AI copywriter. 2. The copywriter (an `AgentExecutor`) drafts release notes. -3. A reviewer gateway routes every draft through `RequestInfoExecutor` so a human - can approve or request tweaks. +3. A reviewer gateway sends a request for approval for every draft. 4. The workflow records checkpoints between each superstep so you can stop the program, restart later, and optionally pre-supply human answers on resume. Key concepts demonstrated ------------------------- - Minimal executor pipeline with checkpoint persistence. -- Human-in-the-loop pause/resume by pairing `RequestInfoExecutor` with - checkpoint restoration. -- Supplying responses at restore time (`run_stream_from_checkpoint(..., responses=...)`). +- Human-in-the-loop pause/resume with checkpoint restoration. Typical pause/resume flow ------------------------- @@ -103,7 +100,7 @@ async def prepare(self, brief: str, ctx: WorkflowContext[AgentExecutorRequest, s @dataclass class HumanApprovalRequest: - """Message sent to the human reviewer via RequestInfoExecutor.""" + """Request sent to the human reviewer.""" # These fields are intentionally simple because they are serialised into # checkpoints. Keeping them primitive types guarantees the new @@ -122,15 +119,11 @@ def __init__(self, id: str, writer_id: str) -> None: @handler async def on_agent_response(self, response: AgentExecutorResponse, ctx: WorkflowContext) -> None: - # Capture the agent output so we can surface it to the reviewer and - # persist iterations. The `RequestInfoExecutor` relies on this state to - # rehydrate when checkpoints are restored. + # Capture the agent output so we can surface it to the reviewer and persist iterations. draft = response.agent_run_response.text or "" iteration = int((await ctx.get_executor_state() or {}).get("iteration", 0)) + 1 await ctx.set_executor_state({"iteration": iteration, "last_draft": draft}) - # Emit a human approval request. Because this flows through - # RequestInfoExecutor it will pause the workflow until an answer is - # supplied either interactively or via pre-supplied responses. + # Emit a human approval request. await ctx.request_info( HumanApprovalRequest( prompt="Review the draft. Reply 'approve' or provide edit instructions.", @@ -148,8 +141,7 @@ async def on_human_feedback( feedback: str, ctx: WorkflowContext[AgentExecutorRequest | str, str], ) -> None: - # The RequestResponse wrapper gives us both the human data and the - # original request message, even when resuming from checkpoints. + # The `original_request` is the request we sent earlier that is now being answered. reply = feedback.strip() state = await ctx.get_executor_state() or {} draft = state.get("last_draft") or (original_request.draft or "") @@ -214,10 +206,8 @@ def render_checkpoint_summary(checkpoints: list["WorkflowCheckpoint"]) -> None: ) if summary.status: line += f" | status={summary.status}" - if summary.draft_preview: - line += f" | draft_preview={summary.draft_preview}" - if summary.pending_requests: - line += f" | pending_request_id={summary.pending_requests[0].request_id}" + if summary.pending_request_info_events: + line += f" | pending_request_id={summary.pending_request_info_events[0].request_id}" print(line) @@ -254,6 +244,7 @@ async def run_interactive_session( if responses: event_stream = workflow.send_responses_streaming(responses) requests.clear() + responses = None else: if initial_message: print(f"\nStarting workflow with brief: {initial_message}\n") diff --git a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py index 21e385e3c6..005158bb3a 100644 --- a/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py +++ b/python/samples/getting_started/workflows/checkpoint/sub_workflow_checkpoint.py @@ -31,7 +31,7 @@ Sample: Checkpointing for workflows that embed sub-workflows. This sample shows how a parent workflow that wraps a sub-workflow can: -- run until the sub-workflow emits a human approval request via RequestInfoExecutor +- run until the sub-workflow emits a human approval request - persist a checkpoint that captures the pending request (including complex payloads) - resume later, supplying the human decision directly at restore time @@ -80,7 +80,7 @@ class FinalDraft: @dataclass class ReviewRequest: - """Human approval request surfaced via RequestInfoExecutor.""" + """Human approval request surfaced via `request_info`.""" id: str = str(uuid.uuid4()) topic: str = "" diff --git a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py index 2bec4c0f7d..81667a9931 100644 --- a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py +++ b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py @@ -17,7 +17,8 @@ WorkflowRunState, WorkflowStatusEvent, ) -from agent_framework.openai import OpenAIChatClient +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity._credentials import AzureCliCredential """ Sample: Magentic Orchestration + Checkpointing @@ -29,8 +30,8 @@ Concepts highlighted here: 1. **Deterministic executor IDs** - the orchestrator and plan-review request executor must keep stable IDs so the checkpoint state aligns when we rebuild the graph. -2. **Executor snapshotting** - checkpoints capture the `RequestInfoExecutor` state, - specifically the pending plan-review request map, at superstep boundaries. +2. **Executor snapshotting** - checkpoints capture the pending plan-review request + map, at superstep boundaries. 3. **Resume with responses** - `Workflow.run_stream_from_checkpoint` accepts a `responses` mapping so we can inject the stored human reply during restoration. @@ -58,14 +59,14 @@ def build_workflow(checkpoint_storage: FileCheckpointStorage): name="ResearcherAgent", description="Collects background facts and references for the project.", instructions=("You are the research lead. Gather crisp bullet points the team should know."), - chat_client=OpenAIChatClient(), + chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()), ) writer = ChatAgent( name="WriterAgent", description="Synthesizes the final brief for stakeholders.", instructions=("You convert the research notes into a structured brief with milestones and risks."), - chat_client=OpenAIChatClient(), + chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()), ) # The builder wires in the Magentic orchestrator, sets the plan review path, and @@ -75,7 +76,7 @@ def build_workflow(checkpoint_storage: FileCheckpointStorage): .participants(researcher=researcher, writer=writer) .with_plan_review() .with_standard_manager( - chat_client=OpenAIChatClient(), + chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()), max_round_count=10, max_stall_count=3, ) @@ -135,16 +136,23 @@ async def main() -> None: print("\n=== Stage 2: resume from checkpoint and approve plan ===") resumed_workflow = build_workflow(checkpoint_storage) + # Construct an approval reply to supply when the plan review request is re-emitted. approval = MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE) - # Resume execution and supply the recorded approval in a single call. - # `run_stream_from_checkpoint` rebuilds executor state, applies the provided responses, - # and then continues the workflow. Because we only captured the initial plan review - # checkpoint, the resumed run should complete almost immediately. + + # Resume execution and capture the re-emitted plan review request. + request_info_event: RequestInfoEvent | None = None + async for event in resumed_workflow.workflow.run_stream_from_checkpoint(resume_checkpoint.checkpoint_id): + if isinstance(event, RequestInfoEvent) and isinstance(event, MagenticPlanReviewRequest): + request_info_event = event + + if request_info_event is None: + print("No plan review request re-emitted on resume; cannot approve.") + return + print(f"Resumed plan review request: {request_info_event.request_id}") + + # Supply the approval and continue to run to completion. final_event: WorkflowOutputEvent | None = None - async for event in resumed_workflow.workflow.run_stream_from_checkpoint( - resume_checkpoint.checkpoint_id, - responses={plan_review_request_id: approval}, - ): + async for event in resumed_workflow.workflow.send_responses_streaming({request_info_event.request_id: approval}): if isinstance(event, WorkflowOutputEvent): final_event = event @@ -204,10 +212,7 @@ def _pending_message_count(cp: WorkflowCheckpoint) -> int: final_event_post: WorkflowOutputEvent | None = None post_emitted_events = False post_plan_workflow = build_workflow(checkpoint_storage) - async for event in post_plan_workflow.workflow.run_stream_from_checkpoint( - post_plan_checkpoint.checkpoint_id, - responses={}, - ): + async for event in post_plan_workflow.workflow.run_stream_from_checkpoint(post_plan_checkpoint.checkpoint_id): post_emitted_events = True if isinstance(event, WorkflowOutputEvent): final_event_post = event From d3e83437faed272417cb5cc035e531cbbed2a644 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 24 Oct 2025 10:03:07 -0700 Subject: [PATCH 10/26] Fix Handoff and sample --- .../agent_framework/_workflows/_handoff.py | 40 +++++++------------ .../_workflows/_runner_context.py | 12 ++++++ 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index 725e50cb25..dbb7acc9ef 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -38,7 +38,7 @@ from ._checkpoint import CheckpointStorage from ._conversation_state import decode_chat_messages, encode_chat_messages from ._executor import Executor, handler -from ._request_info_executor import RequestInfoExecutor, RequestInfoMessage, RequestResponse +from ._request_info_mixin import response_handler from ._workflow import Workflow from ._workflow_builder import WorkflowBuilder from ._workflow_context import WorkflowContext @@ -112,12 +112,13 @@ def _clone_chat_agent(agent: ChatAgent) -> ChatAgent: @dataclass -class HandoffUserInputRequest(RequestInfoMessage): +class HandoffUserInputRequest: """Request message emitted when the workflow needs fresh user input.""" - conversation: list[ChatMessage] = field(default_factory=lambda: []) # type: ignore[misc] - awaiting_agent_id: str | None = None - prompt: str | None = None + conversation: list[ChatMessage] + awaiting_agent_id: str + prompt: str + source_executor_id: str @dataclass @@ -514,28 +515,22 @@ def _apply_response_metadata(self, conversation: list[ChatMessage], agent_respon class _UserInputGateway(Executor): - """Bridges conversation context with RequestInfoExecutor and re-enters the loop.""" + """Bridges conversation context with the request & response cycle and re-enters the loop.""" def __init__( self, *, - request_executor_id: str, starting_agent_id: str, prompt: str | None, id: str, ) -> None: """Initialise the gateway that requests user input and forwards responses.""" super().__init__(id) - self._request_executor_id = request_executor_id self._starting_agent_id = starting_agent_id self._prompt = prompt or "Provide your next input for the conversation." @handler - async def request_input( - self, - conversation: list[ChatMessage], - ctx: WorkflowContext[HandoffUserInputRequest], - ) -> None: + async def request_input(self, conversation: list[ChatMessage], ctx: WorkflowContext) -> None: """Emit a `HandoffUserInputRequest` capturing the conversation snapshot.""" if not conversation: raise ValueError("Handoff workflow requires non-empty conversation before requesting user input.") @@ -543,27 +538,26 @@ async def request_input( conversation=list(conversation), awaiting_agent_id=self._starting_agent_id, prompt=self._prompt, + source_executor_id=self.id, ) - request.source_executor_id = self.id - await ctx.send_message(request, target_id=self._request_executor_id) + await ctx.request_info(request, HandoffUserInputRequest, Any) - @handler + @response_handler async def resume_from_user( self, - response: RequestResponse[HandoffUserInputRequest, Any], + original_request: HandoffUserInputRequest, + response: Any, ctx: WorkflowContext[_ConversationWithUserInput], ) -> None: """Convert user input responses back into chat messages and resume the workflow.""" # Reconstruct full conversation with new user input - conversation = list(response.original_request.conversation) - user_messages = _as_user_messages(response.data) + conversation = list(original_request.conversation) + user_messages = _as_user_messages(response) conversation.extend(user_messages) # Send full conversation back to coordinator (not trimmed) # Coordinator will update its authoritative history and trim for agent message = _ConversationWithUserInput(full_conversation=conversation) - # CRITICAL: Must specify target to avoid broadcasting to all connected executors - # Gateway is connected to both request_info and coordinator, we want coordinator only await ctx.send_message(message, target_id="handoff-coordinator") @@ -1309,9 +1303,7 @@ def build(self) -> Workflow: logger.warning("Handoff workflow has no specialist agents; the coordinator will loop with the user.") input_node = _InputToConversation(id="input-conversation") - request_info = RequestInfoExecutor(id=f"{starting_executor.id}_handoff_requests") user_gateway = _UserInputGateway( - request_executor_id=request_info.id, starting_agent_id=starting_executor.id, prompt=self._request_prompt, id="handoff-user-input", @@ -1335,8 +1327,6 @@ def build(self) -> Workflow: builder.add_edge(specialist, coordinator) builder.add_edge(coordinator, user_gateway) - builder.add_edge(user_gateway, request_info) - builder.add_edge(request_info, user_gateway) builder.add_edge(user_gateway, coordinator) # Route back to coordinator, not directly to agent builder.add_edge(coordinator, starting_executor) # Coordinator sends trimmed request to agent diff --git a/python/packages/core/agent_framework/_workflows/_runner_context.py b/python/packages/core/agent_framework/_workflows/_runner_context.py index d4d94847af..6decfc592b 100644 --- a/python/packages/core/agent_framework/_workflows/_runner_context.py +++ b/python/packages/core/agent_framework/_workflows/_runner_context.py @@ -348,6 +348,7 @@ async def create_checkpoint( workflow_id=self._workflow_id, messages=state["messages"], shared_state=state["shared_state"], + pending_request_info_events=state["pending_request_info_events"], iteration_count=state["iteration_count"], metadata=metadata or {}, ) @@ -371,6 +372,8 @@ def reset_for_new_run(self) -> None: self._streaming = False # Reset streaming flag async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: + """Apply a checkpoint to the current context, mutating its state.""" + # Restore messages self._messages.clear() messages_data = checkpoint.messages for source_id, message_list in messages_data.items(): @@ -385,6 +388,15 @@ async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: for msg in message_list ] + # Restore pending request info events + self._pending_request_info_events.clear() + pending_requests_data = checkpoint.pending_request_info_events + for request_id, request_data in pending_requests_data.items(): + request_info_event = RequestInfoEvent.from_dict(request_data) + self._pending_request_info_events[request_id] = request_info_event + await self.add_event(request_info_event) + + # Restore workflow ID self._workflow_id = checkpoint.workflow_id # endregion Checkpointing From dc24afa6cab3aad641461465f0fb28c79938b47b Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 24 Oct 2025 11:32:32 -0700 Subject: [PATCH 11/26] fix pending requests in checkpoint --- .../agent_framework/_workflows/_runner.py | 2 +- python/samples/_run_all_samples.py | 133 +++++++++++++----- .../workflow_as_agent_human_in_the_loop.py | 2 +- 3 files changed, 103 insertions(+), 34 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index b4a00437ce..3fe6f1ca7a 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -287,7 +287,7 @@ async def restore_from_checkpoint( self._workflow_id = checkpoint.workflow_id # Restore shared state - await self._shared_state.import_state(checkpoint.shared_state) + await self._shared_state.import_state(decode_checkpoint_value(checkpoint.shared_state)) # Restore executor states using the restored shared state await self._restore_executor_states() # Apply the checkpoint to the context diff --git a/python/samples/_run_all_samples.py b/python/samples/_run_all_samples.py index ba5f21d8a5..7d1a226e5c 100644 --- a/python/samples/_run_all_samples.py +++ b/python/samples/_run_all_samples.py @@ -56,9 +56,9 @@ def run_sample( sample_path: Path, use_uv: bool = True, python_root: Path | None = None, -) -> tuple[bool, str, str]: +) -> tuple[bool, str, str, str]: """ - Run a single sample file using subprocess and return (success, output, error_info). + Run a single sample file using subprocess and return (success, output, error_info, error_type). Args: sample_path: Path to the sample file @@ -66,7 +66,8 @@ def run_sample( python_root: Root directory for uv run Returns: - Tuple of (success, output, error_info) + Tuple of (success, output, error_info, error_type) + error_type can be: "timeout", "input_hang", "execution_error", "exception" """ if use_uv and python_root: cmd = ["uv", "run", "python", str(sample_path)] @@ -75,29 +76,69 @@ def run_sample( cmd = [sys.executable, sample_path.name] cwd = sample_path.parent + # Set environment variables to handle Unicode properly + env = os.environ.copy() + env["PYTHONIOENCODING"] = "utf-8" # Force Python to use UTF-8 for I/O + env["PYTHONUTF8"] = "1" # Enable UTF-8 mode in Python 3.7+ + try: - result = subprocess.run( - cmd, - cwd=cwd, - capture_output=True, - text=True, - timeout=60, # 60 second timeout + # Use Popen for better timeout handling with stdin for samples that may wait for input + # Popen gives us more control over process lifecycle compared to subprocess.run() + process = subprocess.Popen( + cmd, # Command to execute as a list [program, arg1, arg2, ...] + cwd=cwd, # Working directory for the subprocess + stdout=subprocess.PIPE, # Capture stdout so we can read the output + stderr=subprocess.PIPE, # Capture stderr so we can read error messages + stdin=subprocess.PIPE, # Create a pipe for stdin so we can send input + text=True, # Handle input/output as text strings (not bytes) + encoding="utf-8", # Use UTF-8 encoding to handle Unicode characters like emojis + errors="replace", # Replace problematic characters instead of failing + env=env, # Pass environment variables for proper Unicode handling ) - if result.returncode == 0: - output = result.stdout.strip() if result.stdout.strip() else "No output" - return True, output, "" - - error_info = f"Exit code: {result.returncode}" - if result.stderr.strip(): - error_info += f"\nSTDERR: {result.stderr}" - - return False, result.stdout.strip() if result.stdout.strip() else "", error_info - - except subprocess.TimeoutExpired: - return False, "", f"TIMEOUT: {sample_path.name} (exceeded 60 seconds)" + try: + # communicate() sends input to stdin and waits for process to complete + # input="" sends an empty string to stdin, which causes input() calls to + # immediately receive EOFError (End Of File) since there's no data to read. + # This prevents the process from hanging indefinitely waiting for user input. + stdout, stderr = process.communicate(input="", timeout=60) + except subprocess.TimeoutExpired: + # If the process doesn't complete within the timeout period, we need to + # forcibly terminate it. This is especially important for processes that + # ignore EOFError and continue to hang on input() calls. + + # First attempt: Send SIGKILL (immediate termination) on Unix or TerminateProcess on Windows + process.kill() + try: + # Give the process a few seconds to clean up after being killed + stdout, stderr = process.communicate(timeout=5) + except subprocess.TimeoutExpired: + # If the process is still alive after kill(), use terminate() as a last resort + # terminate() sends SIGTERM (graceful termination request) which may work + # when kill() doesn't on some systems + process.terminate() + stdout, stderr = "", "Process forcibly terminated" + return False, "", f"TIMEOUT: {sample_path.name} (exceeded 60 seconds)", "timeout" + + if process.returncode == 0: + output = stdout.strip() if stdout.strip() else "No output" + return True, output, "", "success" + + error_info = f"Exit code: {process.returncode}" + if stderr.strip(): + error_info += f"\nSTDERR: {stderr}" + + # Check if this looks like an input/interaction related error + error_type = "execution_error" + stderr_safe = stderr.encode("utf-8", errors="replace").decode("utf-8") if stderr else "" + if "EOFError" in stderr_safe or "input" in stderr_safe.lower() or "stdin" in stderr_safe.lower(): + error_type = "input_hang" + elif "UnicodeEncodeError" in stderr_safe and ("charmap" in stderr_safe or "codec can't encode" in stderr_safe): + error_type = "input_hang" # Unicode errors often indicate interactive samples with emojis + + return False, stdout.strip() if stdout.strip() else "", error_info, error_type except Exception as e: - return False, "", f"ERROR: {sample_path.name} - Exception: {str(e)}" + return False, "", f"ERROR: {sample_path.name} - Exception: {str(e)}", "exception" def parse_arguments() -> argparse.Namespace: @@ -161,7 +202,7 @@ def main() -> None: print(f"Found {len(sample_files)} Python sample files") # Run samples concurrently - results: list[tuple[Path, bool, str, str]] = [] + results: list[tuple[Path, bool, str, str, str]] = [] with ThreadPoolExecutor(max_workers=args.max_workers) as executor: # Submit all tasks @@ -174,53 +215,81 @@ def main() -> None: for future in as_completed(future_to_sample): sample_path = future_to_sample[future] try: - success, output, error_info = future.result() - results.append((sample_path, success, output, error_info)) + success, output, error_info, error_type = future.result() + results.append((sample_path, success, output, error_info, error_type)) # Print progress - show relative path from samples directory relative_path = sample_path.relative_to(samples_dir) if success: print(f"✅ {relative_path}") else: - print(f"❌ {relative_path} - {error_info.split(':', 1)[0]}") + # Show error type in progress display + error_display = f"{error_type.upper()}" if error_type != "execution_error" else "ERROR" + print(f"❌ {relative_path} - {error_display}") except Exception as e: error_info = f"Future exception: {str(e)}" - results.append((sample_path, False, "", error_info)) + results.append((sample_path, False, "", error_info, "exception")) relative_path = sample_path.relative_to(samples_dir) - print(f"❌ {relative_path} - {error_info}") + print(f"❌ {relative_path} - EXCEPTION") # Sort results by original file order for consistent reporting sample_to_index = {path: i for i, path in enumerate(sample_files)} results.sort(key=lambda x: sample_to_index[x[0]]) - successful_runs = sum(1 for _, success, _, _ in results if success) + successful_runs = sum(1 for _, success, _, _, _ in results if success) failed_runs = len(results) - successful_runs + # Categorize failures by type + timeout_failures = [r for r in results if not r[1] and r[4] == "timeout"] + input_hang_failures = [r for r in results if not r[1] and r[4] == "input_hang"] + execution_errors = [r for r in results if not r[1] and r[4] == "execution_error"] + exceptions = [r for r in results if not r[1] and r[4] == "exception"] + # Print detailed results print(f"\n{'=' * 80}") print("DETAILED RESULTS:") print(f"{'=' * 80}") - for sample_path, success, output, error_info in results: + for sample_path, success, output, error_info, error_type in results: relative_path = sample_path.relative_to(samples_dir) if success: print(f"✅ {relative_path}") if output and output != "No output": print(f" Output preview: {output[:100]}{'...' if len(output) > 100 else ''}") else: - print(f"❌ {relative_path}") + # Display error with type indicator + if error_type == "timeout": + print(f"⏱️ {relative_path} - TIMEOUT (likely waiting for input)") + elif error_type == "input_hang": + print(f"⌨️ {relative_path} - INPUT ERROR (interactive sample)") + elif error_type == "exception": + print(f"💥 {relative_path} - EXCEPTION") + else: + print(f"❌ {relative_path} - EXECUTION ERROR") print(f" Error: {error_info}") - # Print summary + # Print categorized summary print(f"\n{'=' * 80}") if failed_runs == 0: print("🎉 ALL SAMPLES COMPLETED SUCCESSFULLY!") else: print(f"❌ {failed_runs} SAMPLE(S) FAILED!") + print(f"Successful runs: {successful_runs}") print(f"Failed runs: {failed_runs}") + if failed_runs > 0: + print("\nFailure breakdown:") + if len(timeout_failures) > 0: + print(f" ⏱️ Timeouts (likely interactive): {len(timeout_failures)}") + if len(input_hang_failures) > 0: + print(f" ⌨️ Input errors (interactive): {len(input_hang_failures)}") + if len(execution_errors) > 0: + print(f" ❌ Execution errors: {len(execution_errors)}") + if len(exceptions) > 0: + print(f" 💥 Exceptions: {len(exceptions)}") + if args.subdir: print(f"Subdirectory filter: {args.subdir}") diff --git a/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py b/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py index 7bdf3e922a..79028cc325 100644 --- a/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py +++ b/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py @@ -103,7 +103,7 @@ async def main() -> None: worker = Worker(id="sub-worker", chat_client=mini_chat_client) reviewer = ReviewerWithHumanInTheLoop(worker_id=worker.id) - print("Building workflow with Worker ↔ Reviewer cycle...") + print("Building workflow with Worker-Reviewer cycle...") # Build a workflow with bidirectional communication between Worker and Reviewer, # and escalation paths for human review. agent = ( From 042f61041a9f54cde8b5897b495e97b97c691381 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 24 Oct 2025 16:58:38 -0700 Subject: [PATCH 12/26] Fix unit tests --- .../core/agent_framework/_workflows/_edge.py | 2 + .../agent_framework/_workflows/_handoff.py | 4 +- .../agent_framework/_workflows/_magentic.py | 6 +- .../_workflows/_typing_utils.py | 2 +- .../tests/workflow/test_checkpoint_decode.py | 93 +++++- .../core/tests/workflow/test_executor.py | 8 +- .../tests/workflow/test_function_executor.py | 29 +- .../core/tests/workflow/test_magentic.py | 11 +- .../test_request_info_event_rehydrate.py | 169 +++++++++++ .../test_request_info_executor_rehydrate.py | 285 ------------------ .../core/tests/workflow/test_serialization.py | 57 +++- .../core/tests/workflow/test_sub_workflow.py | 271 +++++++++-------- .../core/tests/workflow/test_typing_utils.py | 17 -- .../core/tests/workflow/test_validation.py | 4 +- .../core/tests/workflow/test_workflow.py | 74 +++-- .../tests/workflow/test_workflow_agent.py | 22 +- .../tests/workflow/test_workflow_builder.py | 2 +- .../tests/workflow/test_workflow_states.py | 54 ++-- 18 files changed, 556 insertions(+), 554 deletions(-) create mode 100644 python/packages/core/tests/workflow/test_request_info_event_rehydrate.py delete mode 100644 python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index 22bd0255ff..70eafc7d38 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -868,6 +868,8 @@ def to_dict(self) -> dict[str, Any]: return payload +@EdgeGroup.register +@dataclass(init=False) class InternalEdgeGroup(EdgeGroup): """Special edge group used to route internal messages to executors. diff --git a/python/packages/core/agent_framework/_workflows/_handoff.py b/python/packages/core/agent_framework/_workflows/_handoff.py index dbb7acc9ef..0db50b1426 100644 --- a/python/packages/core/agent_framework/_workflows/_handoff.py +++ b/python/packages/core/agent_framework/_workflows/_handoff.py @@ -540,13 +540,13 @@ async def request_input(self, conversation: list[ChatMessage], ctx: WorkflowCont prompt=self._prompt, source_executor_id=self.id, ) - await ctx.request_info(request, HandoffUserInputRequest, Any) + await ctx.request_info(request, HandoffUserInputRequest, object) @response_handler async def resume_from_user( self, original_request: HandoffUserInputRequest, - response: Any, + response: object, ctx: WorkflowContext[_ConversationWithUserInput], ) -> None: """Convert user input responses back into chat messages and resume the workflow.""" diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index acb41b2cef..9130fdbc3b 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -2149,11 +2149,10 @@ async def run_stream_from_checkpoint( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None = None, - responses: dict[str, Any] | None = None, ) -> AsyncIterable[WorkflowEvent]: """Resume orchestration from a checkpoint and stream resulting events.""" await self._validate_checkpoint_participants(checkpoint_id, checkpoint_storage) - async for event in self._workflow.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage, responses): + async for event in self._workflow.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage): yield event async def run_with_string(self, task_text: str) -> WorkflowRunResult: @@ -2203,11 +2202,10 @@ async def run_from_checkpoint( self, checkpoint_id: str, checkpoint_storage: CheckpointStorage | None = None, - responses: dict[str, Any] | None = None, ) -> WorkflowRunResult: """Resume orchestration from a checkpoint and collect all resulting events.""" events: list[WorkflowEvent] = [] - async for event in self.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage, responses): + async for event in self.run_stream_from_checkpoint(checkpoint_id, checkpoint_storage): events.append(event) return WorkflowRunResult(events) diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index 4212a7f5d3..f2e355d5b3 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -145,7 +145,7 @@ def deserialize_type(serialized_type_string: str) -> type: """ import importlib - module_name, _, type_name = serialized_type_string.partition(".") + module_name, _, type_name = serialized_type_string.rpartition(".") module = importlib.import_module(module_name) return getattr(module, type_name) diff --git a/python/packages/core/tests/workflow/test_checkpoint_decode.py b/python/packages/core/tests/workflow/test_checkpoint_decode.py index 16d7a17c7a..b126eafacf 100644 --- a/python/packages/core/tests/workflow/test_checkpoint_decode.py +++ b/python/packages/core/tests/workflow/test_checkpoint_decode.py @@ -3,7 +3,6 @@ from dataclasses import dataclass # noqa: I001 from typing import Any, cast -from agent_framework._workflows._request_info_executor import RequestInfoMessage, RequestResponse from agent_framework._workflows._checkpoint_encoding import ( decode_checkpoint_value, encode_checkpoint_value, @@ -11,30 +10,45 @@ from agent_framework._workflows._typing_utils import is_instance_of -@dataclass(kw_only=True) -class SampleRequest(RequestInfoMessage): +@dataclass +class SampleRequest: + """Sample request message for testing checkpoint encoding/decoding.""" + + request_id: str prompt: str +@dataclass +class SampleResponse: + """Sample response message for testing checkpoint encoding/decoding.""" + + data: str + original_request: SampleRequest + request_id: str + + def test_decode_dataclass_with_nested_request() -> None: - original = RequestResponse[SampleRequest, str]( + """Test that dataclass with nested dataclass fields can be encoded and decoded correctly.""" + original = SampleResponse( data="approve", original_request=SampleRequest(request_id="abc", prompt="prompt"), request_id="abc", ) encoded = encode_checkpoint_value(original) - decoded = cast(RequestResponse[SampleRequest, str], decode_checkpoint_value(encoded)) + decoded = cast(SampleResponse, decode_checkpoint_value(encoded)) - assert isinstance(decoded, RequestResponse) + assert isinstance(decoded, SampleResponse) assert decoded.data == "approve" assert decoded.request_id == "abc" assert isinstance(decoded.original_request, SampleRequest) assert decoded.original_request.prompt == "prompt" + assert decoded.original_request.request_id == "abc" -def test_is_instance_of_coerces_request_response_original_request_dict() -> None: - response = RequestResponse[SampleRequest, str]( +def test_is_instance_of_coerces_nested_dataclass_dict() -> None: + """Test that is_instance_of can handle nested structures with dict conversion.""" + response = SampleResponse( data="approve", original_request=SampleRequest(request_id="req-1", prompt="prompt"), request_id="req-1", @@ -49,5 +63,66 @@ def test_is_instance_of_coerces_request_response_original_request_dict() -> None }, ) - assert is_instance_of(response, RequestResponse[SampleRequest, str]) + assert is_instance_of(response, SampleResponse) + assert isinstance(response.original_request, dict) + + # Verify the dict contains expected values + dict_request = cast(dict[str, Any], response.original_request) + assert dict_request["request_id"] == "req-1" + assert dict_request["prompt"] == "prompt" + + +def test_encode_decode_simple_dataclass() -> None: + """Test encoding and decoding of a simple dataclass.""" + original = SampleRequest(request_id="test-123", prompt="test prompt") + + encoded = encode_checkpoint_value(original) + decoded = cast(SampleRequest, decode_checkpoint_value(encoded)) + + assert isinstance(decoded, SampleRequest) + assert decoded.request_id == "test-123" + assert decoded.prompt == "test prompt" + + +def test_encode_decode_nested_structures() -> None: + """Test encoding and decoding of complex nested structures.""" + nested_data = { + "requests": [ + SampleRequest(request_id="req-1", prompt="first prompt"), + SampleRequest(request_id="req-2", prompt="second prompt"), + ], + "responses": { + "req-1": SampleResponse( + data="first response", + original_request=SampleRequest(request_id="req-1", prompt="first prompt"), + request_id="req-1", + ), + }, + } + + encoded = encode_checkpoint_value(nested_data) + decoded = decode_checkpoint_value(encoded) + + assert isinstance(decoded, dict) + assert "requests" in decoded + assert "responses" in decoded + + # Check the requests list + requests = cast(list[Any], decoded["requests"]) + assert isinstance(requests, list) + assert len(requests) == 2 + assert all(isinstance(req, SampleRequest) for req in requests) + first_request = cast(SampleRequest, requests[0]) + second_request = cast(SampleRequest, requests[1]) + assert first_request.request_id == "req-1" + assert second_request.request_id == "req-2" + + # Check the responses dict + responses = cast(dict[str, Any], decoded["responses"]) + assert isinstance(responses, dict) + assert "req-1" in responses + response = cast(SampleResponse, responses["req-1"]) + assert isinstance(response, SampleResponse) + assert response.data == "first response" assert isinstance(response.original_request, SampleRequest) + assert response.original_request.request_id == "req-1" diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 7f3d5bfc3e..952d6bab60 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -2,7 +2,7 @@ import pytest -from agent_framework import Executor, WorkflowContext, handler +from agent_framework import Executor, Message, WorkflowContext, handler def test_executor_without_id(): @@ -64,9 +64,9 @@ async def handle_number(self, number: int, ctx: WorkflowContext) -> None: # typ executor = MockExecutorWithValidHandlers(id="test") assert executor.id is not None assert len(executor._handlers) == 2 # type: ignore - assert executor.can_handle("text") is True - assert executor.can_handle(42) is True - assert executor.can_handle(3.14) is False + assert executor.can_handle(Message(data="text", source_id="mock")) is True + assert executor.can_handle(Message(data=42, source_id="mock")) is True + assert executor.can_handle(Message(data=3.14, source_id="mock")) is False def test_executor_handlers_with_output_types(): diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index 229e399b09..fcf82ebca4 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -7,6 +7,7 @@ from agent_framework import ( FunctionExecutor, + Message, WorkflowBuilder, WorkflowContext, executor, @@ -230,9 +231,9 @@ def test_can_handle_method(self): async def string_processor(text: str, ctx: WorkflowContext[str]) -> None: await ctx.send_message(text) - assert string_processor.can_handle("hello") - assert not string_processor.can_handle(123) - assert not string_processor.can_handle([]) + assert string_processor.can_handle(Message(data="hello", source_id="Mock")) + assert not string_processor.can_handle(Message(data=123, source_id="Mock")) + assert not string_processor.can_handle(Message(data=[], source_id="Mock")) def test_duplicate_handler_registration(self): """Test that registering duplicate handlers raises an error.""" @@ -309,9 +310,9 @@ def test_single_parameter_can_handle(self): async def int_processor(value: int): return value * 2 - assert int_processor.can_handle(42) - assert not int_processor.can_handle("hello") - assert not int_processor.can_handle([]) + assert int_processor.can_handle(Message(data=42, source_id="mock")) + assert not int_processor.can_handle(Message(data="hello", source_id="mock")) + assert not int_processor.can_handle(Message(data=[], source_id="mock")) async def test_single_parameter_execution(self): """Test that single-parameter functions can be executed properly.""" @@ -325,7 +326,7 @@ async def double_value(value: int): WorkflowBuilder().set_start_executor(double_value).build() # For testing purposes, we can check that the handler is registered correctly - assert double_value.can_handle(5) + assert double_value.can_handle(Message(data=5, source_id="mock")) assert int in double_value._handlers def test_sync_function_basic(self): @@ -369,9 +370,9 @@ def test_sync_function_can_handle(self): def string_handler(text: str): return text.strip() - assert string_handler.can_handle("hello") - assert not string_handler.can_handle(123) - assert not string_handler.can_handle([]) + assert string_handler.can_handle(Message(data="hello", source_id="mock")) + assert not string_handler.can_handle(Message(data=123, source_id="mock")) + assert not string_handler.can_handle(Message(data=[], source_id="mock")) def test_sync_function_validation(self): """Test validation for synchronous functions.""" @@ -413,8 +414,8 @@ async def async_func(data: str): assert isinstance(async_func, FunctionExecutor) # Both should handle strings - assert sync_func.can_handle("test") - assert async_func.can_handle("test") + assert sync_func.can_handle(Message(data="test", source_id="mock")) + assert async_func.can_handle(Message(data="test", source_id="mock")) # Both should be different instances assert sync_func is not async_func @@ -443,8 +444,8 @@ async def reverse_async(text: str, ctx: WorkflowContext[Any, str]): assert async_spec["workflow_output_types"] == [str] # Second parameter is str # Verify the executors can handle their input types - assert to_upper_sync.can_handle("hello") - assert reverse_async.can_handle("HELLO") + assert to_upper_sync.can_handle(Message(data="hello", source_id="mock")) + assert reverse_async.can_handle(Message(data="HELLO", source_id="mock")) # For integration testing, we mainly verify that the handlers are properly registered # and the functions are wrapped correctly diff --git a/python/packages/core/tests/workflow/test_magentic.py b/python/packages/core/tests/workflow/test_magentic.py index b52449a928..7f6f71df13 100644 --- a/python/packages/core/tests/workflow/test_magentic.py +++ b/python/packages/core/tests/workflow/test_magentic.py @@ -310,7 +310,6 @@ async def test_magentic_checkpoint_resume_round_trip(): async for ev in wf.run_stream(task_text): if isinstance(ev, RequestInfoEvent) and ev.request_type is MagenticPlanReviewRequest: req_event = ev - break assert req_event is not None checkpoints = await storage.list_checkpoints() @@ -334,10 +333,16 @@ async def test_magentic_checkpoint_resume_round_trip(): reply = MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE) completed: WorkflowOutputEvent | None = None + req_event = None async for event in wf_resume.workflow.run_stream_from_checkpoint( resume_checkpoint.checkpoint_id, - responses={req_event.request_id: reply}, ): + if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: + req_event = event + assert req_event is not None + + responses = {req_event.request_id: reply} + async for event in wf_resume.workflow.send_responses_streaming(responses=responses): if isinstance(event, WorkflowOutputEvent): completed = event assert completed is not None @@ -669,7 +674,6 @@ async def test_magentic_checkpoint_resume_rejects_participant_renames(): async for event in workflow.run_stream("task"): if isinstance(event, RequestInfoEvent) and event.request_type is MagenticPlanReviewRequest: req_event = event - break assert req_event is not None @@ -688,7 +692,6 @@ async def test_magentic_checkpoint_resume_rejects_participant_renames(): with pytest.raises(RuntimeError, match="participant names do not match"): async for _ in renamed_workflow.run_stream_from_checkpoint( target_checkpoint.checkpoint_id, # type: ignore[reportUnknownMemberType] - responses={req_event.request_id: MagenticPlanReviewReply(decision=MagenticPlanReviewDecision.APPROVE)}, ): pass diff --git a/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py new file mode 100644 index 0000000000..80dc7e3004 --- /dev/null +++ b/python/packages/core/tests/workflow/test_request_info_event_rehydrate.py @@ -0,0 +1,169 @@ +# Copyright (c) Microsoft. All rights reserved. + +import json +from dataclasses import dataclass, field +from datetime import datetime, timezone + +import pytest + +from agent_framework import InMemoryCheckpointStorage, InProcRunnerContext +from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value +from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary +from agent_framework._workflows._events import RequestInfoEvent +from agent_framework._workflows._shared_state import SharedState + + +@dataclass +class MockRequest: ... + + +@dataclass(kw_only=True) +class SimpleApproval: + prompt: str = "" + draft: str = "" + iteration: int = 0 + + +@dataclass(slots=True) +class SlottedApproval: + note: str = "" + + +@dataclass +class TimedApproval: + issued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + + +async def test_rehydrate_request_info_event() -> None: + """Rehydration should succeed for valid request info events.""" + request_info_event = RequestInfoEvent( + request_id="request-123", + source_executor_id="review_gateway", + request_type=MockRequest, + request_data=MockRequest(), + response_type=bool, + ) + + runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) + await runner_context.add_request_info_event(request_info_event) + + checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint = await runner_context.load_checkpoint(checkpoint_id) + + assert checkpoint is not None + assert checkpoint.pending_request_info_events + assert "request-123" in checkpoint.pending_request_info_events + assert "request_type" in checkpoint.pending_request_info_events["request-123"] + + # Rehydrate the context + await runner_context.apply_checkpoint(checkpoint) + + pending_requests = await runner_context.get_pending_request_info_events() + assert "request-123" in pending_requests + rehydrated_event = pending_requests["request-123"] + assert rehydrated_event.request_id == "request-123" + assert rehydrated_event.source_executor_id == "review_gateway" + assert rehydrated_event.request_type is MockRequest + assert rehydrated_event.response_type is bool + assert isinstance(rehydrated_event.data, MockRequest) + + +async def test_rehydrate_fails_when_request_type_missing() -> None: + """Rehydration should fail is the request type is missing or fails to import.""" + request_info_event = RequestInfoEvent( + request_id="request-123", + source_executor_id="review_gateway", + request_type=MockRequest, + request_data=MockRequest(), + response_type=bool, + ) + + runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) + await runner_context.add_request_info_event(request_info_event) + + checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint = await runner_context.load_checkpoint(checkpoint_id) + + assert checkpoint is not None + assert checkpoint.pending_request_info_events + assert "request-123" in checkpoint.pending_request_info_events + assert "request_type" in checkpoint.pending_request_info_events["request-123"] + + # Modify the checkpoint to simulate missing request type + checkpoint.pending_request_info_events["request-123"]["request_type"] = "nonexistent.module:MissingRequest" + + # Rehydrate the context + with pytest.raises(ImportError): + await runner_context.apply_checkpoint(checkpoint) + + +async def test_pending_requests_in_summary() -> None: + """Test that pending requests are correctly summarized in the checkpoint summary.""" + request_info_event = RequestInfoEvent( + request_id="request-123", + source_executor_id="review_gateway", + request_type=MockRequest, + request_data=MockRequest(), + response_type=bool, + ) + + runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) + await runner_context.add_request_info_event(request_info_event) + + checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint = await runner_context.load_checkpoint(checkpoint_id) + + assert checkpoint is not None + summary = get_checkpoint_summary(checkpoint) + + assert summary.checkpoint_id == checkpoint_id + assert summary.status == "awaiting request response" + + assert len(summary.pending_request_info_events) == 1 + pending_event = summary.pending_request_info_events[0] + assert isinstance(pending_event, RequestInfoEvent) + assert pending_event.request_id == "request-123" + + assert pending_event.source_executor_id == "review_gateway" + assert pending_event.request_type is MockRequest + assert pending_event.response_type is bool + assert isinstance(pending_event.data, MockRequest) + + +async def test_request_info_event_serializes_non_json_payloads() -> None: + req_1 = RequestInfoEvent( + request_id="req-1", + source_executor_id="source", + request_type=TimedApproval, + request_data=TimedApproval(issued_at=datetime(2024, 5, 4, 12, 30, 45)), + response_type=bool, + ) + req_2 = RequestInfoEvent( + request_id="req-2", + source_executor_id="source", + request_type=SlottedApproval, + request_data=SlottedApproval(note="slot-based"), + response_type=bool, + ) + + runner_context = InProcRunnerContext(InMemoryCheckpointStorage()) + await runner_context.add_request_info_event(req_1) + await runner_context.add_request_info_event(req_2) + + checkpoint_id = await runner_context.create_checkpoint(SharedState(), iteration_count=1) + checkpoint = await runner_context.load_checkpoint(checkpoint_id) + + # Should be JSON serializable despite datetime/slots + serialized = json.dumps(encode_checkpoint_value(checkpoint)) + deserialized = json.loads(serialized) + + assert "value" in deserialized + deserialized = deserialized["value"] + + assert "pending_request_info_events" in deserialized + pending_request_info_events = deserialized["pending_request_info_events"] + assert "req-1" in pending_request_info_events + assert isinstance(pending_request_info_events["req-1"]["data"]["value"]["issued_at"], str) + + assert "req-2" in pending_request_info_events + assert pending_request_info_events["req-2"]["data"]["value"]["note"] == "slot-based" diff --git a/python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py b/python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py deleted file mode 100644 index 91bc716829..0000000000 --- a/python/packages/core/tests/workflow/test_request_info_executor_rehydrate.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -import json -from dataclasses import dataclass, field -from datetime import datetime, timezone -from typing import Any - -from agent_framework._workflows._checkpoint import WorkflowCheckpoint -from agent_framework._workflows._checkpoint_encoding import encode_checkpoint_value -from agent_framework._workflows._checkpoint_summary import get_checkpoint_summary -from agent_framework._workflows._const import EXECUTOR_STATE_KEY -from agent_framework._workflows._events import RequestInfoEvent, WorkflowEvent -from agent_framework._workflows._request_info_executor import ( - PendingRequestDetails, - PendingRequestSnapshot, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, -) -from agent_framework._workflows._runner_context import Message -from agent_framework._workflows._shared_state import SharedState -from agent_framework._workflows._workflow_context import WorkflowContext - -PENDING_STATE_KEY = RequestInfoExecutor._PENDING_SHARED_STATE_KEY # pyright: ignore[reportPrivateUsage] - - -class _StubRunnerContext: - """Minimal runner context stub for exercising WorkflowContext helpers.""" - - async def send_message(self, message: Message) -> None: # pragma: no cover - unused in tests - return None - - async def drain_messages(self) -> dict[str, list[Message]]: # pragma: no cover - unused - return {} - - async def has_messages(self) -> bool: # pragma: no cover - unused - return False - - async def add_event(self, event: WorkflowEvent) -> None: # pragma: no cover - unused - return None - - async def drain_events(self) -> list[WorkflowEvent]: # pragma: no cover - unused - return [] - - async def has_events(self) -> bool: # pragma: no cover - unused - return False - - async def next_event(self) -> WorkflowEvent: # pragma: no cover - unused - raise RuntimeError("Not implemented in stub context") - - def has_checkpointing(self) -> bool: # pragma: no cover - unused - return False - - def set_workflow_id(self, workflow_id: str) -> None: # pragma: no cover - unused - pass - - def reset_for_new_run(self) -> None: # pragma: no cover - unused - pass - - async def create_checkpoint( - self, - shared_state: SharedState, - iteration_count: int, - metadata: dict[str, Any] | None = None, - ) -> str: # pragma: no cover - unused - raise RuntimeError("Checkpointing not supported in stub context") - - async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pragma: no cover - unused - return None - - async def apply_checkpoint(self, checkpoint: WorkflowCheckpoint) -> None: # pragma: no cover - unused - pass - - def set_streaming(self, streaming: bool) -> None: # pragma: no cover - unused - pass - - def is_streaming(self) -> bool: # pragma: no cover - unused - return False - - -@dataclass(kw_only=True) -class SimpleApproval(RequestInfoMessage): - prompt: str = "" - draft: str = "" - iteration: int = 0 - - -@dataclass(slots=True) -class SlottedApproval(RequestInfoMessage): - note: str = "" - - -@dataclass -class TimedApproval(RequestInfoMessage): - issued_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - - -async def test_rehydrate_falls_back_when_request_type_missing() -> None: - """Rehydration should succeed even if the original request type cannot be imported. - - This simulates resuming a workflow where the HumanApprovalRequest class is unavailable - in the current process (e.g., defined in __main__ during the original run). - """ - request_id = "request-123" - snapshot = PendingRequestSnapshot( - request_id=request_id, - source_executor_id="review_gateway", - request_type="nonexistent.module:MissingRequest", - request_as_json_safe_dict={ - "request_id": request_id, - }, - ) - - ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], SharedState(), _StubRunnerContext()) - await ctx.set_executor_state({PENDING_STATE_KEY: {request_id: snapshot}}) - - executor = RequestInfoExecutor(id="request_info") - - event = await executor._rehydrate_request_event(request_id, ctx) # pyright: ignore[reportPrivateUsage] - - assert event is not None - assert event.request_id == request_id - assert isinstance(event.data, RequestInfoMessage) - - -async def test_has_pending_request_detects_snapshot() -> None: - request_id = "request-123" - snapshot = PendingRequestSnapshot( - request_id=request_id, - source_executor_id="review_gateway", - request_type="nonexistent.module:MissingRequest", - request_as_json_safe_dict={ - "request_id": request_id, - }, - ) - - ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], SharedState(), _StubRunnerContext()) - await ctx.set_executor_state({PENDING_STATE_KEY: {request_id: snapshot}}) - - executor = RequestInfoExecutor(id="request_info") - - assert await executor.has_pending_request(request_id, ctx) - - -async def test_has_pending_request_false_when_snapshot_absent() -> None: - ctx: WorkflowContext[Any] = WorkflowContext("request_info", ["workflow"], SharedState(), _StubRunnerContext()) - await ctx.set_executor_state({PENDING_STATE_KEY: {}}) - - executor = RequestInfoExecutor(id="request_info") - - assert not await executor.has_pending_request("missing", ctx) - - -def test_pending_requests_from_checkpoint_and_summary() -> None: - request = SimpleApproval(prompt="Review draft", draft="Draft text", iteration=3) - request.request_id = "req-42" - - response = RequestResponse[SimpleApproval, str]( - data="approve", - original_request=request, - request_id=request.request_id, - ) - - encoded_response = encode_checkpoint_value(response) - - checkpoint = WorkflowCheckpoint( - checkpoint_id="cp-1", - workflow_id="wf", - messages={ - "request_info": [ - { - "data": encoded_response, - "source_id": "request_info", - "target_id": "review_gateway", - } - ] - }, - shared_state={ - PENDING_STATE_KEY: { - request.request_id: { - "request_id": request.request_id, - "prompt": request.prompt, - "draft": request.draft, - "iteration": request.iteration, - "source_executor_id": "review_gateway", - } - } - }, - iteration_count=1, - ) - - summary = get_checkpoint_summary(checkpoint) - assert summary.checkpoint_id == "cp-1" - assert summary.status == "awaiting request response" - assert summary.pending_requests[0].request_id == "req-42" - - pending = summary.pending_requests - assert len(pending) == 1 - entry = pending[0] - assert isinstance(entry, PendingRequestDetails) - assert entry.request_id == "req-42" - assert entry.prompt == "Review draft" - assert entry.draft == "Draft text" - assert entry.iteration == 3 - assert entry.original_request is not None - - -def test_snapshot_state_serializes_non_json_payloads() -> None: - executor = RequestInfoExecutor(id="request_info") - - timed = TimedApproval(issued_at=datetime(2024, 5, 4, 12, 30, 45)) - timed.request_id = "timed" - slotted = SlottedApproval(note="slot-based") - slotted.request_id = "slotted" - - executor._request_events = { # pyright: ignore[reportPrivateUsage] - timed.request_id: RequestInfoEvent( - request_id=timed.request_id, - source_executor_id="source", - request_type=TimedApproval, - request_data=timed, - ), - slotted.request_id: RequestInfoEvent( - request_id=slotted.request_id, - source_executor_id="source", - request_type=SlottedApproval, - request_data=slotted, - ), - } - - state = executor.snapshot_state() - - # Should be JSON serializable despite datetime/slots - serialized = json.dumps(state) - assert "timed" in serialized - timed_payload = state["request_events"][timed.request_id]["request_data"]["value"] - assert isinstance(timed_payload["issued_at"], str) - - -def test_restore_state_falls_back_to_base_request_type() -> None: - executor = RequestInfoExecutor(id="request_info") - - approval = SimpleApproval(prompt="Review", draft="Draft", iteration=1) - approval.request_id = "req" - executor._request_events = { # pyright: ignore[reportPrivateUsage] - approval.request_id: RequestInfoEvent( - request_id=approval.request_id, - source_executor_id="source", - request_type=SimpleApproval, - request_data=approval, - ) - } - - state = executor.snapshot_state() - state["request_events"][approval.request_id]["request_type"] = "missing.module:GhostRequest" - - executor.restore_state(state) - - restored = executor._request_events[approval.request_id] # pyright: ignore[reportPrivateUsage] - assert restored.request_type is RequestInfoMessage - assert isinstance(restored.data, RequestInfoMessage) - - -async def test_run_persists_pending_requests_in_runner_state() -> None: - shared_state = SharedState() - runner_ctx = _StubRunnerContext() - ctx: WorkflowContext[None] = WorkflowContext("request_info", ["source"], shared_state, runner_ctx) - - executor = RequestInfoExecutor(id="request_info") - approval = SimpleApproval(prompt="Review", draft="Draft", iteration=1) - approval.request_id = "req-123" - - await executor.execute(approval, ctx.source_executor_ids, shared_state, runner_ctx) - - # Runner state should include both pending snapshot and serialized request events - assert await shared_state.has(EXECUTOR_STATE_KEY) - executor_state = await shared_state.get(EXECUTOR_STATE_KEY) - assert executor.id in executor_state - assert PENDING_STATE_KEY in executor_state[executor.id] - assert approval.request_id in executor_state[executor.id][PENDING_STATE_KEY] - - response_ctx: WorkflowContext[None] = WorkflowContext("request_info", ["source"], shared_state, runner_ctx) - await executor.handle_response("approved", approval.request_id, response_ctx) # type: ignore - - assert executor_state[executor.id][PENDING_STATE_KEY] == {} diff --git a/python/packages/core/tests/workflow/test_serialization.py b/python/packages/core/tests/workflow/test_serialization.py index 08245ef534..2bb8f305e9 100644 --- a/python/packages/core/tests/workflow/test_serialization.py +++ b/python/packages/core/tests/workflow/test_serialization.py @@ -6,12 +6,14 @@ import pytest from agent_framework import Executor, WorkflowBuilder, WorkflowContext, handler +from agent_framework._workflows._const import INTERNAL_SOURCE_ID from agent_framework._workflows._edge import ( Case, Default, Edge, FanInEdgeGroup, FanOutEdgeGroup, + InternalEdgeGroup, SingleEdgeGroup, SwitchCaseEdgeGroup, SwitchCaseEdgeGroupCase, @@ -557,16 +559,32 @@ def test_workflow_serialization(self) -> None: # Verify edge groups contain edges edge_groups = data["edge_groups"] - assert len(edge_groups) == 1, "Should have exactly one edge group" - edge_group = edge_groups[0] - assert "edges" in edge_group, "Edge group should contain 'edges' field" - assert len(edge_group["edges"]) == 1, "Should have exactly one edge" - edge = edge_group["edges"][0] - assert "source_id" in edge, "Edge should have source_id" - assert "target_id" in edge, "Edge should have target_id" - assert edge["source_id"] == "executor1", f"Expected source_id 'executor1', got {edge['source_id']}" - assert edge["target_id"] == "executor2", f"Expected target_id 'executor2', got {edge['target_id']}" + single_edge_groups = [SingleEdgeGroup.from_dict(eg) for eg in edge_groups if eg["type"] == "SingleEdgeGroup"] + internal_edge_groups = [ + InternalEdgeGroup.from_dict(eg) for eg in edge_groups if eg["type"] == "InternalEdgeGroup" + ] + + assert len(single_edge_groups) == 1, "Should have exactly one SingleEdgeGroup for the added edge" + assert len(internal_edge_groups) == 2, ( + "Should have exactly two (one per executor) InternalEdgeGroups for request/response handling" + ) + + for edge_group in single_edge_groups: + assert len(edge_group.edges) == 1, "Should have exactly one edge" + + edge = edge_group.edges[0] + + assert edge.source_id == "executor1", f"Expected source_id 'executor1', got {edge.source_id}" + assert edge.target_id == "executor2", f"Expected target_id 'executor2', got {edge.target_id}" + + for edge_group in internal_edge_groups: + assert len(edge_group.edges) == 1, "Each InternalEdgeGroup should have exactly one edge" + + edge = edge_group.edges[0] + + assert edge.source_id == INTERNAL_SOURCE_ID(edge.target_id) + assert edge.target_id in [executor1.id, executor2.id] # Test model_dump_json json_str = workflow.to_json() @@ -577,12 +595,21 @@ def test_workflow_serialization(self) -> None: # Verify edges are preserved in JSON serialization json_edge_groups = parsed["edge_groups"] - assert len(json_edge_groups) == 1, "JSON should have exactly one edge group" - json_edge_group = json_edge_groups[0] - assert "edges" in json_edge_group, "JSON edge group should contain 'edges' field" - json_edge = json_edge_group["edges"][0] - assert json_edge["source_id"] == "executor1", "JSON should preserve edge source_id" - assert json_edge["target_id"] == "executor2", "JSON should preserve edge target_id" + assert len(json_edge_groups) == 1 + 2, "JSON should have exactly one SingleEdgeGroup and two InternalEdgeGroups" + + for json_edge_group in json_edge_groups: + assert "edges" in json_edge_group, "JSON edge group should contain 'edges' field" + assert len(json_edge_group["edges"]) == 1, "Each JSON edge group should have exactly one edge" + if json_edge_group["type"] == "SingleEdgeGroup": + json_edge = json_edge_group["edges"][0] + assert json_edge["source_id"] == "executor1", "JSON should preserve edge source_id" + assert json_edge["target_id"] == "executor2", "JSON should preserve edge target_id" + elif json_edge_group["type"] == "InternalEdgeGroup": + json_edge = json_edge_group["edges"][0] + assert json_edge["source_id"] == INTERNAL_SOURCE_ID(json_edge["target_id"]) + assert json_edge["target_id"] in [executor1.id, executor2.id] + else: + pytest.fail(f"Unexpected edge group type: {json_edge_group['type']}") def test_workflow_serialization_excludes_non_serializable_fields(self) -> None: """Test that non-serializable fields are excluded from serialization.""" diff --git a/python/packages/core/tests/workflow/test_sub_workflow.py b/python/packages/core/tests/workflow/test_sub_workflow.py index 2c787fe658..8fdde7f62e 100644 --- a/python/packages/core/tests/workflow/test_sub_workflow.py +++ b/python/packages/core/tests/workflow/test_sub_workflow.py @@ -1,20 +1,20 @@ # Copyright (c) Microsoft. All rights reserved. -from dataclasses import dataclass -from typing import Any +from dataclasses import dataclass, field +from uuid import uuid4 from typing_extensions import Never from agent_framework import ( Executor, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, + SubWorkflowRequestMessage, + SubWorkflowResponseMessage, Workflow, WorkflowBuilder, WorkflowContext, WorkflowExecutor, handler, + response_handler, ) @@ -27,9 +27,10 @@ class EmailValidationRequest: @dataclass -class DomainCheckRequest(RequestInfoMessage): +class DomainCheckRequest: """Request to check if a domain is approved.""" + id: str = field(default_factory=lambda: str(uuid4())) domain: str = "" email: str = "" # Include original email for correlation @@ -43,72 +44,93 @@ class ValidationResult: reason: str -# Test helper functions -def create_email_validation_workflow() -> Workflow: - """Create a standard email validation workflow.""" - email_validator = EmailValidator() - email_request_info = RequestInfoExecutor(id="email_request_info") - - return ( - WorkflowBuilder() - .set_start_executor(email_validator) - .add_edge(email_validator, email_request_info) - .add_edge(email_request_info, email_validator) - .build() - ) - - -class BasicParent(Executor): - """Basic parent executor for simple sub-workflow tests.""" +class Coordinator(Executor): + """Coordinator executor in the parent workflow for simple sub-workflow tests.""" def __init__(self, cache: dict[str, bool] | None = None) -> None: super().__init__(id="basic_parent") self.result: ValidationResult | None = None self.cache: dict[str, bool] = dict(cache) if cache is not None else {} + self._pending_sub_workflow_requests: dict[str, SubWorkflowRequestMessage] = {} @handler async def start(self, email: str, ctx: WorkflowContext[EmailValidationRequest]) -> None: request = EmailValidationRequest(email=email) - await ctx.send_message(request, target_id="email_workflow") + await ctx.send_message(request) @handler async def handle_domain_request( self, - request: DomainCheckRequest, - ctx: WorkflowContext[RequestResponse[DomainCheckRequest, Any] | DomainCheckRequest], + sub_workflow_request: SubWorkflowRequestMessage, + ctx: WorkflowContext[SubWorkflowResponseMessage], ) -> None: """Handle requests from sub-workflows with optional caching.""" - domain_request = request + if not isinstance(sub_workflow_request.source_event.data, DomainCheckRequest): + raise ValueError("Unexpected request type") + + domain_request = sub_workflow_request.source_event.data if domain_request.domain in self.cache: # Return cached result - response = RequestResponse( - data=self.cache[domain_request.domain], original_request=request, request_id=request.request_id - ) - await ctx.send_message(response, target_id=request.source_executor_id) + await ctx.send_message(sub_workflow_request.create_response(self.cache[domain_request.domain])) else: # Not in cache, forward to external - await ctx.send_message(request) + self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request + await ctx.request_info(domain_request, DomainCheckRequest, bool) + + @response_handler + async def handle_domain_response( + self, + original_request: DomainCheckRequest, + is_approved: bool, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + """Handle domain check response with correlation and send the response back to the sub-workflow.""" + if original_request.id not in self._pending_sub_workflow_requests: + raise ValueError("No pending sub-workflow request for the given domain check response") + + sub_workflow_request = self._pending_sub_workflow_requests.pop(original_request.id) + await ctx.send_message(sub_workflow_request.create_response(is_approved)) @handler async def collect(self, result: ValidationResult, ctx: WorkflowContext) -> None: self.result = result -# Test executors -class EmailValidator(Executor): +class EmailFormatValidator(Executor): + """Validates the format of an email address.""" + + def __init__(self): + super().__init__(id="email_format_validator") + + @handler + async def validate( + self, request: EmailValidationRequest, ctx: WorkflowContext[DomainCheckRequest, ValidationResult] + ) -> None: + """Validate email format and extract domain.""" + email = request.email + if "@" not in email: + result = ValidationResult(email=email, is_valid=False, reason="Invalid email format") + await ctx.yield_output(result) + return + + domain = email.split("@")[1] + domain_check = DomainCheckRequest(domain=domain, email=email) + await ctx.send_message(domain_check) + + +class EmailDomainValidator(Executor): """Validates email addresses in a sub-workflow.""" def __init__(self): - super().__init__(id="email_validator") + super().__init__(id="email_domain_validator") @handler async def validate_request( - self, request: EmailValidationRequest, ctx: WorkflowContext[DomainCheckRequest, ValidationResult] + self, request: DomainCheckRequest, ctx: WorkflowContext[DomainCheckRequest, ValidationResult] ) -> None: """Validate an email address.""" - # Extract domain and check if it's approved - domain = request.email.split("@")[1] if "@" in request.email else "" + domain = request.domain if not domain: result = ValidationResult(email=request.email, is_valid=False, reason="Invalid email format") @@ -116,62 +138,37 @@ async def validate_request( return # Request domain check from external source - domain_check = DomainCheckRequest(domain=domain, email=request.email) - await ctx.send_message(domain_check) + await ctx.request_info(request, DomainCheckRequest, bool) - @handler + @response_handler async def handle_domain_response( - self, response: RequestResponse[DomainCheckRequest, bool], ctx: WorkflowContext[Never, ValidationResult] + self, + original_request: DomainCheckRequest, + is_approved: bool, + ctx: WorkflowContext[Never, ValidationResult], ) -> None: """Handle domain check response with correlation.""" # Use the original email from the correlated response result = ValidationResult( - email=response.original_request.email, - is_valid=response.data or False, - reason="Domain approved" if response.data else "Domain not approved", + email=original_request.email, + is_valid=is_approved, + reason="Domain approved" if is_approved else "Domain not approved", ) await ctx.yield_output(result) -class ParentOrchestrator(Executor): - """Parent workflow orchestrator with domain knowledge.""" - - def __init__(self, approved_domains: set[str] | None = None) -> None: - super().__init__(id="parent_orchestrator") - self.approved_domains: set[str] = ( - set(approved_domains) if approved_domains is not None else {"example.com", "test.org"} - ) - self.results: list[ValidationResult] = [] - - @handler - async def start(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None: - """Start processing emails.""" - for email in emails: - request = EmailValidationRequest(email=email) - await ctx.send_message(request, target_id="email_workflow") - - @handler - async def handle_domain_request( - self, - request: DomainCheckRequest, - ctx: WorkflowContext[RequestResponse[DomainCheckRequest, Any] | DomainCheckRequest], - ) -> None: - """Handle requests from sub-workflows.""" - domain_request = request - - # Check if we know this domain - if domain_request.domain in self.approved_domains: - # Send response back to sub-workflow - response = RequestResponse(data=True, original_request=request, request_id=request.request_id) - await ctx.send_message(response, target_id=request.source_executor_id) - else: - # We don't know this domain, forward to external - await ctx.send_message(request) +# Test helper functions +def create_email_validation_workflow() -> Workflow: + """Create a standard email validation workflow.""" + email_format_validator = EmailFormatValidator() + email_domain_validator = EmailDomainValidator() - @handler - async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) -> None: - """Collect validation results.""" - self.results.append(result) + return ( + WorkflowBuilder() + .set_start_executor(email_format_validator) + .add_edge(email_format_validator, email_domain_validator) + .build() + ) async def test_basic_sub_workflow() -> None: @@ -180,17 +177,14 @@ async def test_basic_sub_workflow() -> None: validation_workflow = create_email_validation_workflow() # Create parent workflow without interception - parent = BasicParent() - workflow_executor = WorkflowExecutor(validation_workflow, "email_workflow") - main_request_info = RequestInfoExecutor(id="main_request_info") + parent = Coordinator() + workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow") main_workflow = ( WorkflowBuilder() .set_start_executor(parent) .add_edge(parent, workflow_executor) .add_edge(workflow_executor, parent) - .add_edge(workflow_executor, main_request_info) - .add_edge(main_request_info, workflow_executor) # CRITICAL: For RequestResponse routing .build() ) @@ -220,17 +214,14 @@ async def test_sub_workflow_with_interception(): validation_workflow = create_email_validation_workflow() # Create parent workflow with interception cache - parent = BasicParent(cache={"example.com": True, "internal.org": True}) + parent = Coordinator(cache={"example.com": True, "internal.org": True}) workflow_executor = WorkflowExecutor(validation_workflow, "email_workflow") - parent_request_info = RequestInfoExecutor(id="request_info") main_workflow = ( WorkflowBuilder() .set_start_executor(parent) .add_edge(parent, workflow_executor) .add_edge(workflow_executor, parent) - .add_edge(parent, parent_request_info) # For forwarded requests - .add_edge(parent_request_info, workflow_executor) # For RequestResponse routing .build() ) @@ -276,6 +267,7 @@ class MultiWorkflowParent(Executor): def __init__(self) -> None: super().__init__(id="multi_parent") self.results: dict[str, ValidationResult] = {} + self._pending_sub_workflow_requests: dict[str, SubWorkflowRequestMessage] = {} @handler async def start(self, data: dict[str, str], ctx: WorkflowContext[EmailValidationRequest]) -> None: @@ -286,30 +278,47 @@ async def start(self, data: dict[str, str], ctx: WorkflowContext[EmailValidation @handler async def handle_domain_request( self, - request: DomainCheckRequest, - ctx: WorkflowContext[RequestResponse[DomainCheckRequest, Any] | DomainCheckRequest], + sub_workflow_request: SubWorkflowRequestMessage, + ctx: WorkflowContext[SubWorkflowResponseMessage], ) -> None: - domain_request = request + """Handle requests from sub-workflows with optional caching.""" + if not isinstance(sub_workflow_request.source_event.data, DomainCheckRequest): + raise ValueError("Unexpected request type") - if request.source_executor_id == "workflow_a": + domain_request = sub_workflow_request.source_event.data + + if sub_workflow_request.executor_id == "workflow_a" and domain_request.domain == "strict.com": # Strict rules for workflow A - if domain_request.domain == "strict.com": - response = RequestResponse(data=True, original_request=request, request_id=request.request_id) - await ctx.send_message(response, target_id=request.source_executor_id) - else: - # Forward to external - await ctx.send_message(request) - elif request.source_executor_id == "workflow_b": + await ctx.send_message( + sub_workflow_request.create_response(True), target_id=sub_workflow_request.executor_id + ) + return + if sub_workflow_request.executor_id == "workflow_b" and domain_request.domain.endswith(".com"): # Lenient rules for workflow B - if domain_request.domain.endswith(".com"): - response = RequestResponse(data=True, original_request=request, request_id=request.request_id) - await ctx.send_message(response, target_id=request.source_executor_id) - else: - # Forward to external - await ctx.send_message(request) - else: - # Unknown source, forward to external - await ctx.send_message(request) + await ctx.send_message( + sub_workflow_request.create_response(True), target_id=sub_workflow_request.executor_id + ) + return + + # Unknown source, forward to external + self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request + await ctx.request_info(domain_request, DomainCheckRequest, bool) + + @response_handler + async def handle_domain_response( + self, + original_request: DomainCheckRequest, + is_approved: bool, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + """Handle domain check response with correlation and send the response back to the sub-workflow.""" + if original_request.id not in self._pending_sub_workflow_requests: + raise ValueError("No pending sub-workflow request for the given domain check response") + + sub_workflow_request = self._pending_sub_workflow_requests.pop(original_request.id) + await ctx.send_message( + sub_workflow_request.create_response(is_approved), target_id=sub_workflow_request.executor_id + ) @handler async def collect(self, result: ValidationResult, ctx: WorkflowContext) -> None: @@ -322,7 +331,6 @@ async def collect(self, result: ValidationResult, ctx: WorkflowContext) -> None: parent = MultiWorkflowParent() executor_a = WorkflowExecutor(workflow_a, "workflow_a") executor_b = WorkflowExecutor(workflow_b, "workflow_b") - parent_request_info = RequestInfoExecutor(id="request_info") main_workflow = ( WorkflowBuilder() @@ -331,9 +339,6 @@ async def collect(self, result: ValidationResult, ctx: WorkflowContext) -> None: .add_edge(parent, executor_b) .add_edge(executor_a, parent) .add_edge(executor_b, parent) - .add_edge(parent, parent_request_info) - .add_edge(parent_request_info, executor_a) # For RequestResponse routing - .add_edge(parent_request_info, executor_b) # For RequestResponse routing .build() ) @@ -359,6 +364,7 @@ class ConcurrentProcessor(Executor): def __init__(self) -> None: super().__init__(id="concurrent_processor") self.results: list[ValidationResult] = [] + self._pending_sub_workflow_requests: dict[str, SubWorkflowRequestMessage] = {} @handler async def start(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None: @@ -366,7 +372,35 @@ async def start(self, emails: list[str], ctx: WorkflowContext[EmailValidationReq # Send all requests concurrently to the same workflow executor for email in emails: request = EmailValidationRequest(email=email) - await ctx.send_message(request, target_id="email_workflow") + await ctx.send_message(request) + + @handler + async def handle_domain_request( + self, + sub_workflow_request: SubWorkflowRequestMessage, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + """Handle requests from sub-workflows with optional caching.""" + if not isinstance(sub_workflow_request.source_event.data, DomainCheckRequest): + raise ValueError("Unexpected request type") + + domain_request = sub_workflow_request.source_event.data + self._pending_sub_workflow_requests[domain_request.id] = sub_workflow_request + await ctx.request_info(domain_request, DomainCheckRequest, bool) + + @response_handler + async def handle_domain_response( + self, + original_request: DomainCheckRequest, + is_approved: bool, + ctx: WorkflowContext[SubWorkflowResponseMessage], + ) -> None: + """Handle domain check response with correlation and send the response back to the sub-workflow.""" + if original_request.id not in self._pending_sub_workflow_requests: + raise ValueError("No pending sub-workflow request for the given domain check response") + + sub_workflow_request = self._pending_sub_workflow_requests.pop(original_request.id) + await ctx.send_message(sub_workflow_request.create_response(is_approved)) @handler async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) -> None: @@ -379,15 +413,12 @@ async def collect_result(self, result: ValidationResult, ctx: WorkflowContext) - # Create parent workflow processor = ConcurrentProcessor() workflow_executor = WorkflowExecutor(validation_workflow, "email_workflow") - parent_request_info = RequestInfoExecutor(id="request_info") main_workflow = ( WorkflowBuilder() .set_start_executor(processor) .add_edge(processor, workflow_executor) .add_edge(workflow_executor, processor) - .add_edge(workflow_executor, parent_request_info) # For external requests - .add_edge(parent_request_info, workflow_executor) # For RequestResponse routing .build() ) diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index 3f88726601..fc388fddc5 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar, Union -from agent_framework._workflows import RequestInfoMessage, RequestResponse from agent_framework._workflows._typing_utils import is_instance_of @@ -91,22 +90,6 @@ class CustomClass: assert not is_instance_of(instance, dict) -def test_request_response_type() -> None: - """Test RequestResponse generic type checking.""" - - request_instance = RequestResponse[RequestInfoMessage, str]( - data="approve", - request_id="req-1", - original_request=RequestInfoMessage(), - ) - - class CustomRequestInfoMessage(RequestInfoMessage): - info: str - - assert is_instance_of(request_instance, RequestResponse[RequestInfoMessage, str]) - assert not is_instance_of(request_instance, RequestResponse[CustomRequestInfoMessage, str]) - - def test_custom_generic_type() -> None: """Test custom generic type checking.""" diff --git a/python/packages/core/tests/workflow/test_validation.py b/python/packages/core/tests/workflow/test_validation.py index 37e7f18ed5..d7fc11aa66 100644 --- a/python/packages/core/tests/workflow/test_validation.py +++ b/python/packages/core/tests/workflow/test_validation.py @@ -183,7 +183,7 @@ def test_graph_connectivity_isolated_executors(): assert "executor3" in str(exc_info.value) -def test_start_executor_not_in_graph(): +def test_disconnected_start_executor_not_in_graph(): executor1 = StringExecutor(id="executor1") executor2 = StringExecutor(id="executor2") executor3 = StringExecutor(id="executor3") # Not in graph @@ -191,7 +191,7 @@ def test_start_executor_not_in_graph(): with pytest.raises(GraphConnectivityError) as exc_info: WorkflowBuilder().add_edge(executor1, executor2).set_start_executor(executor3).build() - assert "not present in the workflow graph" in str(exc_info.value) + assert "The following executors are unreachable from the start executor 'executor3'" in str(exc_info.value) def test_missing_start_executor(): diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index f66c7048e8..a5fc51618a 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -3,8 +3,9 @@ import asyncio import tempfile from collections.abc import AsyncIterable -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any +from uuid import uuid4 import pytest @@ -21,9 +22,6 @@ FileCheckpointStorage, Message, RequestInfoEvent, - RequestInfoExecutor, - RequestInfoMessage, - RequestResponse, Role, TextContent, WorkflowBuilder, @@ -33,6 +31,7 @@ WorkflowRunState, WorkflowStatusEvent, handler, + response_handler, ) @@ -68,6 +67,14 @@ async def mock_handler(self, messages: list[NumberMessage], ctx: WorkflowContext await ctx.yield_output(sum(msg.data for msg in messages)) +@dataclass +class MockRequest: + """A mock request message for testing purposes.""" + + request_id: str = field(default_factory=lambda: str(uuid4())) + prompt: str = "" + + @dataclass class ApprovalMessage: """A mock message for approval requests.""" @@ -79,22 +86,22 @@ class MockExecutorRequestApproval(Executor): """A mock executor that simulates a request for approval.""" @handler - async def mock_handler_a(self, message: NumberMessage, ctx: WorkflowContext[RequestInfoMessage]) -> None: + async def mock_handler_a(self, message: NumberMessage, ctx: WorkflowContext) -> None: """A mock handler that requests approval.""" await ctx.set_shared_state(self.id, message.data) - await ctx.send_message(RequestInfoMessage()) + await ctx.request_info(MockRequest(prompt="Mock approval request"), MockRequest, ApprovalMessage) - @handler + @response_handler async def mock_handler_b( self, - message: RequestResponse[RequestInfoMessage, ApprovalMessage], + original_request: MockRequest, + response: ApprovalMessage, ctx: WorkflowContext[NumberMessage, int], ) -> None: """A mock handler that processes the approval response.""" data = await ctx.get_shared_state(self.id) assert isinstance(data, int) - assert isinstance(message.data, ApprovalMessage) - if message.data.approved: + if response.approved: await ctx.yield_output(data) else: await ctx.send_message(NumberMessage(data=data)) @@ -182,15 +189,12 @@ async def test_workflow_send_responses_streaming(): """Test the workflow run with approval.""" executor_a = IncrementExecutor(id="executor_a") executor_b = MockExecutorRequestApproval(id="executor_b") - request_info_executor = RequestInfoExecutor(id="request_info") workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) - .add_edge(executor_b, request_info_executor) - .add_edge(request_info_executor, executor_b) .build() ) @@ -219,15 +223,12 @@ async def test_workflow_send_responses(): """Test the workflow run with approval.""" executor_a = IncrementExecutor(id="executor_a") executor_b = MockExecutorRequestApproval(id="executor_b") - request_info_executor = RequestInfoExecutor(id="request_info") workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) - .add_edge(executor_b, request_info_executor) - .add_edge(request_info_executor, executor_b) .build() ) @@ -480,6 +481,15 @@ async def test_workflow_run_stream_from_checkpoint_with_responses(simple_executo workflow_id="test-workflow", messages={}, shared_state={}, + pending_request_info_events={ + "request_123": RequestInfoEvent( + request_id="request_123", + source_executor_id=simple_executor.id, + request_type=str, + request_data="Mock", + response_type=str, + ).to_dict(), + }, iteration_count=0, ) checkpoint_id = await storage.save_checkpoint(test_checkpoint) @@ -494,17 +504,20 @@ async def test_workflow_run_stream_from_checkpoint_with_responses(simple_executo ) # Test that run_stream_from_checkpoint accepts responses parameter - responses = {"request_123": {"data": "test_response"}} + responses = {"request_123": "test_response"} - try: - events: list[WorkflowEvent] = [] - async for event in workflow.run_stream_from_checkpoint(checkpoint_id, responses=responses): - events.append(event) - if len(events) >= 2: # Limit to avoid infinite loops - break - except Exception: - # Expected since we have minimal setup, but method should accept the parameters - pass + events: list[WorkflowEvent] = [] + async for event in workflow.run_stream_from_checkpoint(checkpoint_id): + events.append(event) + + assert next( + event for event in events if isinstance(event, RequestInfoEvent) and event.request_id == "request_123" + ) + + async for event in workflow.send_responses_streaming(responses): + events.append(event) + + assert len(events) > 0 # Just ensure we processed some events @dataclass @@ -735,7 +748,7 @@ async def test_workflow_concurrent_execution_prevention_streaming(): # Create an async generator that will consume the stream slowly async def consume_stream_slowly(): - result = [] + result: list[WorkflowEvent] = [] async for event in workflow.run_stream(NumberMessage(data=0)): result.append(event) await asyncio.sleep(0.01) # Slow consumption @@ -768,7 +781,7 @@ async def test_workflow_concurrent_execution_prevention_mixed_methods(): # Start a streaming execution async def consume_stream(): - result = [] + result: list[WorkflowEvent] = [] async for event in workflow.run_stream(NumberMessage(data=0)): result.append(event) await asyncio.sleep(0.01) @@ -844,6 +857,7 @@ async def test_agent_streaming_vs_non_streaming() -> None: assert len(agent_run_events) == 1, "Expected exactly one AgentRunEvent in non-streaming mode" assert len(agent_update_events) == 0, "Expected no AgentRunUpdateEvent in non-streaming mode" assert agent_run_events[0].executor_id == "agent_exec" + assert agent_run_events[0].data is not None assert agent_run_events[0].data.messages[0].text == "Hello World" # Test streaming mode with run_stream() @@ -864,6 +878,8 @@ async def test_agent_streaming_vs_non_streaming() -> None: # Verify the updates build up to the full message accumulated_text = "".join( - e.data.contents[0].text for e in stream_agent_update_events if e.data.contents and e.data.contents[0].text + e.data.contents[0].text + for e in stream_agent_update_events + if e.data and e.data.contents and e.data.contents[0].text ) assert accumulated_text == "Hello World", f"Expected 'Hello World', got '{accumulated_text}'" diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index 567944a104..44a17d02bf 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -14,8 +14,6 @@ FunctionApprovalRequestContent, FunctionApprovalResponseContent, FunctionCallContent, - RequestInfoExecutor, - RequestInfoMessage, Role, TextContent, UsageContent, @@ -24,6 +22,7 @@ WorkflowBuilder, WorkflowContext, handler, + response_handler, ) @@ -56,15 +55,17 @@ async def handle_message(self, message: list[ChatMessage], ctx: WorkflowContext[ class RequestingExecutor(Executor): - """Executor that sends RequestInfoMessage to trigger RequestInfoEvent.""" + """Executor that requests info.""" @handler - async def handle_message(self, _: list[ChatMessage], ctx: WorkflowContext[RequestInfoMessage]) -> None: + async def handle_message(self, _: list[ChatMessage], ctx: WorkflowContext) -> None: # Send a RequestInfoMessage to trigger the request info process - await ctx.send_message(RequestInfoMessage()) + await ctx.request_info("Mock request data", str, str) - @handler - async def handle_request_response(self, _: Any, ctx: WorkflowContext[ChatMessage]) -> None: + @response_handler + async def handle_request_response( + self, original_request: str, response: str, ctx: WorkflowContext[ChatMessage] + ) -> None: # Handle the response and emit completion response update = AgentRunResponseUpdate( contents=[TextContent(text="Request completed successfully")], @@ -148,14 +149,11 @@ async def test_end_to_end_basic_workflow_streaming(self): async def test_end_to_end_request_info_handling(self): """Test end-to-end workflow with RequestInfoEvent handling.""" # Create workflow with requesting executor -> request info executor (no cycle) + simple_executor = SimpleExecutor(id="simple", response_text="SimpleResponse", emit_streaming=False) requesting_executor = RequestingExecutor(id="requester") - request_info_executor = RequestInfoExecutor(id="request_info") workflow = ( - WorkflowBuilder() - .set_start_executor(requesting_executor) - .add_edge(requesting_executor, request_info_executor) - .build() + WorkflowBuilder().set_start_executor(simple_executor).add_edge(simple_executor, requesting_executor).build() ) agent = WorkflowAgent(workflow=workflow, name="Request Test Agent") diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index 6876ce6614..b88271ed26 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -100,7 +100,7 @@ def test_workflow_builder_fluent_api(): .build() ) - assert len(workflow.edge_groups) == 4 + assert len(workflow.edge_groups) == 4 + 6 # 4 defined edges + 6 internal edges for request-response handling assert workflow.start_executor_id == executor_a.id assert len(workflow.executors) == 6 diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index 1c4ba561d5..c21da08d52 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -1,7 +1,5 @@ # Copyright (c) Microsoft. All rights reserved. -from dataclasses import dataclass - import pytest from typing_extensions import Never @@ -10,8 +8,6 @@ ExecutorFailedEvent, InProcRunnerContext, RequestInfoEvent, - RequestInfoExecutor, - RequestInfoMessage, SharedState, Workflow, WorkflowBuilder, @@ -69,18 +65,26 @@ async def test_executor_failed_event_emitted_on_direct_execute(): assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in failed) +class SimpleExecutor(Executor): + """Executor that does nothing, for testing.""" + + @handler + async def run(self, msg: str, ctx: WorkflowContext[str]) -> None: # pragma: no cover + await ctx.send_message(msg) + + class Requester(Executor): """Executor that always requests external info to test idle-with-requests state.""" @handler - async def ask(self, _: str, ctx: WorkflowContext[RequestInfoMessage]) -> None: # pragma: no cover - await ctx.send_message(RequestInfoMessage()) + async def ask(self, _: str, ctx: WorkflowContext) -> None: # pragma: no cover + await ctx.request_info("Mock request data", str, str) async def test_idle_with_pending_requests_status_streaming(): - req = Requester(id="req") - rie = RequestInfoExecutor(id="rie") - wf = WorkflowBuilder().set_start_executor(req).add_edge(req, rie).build() + simple_executor = SimpleExecutor(id="simple") + requester = Requester(id="req") + wf = WorkflowBuilder().set_start_executor(simple_executor).add_edge(simple_executor, requester).build() events = [ev async for ev in wf.run_stream("start")] # Consume stream fully @@ -134,9 +138,9 @@ async def test_non_streaming_final_state_helpers(): assert result1.get_final_state() == WorkflowRunState.IDLE # Idle-with-pending-request case - req = Requester(id="req") - rie = RequestInfoExecutor(id="rie") - wf2 = WorkflowBuilder().set_start_executor(req).add_edge(req, rie).build() + simple_executor = SimpleExecutor(id="simple") + requester = Requester(id="req") + wf2 = WorkflowBuilder().set_start_executor(simple_executor).add_edge(simple_executor, requester).build() result2: WorkflowRunResult = await wf2.run("start") assert result2.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS @@ -151,32 +155,12 @@ async def test_run_includes_status_events_completed(): async def test_run_includes_status_events_idle_with_requests(): - req = Requester(id="req2") - rie = RequestInfoExecutor(id="rie2") - wf = WorkflowBuilder().set_start_executor(req).add_edge(req, rie).build() + simple_executor = SimpleExecutor(id="simple") + requester = Requester(id="req2") + wf = WorkflowBuilder().set_start_executor(simple_executor).add_edge(simple_executor, requester).build() result: WorkflowRunResult = await wf.run("start") timeline = result.status_timeline() assert timeline, "Expected status timeline in non-streaming run() results" assert len(timeline) >= 3 assert timeline[-2].state == WorkflowRunState.IN_PROGRESS_PENDING_REQUESTS assert timeline[-1].state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS - - -@dataclass -class SnapshotRequest(RequestInfoMessage): - prompt: str = "" - draft: str = "" - iteration: int = 0 - - -class SnapshotRequester(Executor): - """Executor that emits a rich RequestInfoMessage for persistence tests.""" - - def __init__(self, id: str, prompt: str, draft: str) -> None: - super().__init__(id=id) - self._prompt = prompt - self._draft = draft - - @handler - async def ask(self, _: str, ctx: WorkflowContext[SnapshotRequest]) -> None: # pragma: no cover - simple helper - await ctx.send_message(SnapshotRequest(prompt=self._prompt, draft=self._draft, iteration=1)) From fdfe83df53e4edc9f4449e76cb20cd0680dc15a1 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Fri, 24 Oct 2025 17:50:29 -0700 Subject: [PATCH 13/26] Fix formatting --- .../_workflows/_edge_runner.py | 2 +- .../agent_framework/_workflows/_magentic.py | 6 +- .../_workflows/_typing_utils.py | 4 +- .../_workflows/_workflow_builder.py | 2 +- .../_workflows/_workflow_executor.py | 2 +- .../core/tests/workflow/test_typing_utils.py | 78 ++++++++++++++++++- 6 files changed, 85 insertions(+), 9 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_edge_runner.py b/python/packages/core/agent_framework/_workflows/_edge_runner.py index d97236530a..e621b0fd30 100644 --- a/python/packages/core/agent_framework/_workflows/_edge_runner.py +++ b/python/packages/core/agent_framework/_workflows/_edge_runner.py @@ -86,7 +86,7 @@ async def _execute_on_target( class SingleEdgeRunner(EdgeRunner): """Runner for single edge groups.""" - def __init__(self, edge_group: SingleEdgeGroup, executors: dict[str, Executor]) -> None: + def __init__(self, edge_group: SingleEdgeGroup | InternalEdgeGroup, executors: dict[str, Executor]) -> None: super().__init__(edge_group, executors) self._edge = edge_group.edges[0] diff --git a/python/packages/core/agent_framework/_workflows/_magentic.py b/python/packages/core/agent_framework/_workflows/_magentic.py index 9130fdbc3b..98cad170f2 100644 --- a/python/packages/core/agent_framework/_workflows/_magentic.py +++ b/python/packages/core/agent_framework/_workflows/_magentic.py @@ -1115,7 +1115,7 @@ async def handle_start_message( # If a human must sign off, ask now and return. The response handler will resume. if self._require_plan_signoff: - await self._send_plan_review_request(context) + await self._send_plan_review_request(cast(WorkflowContext, context)) return # Add task ledger to conversation history @@ -1267,7 +1267,7 @@ async def handle_plan_review_response( plan=response.edited_plan_text, ) self._task_ledger = ChatMessage(role=Role.ASSISTANT, text=combined, author_name=MAGENTIC_MANAGER_NAME) - await self._send_plan_review_request(context) + await self._send_plan_review_request(cast(WorkflowContext, context)) return # Else pass comments into the chat history and replan with the manager @@ -1278,7 +1278,7 @@ async def handle_plan_review_response( # Ask the manager to replan; this only adjusts the plan stage, not a full reset self._task_ledger = await self._manager.replan(self._context.clone(deep=True)) - await self._send_plan_review_request(context) + await self._send_plan_review_request(cast(WorkflowContext, context)) async def _run_outer_loop( self, diff --git a/python/packages/core/agent_framework/_workflows/_typing_utils.py b/python/packages/core/agent_framework/_workflows/_typing_utils.py index f2e355d5b3..1450f682c5 100644 --- a/python/packages/core/agent_framework/_workflows/_typing_utils.py +++ b/python/packages/core/agent_framework/_workflows/_typing_utils.py @@ -3,7 +3,7 @@ import logging from dataclasses import fields, is_dataclass from types import UnionType -from typing import Any, Union, get_args, get_origin +from typing import Any, Union, cast, get_args, get_origin logger = logging.getLogger(__name__) @@ -148,4 +148,4 @@ def deserialize_type(serialized_type_string: str) -> type: module_name, _, type_name = serialized_type_string.rpartition(".") module = importlib.import_module(module_name) - return getattr(module, type_name) + return cast(type, getattr(module, type_name)) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_builder.py b/python/packages/core/agent_framework/_workflows/_workflow_builder.py index 8bc1888f86..a1b90408be 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_builder.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_builder.py @@ -83,7 +83,7 @@ def _add_executor(self, executor: Executor) -> str: # New executor self._executors[executor.id] = executor # Add an internal edge group for each unique executor - self._edge_groups.append(InternalEdgeGroup(executor.id)) + self._edge_groups.append(InternalEdgeGroup(executor.id)) # type: ignore[call-arg] return executor.id diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 4b12d819f2..8eb6bd0982 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -471,7 +471,7 @@ async def _ensure_state_loaded(self, ctx: WorkflowContext[Any]) -> None: if isinstance(state, dict) and state: with contextlib.suppress(Exception): - self.restore_state(state) + await self.restore_state(state) self._state_loaded = True else: self._state_loaded = True diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index fc388fddc5..77af045180 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -3,7 +3,8 @@ from dataclasses import dataclass from typing import Any, Generic, TypeVar, Union -from agent_framework._workflows._typing_utils import is_instance_of +from agent_framework import RequestInfoEvent +from agent_framework._workflows._typing_utils import deserialize_type, is_instance_of, serialize_type def test_basic_types() -> None: @@ -116,3 +117,78 @@ def test_edge_cases() -> None: assert is_instance_of({}, dict[str, int]) # Empty dict should be valid assert is_instance_of(None, int | None) # Optional type with None assert not is_instance_of(5, str | None) # Optional type without matching type + + +def test_serialize_type() -> None: + """Test serialization of types to strings.""" + # Test built-in types + assert serialize_type(int) == "builtins.int" + assert serialize_type(str) == "builtins.str" + assert serialize_type(float) == "builtins.float" + assert serialize_type(bool) == "builtins.bool" + assert serialize_type(list) == "builtins.list" + assert serialize_type(dict) == "builtins.dict" + assert serialize_type(tuple) == "builtins.tuple" + assert serialize_type(set) == "builtins.set" + + # Test custom class + @dataclass + class TestClass: + value: int + + # The custom class will be in the test module + expected = f"{TestClass.__module__}.{TestClass.__qualname__}" + assert serialize_type(TestClass) == expected + + +def test_deserialize_type() -> None: + """Test deserialization of type strings back to types.""" + # Test built-in types + assert deserialize_type("builtins.int") is int + assert deserialize_type("builtins.str") is str + assert deserialize_type("builtins.float") is float + assert deserialize_type("builtins.bool") is bool + assert deserialize_type("builtins.list") is list + assert deserialize_type("builtins.dict") is dict + assert deserialize_type("builtins.tuple") is tuple + assert deserialize_type("builtins.set") is set + + +def test_serialize_deserialize_roundtrip() -> None: + """Test that serialization and deserialization are inverse operations.""" + # Test built-in types + types_to_test = [int, str, float, bool, list, dict, tuple, set] + + for type_to_test in types_to_test: + serialized = serialize_type(type_to_test) + deserialized = deserialize_type(serialized) + assert deserialized is type_to_test + + # Test agent framework type roundtrip + + serialized = serialize_type(RequestInfoEvent) + deserialized = deserialize_type(serialized) + assert deserialized is RequestInfoEvent + + # Verify we can instantiate the deserialized type + instance = deserialized( + request_id="request-123", + source_executor_id="executor_1", + request_type=str, + request_data="test", + response_type=str, + ) + assert isinstance(instance, RequestInfoEvent) + + +def test_deserialize_type_error_handling() -> None: + """Test error handling in deserialize_type function.""" + import pytest + + # Test with non-existent module + with pytest.raises(ModuleNotFoundError): + deserialize_type("nonexistent.module.Type") + + # Test with non-existent type in existing module + with pytest.raises(AttributeError): + deserialize_type("builtins.NonExistentType") From ad19795e1adbdf28a4b0486f21ffbd02a34ce09b Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 27 Oct 2025 10:24:37 -0700 Subject: [PATCH 14/26] Resolve comments --- .../guessing_game_with_human_input.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py index 24de2c8ffb..d492ff8d60 100644 --- a/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py +++ b/python/samples/getting_started/workflows/human-in-the-loop/guessing_game_with_human_input.py @@ -45,7 +45,7 @@ - Basic familiarity with WorkflowBuilder, executors, edges, events, and streaming runs. """ -# How human-in-the-loops is achieved via `request_info` and `send_responses_streaming`: +# How human-in-the-loop is achieved via `request_info` and `send_responses_streaming`: # - An executor (TurnManager) calls `ctx.request_info` with a payload (HumanFeedbackRequest). # - The workflow run pauses and emits a RequestInfoEvent with the payload and the request_id. # - The application captures the event, prompts the user, and collects replies. @@ -113,7 +113,11 @@ async def on_agent_response( "lower (your number is lower than this guess), correct, or exit." ) # Send a request with a prompt as the payload and expect a string reply. - await ctx.request_info(HumanFeedbackRequest(prompt=prompt), HumanFeedbackRequest, str) + await ctx.request_info( + request_data=HumanFeedbackRequest(prompt=prompt), + request_type=HumanFeedbackRequest, + response_type=str, + ) @response_handler async def on_human_feedback( @@ -180,7 +184,7 @@ async def main() -> None: # flush=True, # ) - while not workflow_output: + while workflow_output is None: # First iteration uses run_stream("start"). # Subsequent iterations use send_responses_streaming with pending_responses from the console. stream = ( From 26c2accccbaf31cae14c79f9b674a59909943f0f Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 27 Oct 2025 10:53:43 -0700 Subject: [PATCH 15/26] Address comment --- .../core/agent_framework/_workflows/_workflow_executor.py | 2 +- .../composition/sub_workflow_request_interception.py | 4 ++-- .../workflows/orchestration/magentic_checkpoint.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_workflow_executor.py b/python/packages/core/agent_framework/_workflows/_workflow_executor.py index 8eb6bd0982..e7614b5cdb 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_executor.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_executor.py @@ -517,7 +517,7 @@ async def restore_state(self, state: dict[str, Any]) -> None: # The proper way would be to rehydrate the workflow from a checkpoint on a Workflow # API instead of the '_runner_context' object that should be hidden. And the sub workflow # should be rehydrated from a checkpoint object instead of from a subset of the state. - # TODO(@taochen#1614): how to handle the case when the parent workflow has checkpointing + # TODO(@taochen): Issue #1614 - how to handle the case when the parent workflow has checkpointing # set up but not the sub workflow? request_info_events = [ request_info_event diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py index d3fcf75f89..95633b3df0 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py @@ -54,9 +54,9 @@ class SanitizedEmailResult: def build_email_address_validation_workflow() -> Workflow: - """Build a email address validation workflow. + """Build an email address validation workflow. - This workflow consists of three steps (exach is represented by an executor): + This workflow consists of three steps (each is represented by an executor): 1. Sanitize the email address, such as removing leading/trailing spaces. 2. Validate the email address format, such as checking for "@" and domain. 3. Extract the domain from the email address and request domain validation, diff --git a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py index e4a57f9dbd..6ae2dd18b5 100644 --- a/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py +++ b/python/samples/getting_started/workflows/orchestration/magentic_checkpoint.py @@ -142,7 +142,7 @@ async def main() -> None: # Resume execution and capture the re-emitted plan review request. request_info_event: RequestInfoEvent | None = None async for event in resumed_workflow.run_stream_from_checkpoint(resume_checkpoint.checkpoint_id): - if isinstance(event, RequestInfoEvent) and isinstance(event, MagenticPlanReviewRequest): + if isinstance(event, RequestInfoEvent) and isinstance(event.data, MagenticPlanReviewRequest): request_info_event = event if request_info_event is None: From 6dde1672c6474f5c65d8c5448de7224bff203b74 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 27 Oct 2025 11:00:18 -0700 Subject: [PATCH 16/26] Add checkpoint tests --- .../core/tests/workflow/test_checkpoint.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/python/packages/core/tests/workflow/test_checkpoint.py b/python/packages/core/tests/workflow/test_checkpoint.py index 236696e79c..74ac524883 100644 --- a/python/packages/core/tests/workflow/test_checkpoint.py +++ b/python/packages/core/tests/workflow/test_checkpoint.py @@ -20,6 +20,7 @@ def test_workflow_checkpoint_default_values(): assert checkpoint.timestamp != "" assert checkpoint.messages == {} assert checkpoint.shared_state == {} + assert checkpoint.pending_request_info_events == {} assert checkpoint.iteration_count == 0 assert checkpoint.metadata == {} assert checkpoint.version == "1.0" @@ -32,6 +33,7 @@ def test_workflow_checkpoint_custom_values(): workflow_id="test-workflow-456", timestamp=custom_timestamp, messages={"executor1": [{"data": "test"}]}, + pending_request_info_events={"req123": {"data": "test"}}, shared_state={"key": "value"}, iteration_count=5, metadata={"test": True}, @@ -43,6 +45,7 @@ def test_workflow_checkpoint_custom_values(): assert checkpoint.timestamp == custom_timestamp assert checkpoint.messages == {"executor1": [{"data": "test"}]} assert checkpoint.shared_state == {"key": "value"} + assert checkpoint.pending_request_info_events == {"req123": {"data": "test"}} assert checkpoint.iteration_count == 5 assert checkpoint.metadata == {"test": True} assert checkpoint.version == "2.0" @@ -50,7 +53,11 @@ def test_workflow_checkpoint_custom_values(): async def test_memory_checkpoint_storage_save_and_load(): storage = InMemoryCheckpointStorage() - checkpoint = WorkflowCheckpoint(workflow_id="test-workflow", messages={"executor1": [{"data": "hello"}]}) + checkpoint = WorkflowCheckpoint( + workflow_id="test-workflow", + messages={"executor1": [{"data": "hello"}]}, + pending_request_info_events={"req123": {"data": "test"}}, + ) # Save checkpoint saved_id = await storage.save_checkpoint(checkpoint) @@ -62,6 +69,7 @@ async def test_memory_checkpoint_storage_save_and_load(): assert loaded_checkpoint.checkpoint_id == checkpoint.checkpoint_id assert loaded_checkpoint.workflow_id == checkpoint.workflow_id assert loaded_checkpoint.messages == checkpoint.messages + assert loaded_checkpoint.pending_request_info_events == checkpoint.pending_request_info_events async def test_memory_checkpoint_storage_load_nonexistent(): @@ -152,6 +160,7 @@ async def test_file_checkpoint_storage_save_and_load(): workflow_id="test-workflow", messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, shared_state={"key": "value"}, + pending_request_info_events={"req123": {"data": "test"}}, ) # Save checkpoint @@ -169,6 +178,7 @@ async def test_file_checkpoint_storage_save_and_load(): assert loaded_checkpoint.workflow_id == checkpoint.workflow_id assert loaded_checkpoint.messages == checkpoint.messages assert loaded_checkpoint.shared_state == checkpoint.shared_state + assert loaded_checkpoint.pending_request_info_events == checkpoint.pending_request_info_events async def test_file_checkpoint_storage_load_nonexistent(): @@ -284,6 +294,7 @@ async def test_file_checkpoint_storage_json_serialization(): workflow_id="complex-workflow", messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, shared_state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None}, + pending_request_info_events={"req123": {"data": "test"}}, ) # Save and load @@ -303,6 +314,7 @@ async def test_file_checkpoint_storage_json_serialization(): assert data["shared_state"]["list"] == [1, 2, 3] assert data["shared_state"]["bool"] is True assert data["shared_state"]["null"] is None + assert data["pending_request_info_events"]["req123"]["data"] == "test" def test_checkpoint_storage_protocol_compliance(): From 82ec03bc0e538c780c70b862add011770de26c36 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 27 Oct 2025 13:01:06 -0700 Subject: [PATCH 17/26] Add tests --- .../_workflows/_request_info_mixin.py | 1 - .../test_request_info_and_response.py | 413 ++++++++++++++++++ .../tests/workflow/test_request_info_mixin.py | 307 +++++++++++++ 3 files changed, 720 insertions(+), 1 deletion(-) create mode 100644 python/packages/core/tests/workflow/test_request_info_and_response.py create mode 100644 python/packages/core/tests/workflow/test_request_info_mixin.py diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index cbd86de0de..b68c7174ea 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -20,7 +20,6 @@ class RequestInfoMixin: """Mixin providing common functionality for request info handling.""" - _PENDING_SHARED_STATE_KEY: ClassVar[str] = "_af_pending_request_info" def _discover_response_handlers(self) -> None: """Discover and register response handlers defined in the class.""" diff --git a/python/packages/core/tests/workflow/test_request_info_and_response.py b/python/packages/core/tests/workflow/test_request_info_and_response.py new file mode 100644 index 0000000000..198a058e85 --- /dev/null +++ b/python/packages/core/tests/workflow/test_request_info_and_response.py @@ -0,0 +1,413 @@ +# Copyright (c) Microsoft. All rights reserved. + +from dataclasses import dataclass + +from agent_framework import ( + FileCheckpointStorage, + RequestInfoEvent, + WorkflowBuilder, + WorkflowContext, + WorkflowRunState, + WorkflowStatusEvent, + handler, + response_handler, +) +from agent_framework._workflows._executor import Executor +from agent_framework._workflows._request_info_mixin import RequestInfoMixin + + +@dataclass +class UserApprovalRequest: + """A request for user approval with context.""" + + prompt: str + context: str + request_id: str = "" + + def __post_init__(self): + if not self.request_id: + import uuid + + self.request_id = str(uuid.uuid4()) + + +@dataclass +class CalculationRequest: + """A request for a complex calculation.""" + + operation: str + operands: list[float] + request_id: str = "" + + def __post_init__(self): + if not self.request_id: + import uuid + + self.request_id = str(uuid.uuid4()) + + +class ApprovalRequiredExecutor(Executor, RequestInfoMixin): + """Executor that requires approval before proceeding.""" + + def __init__(self, id: str): + super().__init__(id=id) + self.approval_received = False + self.final_result = None + + @handler + async def start_process(self, message: str, ctx: WorkflowContext) -> None: + """Start a process that requires approval.""" + # Request approval from external system + approval_request = UserApprovalRequest( + prompt=f"Please approve the operation: {message}", + context="This is a critical operation that requires human approval.", + ) + await ctx.request_info(approval_request, UserApprovalRequest, bool) + + @response_handler + async def handle_approval_response( + self, original_request: UserApprovalRequest, approved: bool, ctx: WorkflowContext[str] + ) -> None: + """Handle the approval response.""" + self.approval_received = True + + if approved: + self.final_result = f"Operation approved: {original_request.prompt}" + await ctx.send_message(f"APPROVED: {original_request.context}") + else: + self.final_result = "Operation denied by user" + await ctx.send_message("DENIED: Operation was not approved") + + +class CalculationExecutor(Executor, RequestInfoMixin): + """Executor that delegates complex calculations to external services.""" + + def __init__(self, id: str): + super().__init__(id=id) + self.calculations_performed: list[tuple[str, list[float], float]] = [] + + @handler + async def process_calculation(self, message: str, ctx: WorkflowContext[str]) -> None: + """Process a calculation request.""" + # Parse the message to extract operation + parts = message.split() + if len(parts) >= 3: + operation = parts[0] + try: + operands = [float(x) for x in parts[1:]] + calc_request = CalculationRequest(operation=operation, operands=operands) + await ctx.request_info(calc_request, CalculationRequest, float) + except ValueError: + await ctx.send_message("Invalid calculation format") + else: + await ctx.send_message("Insufficient parameters for calculation") + + @response_handler + async def handle_calculation_response( + self, original_request: CalculationRequest, result: float, ctx: WorkflowContext[str] + ) -> None: + """Handle the calculation response.""" + self.calculations_performed.append((original_request.operation, original_request.operands, result)) + operands_str = ", ".join(map(str, original_request.operands)) + await ctx.send_message(f"Calculation complete: {original_request.operation}({operands_str}) = {result}") + + +class MultiRequestExecutor(Executor, RequestInfoMixin): + """Executor that makes multiple requests and waits for all responses.""" + + def __init__(self, id: str): + super().__init__(id=id) + self.responses_received: list[tuple[str, bool | float]] = [] + + @handler + async def start_multi_request(self, message: str, ctx: WorkflowContext) -> None: + """Start multiple requests simultaneously.""" + # Request approval + approval_request = UserApprovalRequest( + prompt="Approve batch operation", context="Multiple operations will be performed" + ) + await ctx.request_info(approval_request, UserApprovalRequest, bool) + + # Request calculation + calc_request = CalculationRequest(operation="multiply", operands=[10.0, 5.0]) + await ctx.request_info(calc_request, CalculationRequest, float) + + @response_handler + async def handle_approval_response( + self, original_request: UserApprovalRequest, approved: bool, ctx: WorkflowContext[str] + ) -> None: + """Handle approval response.""" + self.responses_received.append(("approval", approved)) + await self._check_completion(ctx) + + @response_handler + async def handle_calculation_response( + self, original_request: CalculationRequest, result: float, ctx: WorkflowContext[str] + ) -> None: + """Handle calculation response.""" + self.responses_received.append(("calculation", result)) + await self._check_completion(ctx) + + async def _check_completion(self, ctx: WorkflowContext[str]) -> None: + """Check if all responses are received and send final result.""" + if len(self.responses_received) == 2: + approval_result = next((r[1] for r in self.responses_received if r[0] == "approval"), None) + calc_result = next((r[1] for r in self.responses_received if r[0] == "calculation"), None) + + if approval_result and calc_result is not None: + await ctx.send_message(f"All operations complete. Calculation result: {calc_result}") + else: + await ctx.send_message("Operations completed with mixed results") + + +class OutputCollector(Executor): + """Simple executor that collects outputs for testing.""" + + def __init__(self, id: str): + super().__init__(id=id) + self.collected_outputs: list[str] = [] + + @handler + async def collect_output(self, message: str, ctx: WorkflowContext) -> None: + """Collect output messages.""" + self.collected_outputs.append(message) + + +class TestRequestInfoAndResponse: + """Test cases for end-to-end request info and response handling at the workflow level.""" + + async def test_approval_workflow(self): + """Test end-to-end workflow with approval request.""" + executor = ApprovalRequiredExecutor(id="approval_executor") + workflow = WorkflowBuilder().set_start_executor(executor).build() + + # First run the workflow until it emits a request + request_info_event: RequestInfoEvent | None = None + async for event in workflow.run_stream("test operation"): + if isinstance(event, RequestInfoEvent): + request_info_event = event + + assert request_info_event is not None + assert isinstance(request_info_event.data, UserApprovalRequest) + assert request_info_event.data.prompt == "Please approve the operation: test operation" + + # Send response and continue workflow + completed = False + async for event in workflow.send_responses_streaming({request_info_event.request_id: True}): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + completed = True + + assert completed + assert executor.approval_received is True + assert executor.final_result == "Operation approved: Please approve the operation: test operation" + + async def test_calculation_workflow(self): + """Test end-to-end workflow with calculation request.""" + executor = CalculationExecutor(id="calc_executor") + workflow = WorkflowBuilder().set_start_executor(executor).build() + + # First run the workflow until it emits a calculation request + request_info_event: RequestInfoEvent | None = None + async for event in workflow.run_stream("multiply 15.5 2.0"): + if isinstance(event, RequestInfoEvent): + request_info_event = event + + assert request_info_event is not None + assert isinstance(request_info_event.data, CalculationRequest) + assert request_info_event.data.operation == "multiply" + assert request_info_event.data.operands == [15.5, 2.0] + + # Send response with calculated result + calculated_result = 31.0 + completed = False + async for event in workflow.send_responses_streaming({request_info_event.request_id: calculated_result}): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + completed = True + + assert completed + assert len(executor.calculations_performed) == 1 + assert executor.calculations_performed[0] == ("multiply", [15.5, 2.0], calculated_result) + + async def test_multiple_requests_workflow(self): + """Test workflow with multiple concurrent requests.""" + executor = MultiRequestExecutor(id="multi_executor") + workflow = WorkflowBuilder().set_start_executor(executor).build() + + # Collect all request events by running the full stream + request_events: list[RequestInfoEvent] = [] + async for event in workflow.run_stream("start batch"): + if isinstance(event, RequestInfoEvent): + request_events.append(event) + + assert len(request_events) == 2 + + # Find the approval and calculation requests + approval_event: RequestInfoEvent | None = next( + (e for e in request_events if isinstance(e.data, UserApprovalRequest)), None + ) + calc_event: RequestInfoEvent | None = next( + (e for e in request_events if isinstance(e.data, CalculationRequest)), None + ) + + assert approval_event is not None + assert calc_event is not None + + # Send responses for both requests + responses = {approval_event.request_id: True, calc_event.request_id: 50.0} + completed = False + async for event in workflow.send_responses_streaming(responses): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + completed = True + + assert completed + assert len(executor.responses_received) == 2 + + async def test_denied_approval_workflow(self): + """Test workflow when approval is denied.""" + executor = ApprovalRequiredExecutor(id="approval_executor") + workflow = WorkflowBuilder().set_start_executor(executor).build() + + # First run the workflow until it emits a request + request_info_event: RequestInfoEvent | None = None + async for event in workflow.run_stream("sensitive operation"): + if isinstance(event, RequestInfoEvent): + request_info_event = event + + assert request_info_event is not None + + # Deny the request + completed = False + async for event in workflow.send_responses_streaming({request_info_event.request_id: False}): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + completed = True + + assert completed + assert executor.approval_received is True + assert executor.final_result == "Operation denied by user" + + async def test_workflow_state_with_pending_requests(self): + """Test workflow state when waiting for responses.""" + executor = ApprovalRequiredExecutor(id="approval_executor") + workflow = WorkflowBuilder().set_start_executor(executor).build() + + # Run workflow until idle with pending requests + request_info_event: RequestInfoEvent | None = None + idle_with_pending = False + async for event in workflow.run_stream("test operation"): + if isinstance(event, RequestInfoEvent): + request_info_event = event + elif isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS: + idle_with_pending = True + + assert request_info_event is not None + assert idle_with_pending + + # Continue with response + completed = False + async for event in workflow.send_responses_streaming({request_info_event.request_id: True}): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + completed = True + + assert completed + + async def test_invalid_calculation_input(self): + """Test workflow handling of invalid calculation input.""" + executor = CalculationExecutor(id="calc_executor") + workflow = WorkflowBuilder().set_start_executor(executor).build() + + # Send invalid input (no numbers) + completed = False + async for event in workflow.run_stream("invalid input"): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + completed = True + + assert completed + # Should not have any calculations performed due to invalid input + assert len(executor.calculations_performed) == 0 + + async def test_checkpoint_with_pending_request_info_events(self): + """Test that request info events are properly serialized in checkpoints and can be restored.""" + import tempfile + + with tempfile.TemporaryDirectory() as temp_dir: + # Use file-based storage to test full serialization + storage = FileCheckpointStorage(temp_dir) + + # Create workflow with checkpointing enabled + executor = ApprovalRequiredExecutor(id="approval_executor") + workflow = WorkflowBuilder().set_start_executor(executor).with_checkpointing(storage).build() + + # Step 1: Run workflow to completion to ensure checkpoints are created + request_info_event: RequestInfoEvent | None = None + async for event in workflow.run_stream("checkpoint test operation"): + if isinstance(event, RequestInfoEvent): + request_info_event = event + + # Verify request was emitted + assert request_info_event is not None + assert isinstance(request_info_event.data, UserApprovalRequest) + assert request_info_event.data.prompt == "Please approve the operation: checkpoint test operation" + assert request_info_event.source_executor_id == "approval_executor" + + # Step 2: List checkpoints to find the one with our pending request + checkpoints = await storage.list_checkpoints() + assert len(checkpoints) > 0, "No checkpoints were created during workflow execution" + + # Find the checkpoint with our pending request + checkpoint_with_request = None + for checkpoint in checkpoints: + if request_info_event.request_id in checkpoint.pending_request_info_events: + checkpoint_with_request = checkpoint + break + + assert checkpoint_with_request is not None, "No checkpoint found with pending request info event" + + # Step 3: Verify the pending request info event was properly serialized + serialized_event = checkpoint_with_request.pending_request_info_events[request_info_event.request_id] + assert "data" in serialized_event + assert "request_id" in serialized_event + assert "source_executor_id" in serialized_event + assert "request_type" in serialized_event + assert serialized_event["request_id"] == request_info_event.request_id + assert serialized_event["source_executor_id"] == "approval_executor" + + # Step 4: Create a fresh workflow and restore from checkpoint + new_executor = ApprovalRequiredExecutor(id="approval_executor") + restored_workflow = WorkflowBuilder().set_start_executor(new_executor).with_checkpointing(storage).build() + + # Step 5: Resume from checkpoint and verify the request can be continued + completed = False + restored_request_event: RequestInfoEvent | None = None + async for event in restored_workflow.run_stream_from_checkpoint(checkpoint_with_request.checkpoint_id): + # Should re-emit the pending request info event + if isinstance(event, RequestInfoEvent) and event.request_id == request_info_event.request_id: + restored_request_event = event + elif ( + isinstance(event, WorkflowStatusEvent) + and event.state == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + ): + completed = True + + assert completed, "Workflow should reach idle with pending requests state after restoration" + assert restored_request_event is not None, "Restored request info event should be emitted" + + # Verify the restored event matches the original + assert restored_request_event.source_executor_id == request_info_event.source_executor_id + assert isinstance(restored_request_event.data, UserApprovalRequest) + assert restored_request_event.data.prompt == request_info_event.data.prompt + assert restored_request_event.data.context == request_info_event.data.context + + # Step 6: Provide response to the restored request and complete the workflow + final_completed = False + async for event in restored_workflow.send_responses_streaming({ + request_info_event.request_id: True # Approve the request + }): + if isinstance(event, WorkflowStatusEvent) and event.state == WorkflowRunState.IDLE: + final_completed = True + + assert final_completed, "Workflow should complete after providing response to restored request" + + # Step 7: Verify the executor state was properly restored and response was processed + assert new_executor.approval_received is True + expected_result = "Operation approved: Please approve the operation: checkpoint test operation" + assert new_executor.final_result == expected_result diff --git a/python/packages/core/tests/workflow/test_request_info_mixin.py b/python/packages/core/tests/workflow/test_request_info_mixin.py new file mode 100644 index 0000000000..d8046a100f --- /dev/null +++ b/python/packages/core/tests/workflow/test_request_info_mixin.py @@ -0,0 +1,307 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import inspect +from typing import Any + +import pytest + +from agent_framework._workflows._executor import Executor, handler +from agent_framework._workflows._request_info_mixin import RequestInfoMixin, response_handler +from agent_framework._workflows._workflow_context import WorkflowContext + + +class TestRequestInfoMixin: + """Test cases for RequestInfoMixin functionality.""" + + def test_request_info_mixin_initialization(self): + """Test that RequestInfoMixin can be initialized.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + executor = TestExecutor() + # After calling _discover_response_handlers, it should have the attributes + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + assert hasattr(executor, "_response_handlers") + assert hasattr(executor, "_response_handler_specs") + assert hasattr(executor, "is_request_response_capable") + assert executor.is_request_response_capable is False + + def test_response_handler_decorator_creates_metadata(self): + """Test that the response_handler decorator creates proper metadata.""" + + @response_handler + async def test_handler(self: Any, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + """Test handler docstring.""" + pass + + # Check that the decorator preserves function attributes + assert test_handler.__name__ == "test_handler" + assert test_handler.__doc__ == "Test handler docstring." + assert hasattr(test_handler, "_response_handler_spec") + + # Check the spec attributes + spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] + assert spec["name"] == "test_handler" + assert spec["message_type"] is int + + def test_response_handler_with_workflow_context_types(self): + """Test response handler with different WorkflowContext type parameters.""" + + @response_handler + async def handler_with_output_types( + self: Any, original_request: str, response: int, ctx: WorkflowContext[str, bool] + ) -> None: + pass + + spec = handler_with_output_types._response_handler_spec # type: ignore[reportAttributeAccessIssue] + assert "output_types" in spec + assert "workflow_output_types" in spec + + def test_response_handler_preserves_signature(self): + """Test that response_handler preserves the original function signature.""" + + async def original_handler(self: Any, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + decorated = response_handler(original_handler) + + # Check that signature is preserved + original_sig = inspect.signature(original_handler) + decorated_sig = inspect.signature(decorated) + + # Both should have the same parameter names and types + assert list(original_sig.parameters.keys()) == list(decorated_sig.parameters.keys()) + + def test_executor_with_response_handlers(self): + """Test an executor with valid response handlers.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_string_response( + self, original_request: str, response: int, ctx: WorkflowContext[str] + ) -> None: + pass + + @response_handler + async def handle_dict_response( + self, original_request: dict[str, Any], response: bool, ctx: WorkflowContext[bool] + ) -> None: + pass + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should be request-response capable + assert executor.is_request_response_capable is True + + # Should have registered handlers + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 2 + assert int in response_handlers + assert bool in response_handlers + + def test_executor_without_response_handlers(self): + """Test an executor without response handlers.""" + + class PlainExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="plain_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + executor = PlainExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should not be request-response capable + assert executor.is_request_response_capable is False + + # Should have empty handlers + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 0 + + def test_duplicate_response_handlers_raise_error(self): + """Test that duplicate response handlers for the same message type raise an error.""" + + class DuplicateExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="duplicate_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_first(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + @response_handler + async def handle_second(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + executor = DuplicateExecutor() + + with pytest.raises(ValueError, match="Duplicate response handler for message type.*int"): + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + def test_response_handler_function_callable(self): + """Test that response handlers can actually be called.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + self.handled_request = None + self.handled_response = None + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_response(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + self.handled_request = original_request + self.handled_response = response + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Get the handler + response_handler_func = executor._response_handlers[int] # type: ignore[reportAttributeAccessIssue] + + # Create a mock context - we'll just use None since the handler doesn't use it + asyncio.run(response_handler_func("test_request", 42, None)) # type: ignore[reportArgumentType] + + assert executor.handled_request == "test_request" + assert executor.handled_response == 42 + + def test_inheritance_with_response_handlers(self): + """Test that response handlers work correctly with inheritance.""" + + class BaseExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="base_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def base_handler(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + class ChildExecutor(BaseExecutor): + def __init__(self): + super().__init__() + self.id = "child_executor" + + @response_handler + async def child_handler(self, original_request: str, response: bool, ctx: WorkflowContext[str]) -> None: + pass + + child = ChildExecutor() + child._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should have both handlers + response_handlers = child._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 2 + assert int in response_handlers + assert bool in response_handlers + assert child.is_request_response_capable is True + + def test_response_handler_spec_attributes(self): + """Test that response handler specs contain expected attributes.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def test_handler(self, original_request: str, response: int, ctx: WorkflowContext[str, bool]) -> None: + pass + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + specs = executor._response_handler_specs # type: ignore[reportAttributeAccessIssue] + assert len(specs) == 1 + + spec = specs[0] + assert spec["name"] == "test_handler" + assert spec["message_type"] is int + assert "output_types" in spec + assert "workflow_output_types" in spec + assert "ctx_annotation" in spec + assert spec["source"] == "class_method" + + def test_multiple_discovery_calls_raise_error(self): + """Test that multiple calls to _discover_response_handlers raise an error for duplicates.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def test_handler(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + executor = TestExecutor() + + # First call should work fine + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + first_handlers = len(executor._response_handlers) # type: ignore[reportAttributeAccessIssue] + + # Second call should raise an error due to duplicate registration + with pytest.raises(ValueError, match="Duplicate response handler for message type.*int"): + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Handlers count should remain the same + assert first_handlers == 1 + + def test_non_callable_attributes_ignored(self): + """Test that non-callable attributes are ignored during discovery.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + + some_variable = "not_a_function" + another_attr = 42 + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def valid_handler(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should only have one handler despite other attributes + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 1 + assert int in response_handlers From 968577f8bc7b42d83f278de8e1984b5f897a84ff Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 27 Oct 2025 13:11:11 -0700 Subject: [PATCH 18/26] misc --- python/packages/core/agent_framework/_workflows/_events.py | 7 ++----- .../core/agent_framework/_workflows/_request_info_mixin.py | 3 +-- python/packages/core/agent_framework/_workflows/_runner.py | 5 +---- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index 57f3f9c3d9..d03727e87e 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -6,20 +6,17 @@ from contextvars import ContextVar from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, TypeAlias +from typing import Any, TypeAlias from agent_framework import AgentRunResponse, AgentRunResponseUpdate from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value from ._typing_utils import deserialize_type, serialize_type -if TYPE_CHECKING: - pass - class WorkflowEventSource(str, Enum): """Identifies whether a workflow event came from the framework or an executor. - +runn Use `FRAMEWORK` for events emitted by built-in orchestration paths—even when the code that raises them lives in runner-related modules—and `EXECUTOR` for events surfaced by developer-provided executor implementations. diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index b68c7174ea..3205f17c78 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -6,7 +6,7 @@ import logging from builtins import type as builtin_type from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, ClassVar, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar from ._workflow_context import WorkflowContext, validate_workflow_context_annotation @@ -20,7 +20,6 @@ class RequestInfoMixin: """Mixin providing common functionality for request info handling.""" - def _discover_response_handlers(self) -> None: """Discover and register response handlers defined in the class.""" # Initialize handler storage if not already present diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index 3fe6f1ca7a..51ff79a864 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -4,7 +4,7 @@ import logging from collections import defaultdict from collections.abc import AsyncGenerator, Sequence -from typing import TYPE_CHECKING, Any +from typing import Any from ._checkpoint import CheckpointStorage, WorkflowCheckpoint from ._checkpoint_encoding import DATACLASS_MARKER, MODEL_MARKER, decode_checkpoint_value @@ -19,9 +19,6 @@ ) from ._shared_state import SharedState -if TYPE_CHECKING: - pass - logger = logging.getLogger(__name__) From 28d2811080b89d7b3414d03f3e2742441242fba5 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 27 Oct 2025 13:59:55 -0700 Subject: [PATCH 19/26] fix mypy --- python/packages/core/agent_framework/_workflows/_events.py | 2 +- python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_events.py b/python/packages/core/agent_framework/_workflows/_events.py index d03727e87e..e5a8c3e3c4 100644 --- a/python/packages/core/agent_framework/_workflows/_events.py +++ b/python/packages/core/agent_framework/_workflows/_events.py @@ -16,7 +16,7 @@ class WorkflowEventSource(str, Enum): """Identifies whether a workflow event came from the framework or an executor. -runn + Use `FRAMEWORK` for events emitted by built-in orchestration paths—even when the code that raises them lives in runner-related modules—and `EXECUTOR` for events surfaced by developer-provided executor implementations. diff --git a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py index 1ee2aae167..d5894d086a 100644 --- a/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py +++ b/python/packages/lab/gaia/agent_framework_lab_gaia/gaia.py @@ -298,7 +298,6 @@ def _ensure_data(self) -> Path: repo_type="dataset", token=token, local_dir=str(self.data_dir), - local_dir_use_symlinks=False, force_download=False, ) return Path(local_dir) From ccb18c61e86b7e91faf5597acfc6bf9a38dfc863 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Mon, 27 Oct 2025 14:44:48 -0700 Subject: [PATCH 20/26] fix mypy --- .../agents/azure_chat_agents_tool_calls_with_feedback.py | 3 ++- .../workflows/composition/sub_workflow_parallel_requests.py | 3 ++- .../workflows/composition/sub_workflow_request_interception.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py index a6c87a5899..b5d6262a8b 100644 --- a/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py +++ b/python/samples/getting_started/workflows/agents/azure_chat_agents_tool_calls_with_feedback.py @@ -3,7 +3,7 @@ import asyncio import json from dataclasses import dataclass, field -from typing import Annotated, Never +from typing import Annotated from agent_framework import ( AgentExecutorRequest, @@ -26,6 +26,7 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import AzureCliCredential from pydantic import Field +from typing_extensions import Never """ Sample: Tool-enabled agents with human feedback diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py index b22e452d5b..4d9db405a7 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py @@ -3,7 +3,7 @@ import asyncio import uuid from dataclasses import dataclass -from typing import Literal, Never +from typing import Literal from agent_framework import ( Executor, @@ -17,6 +17,7 @@ handler, response_handler, ) +from typing_extensions import Never """ This sample demonstrates how to handle multiple parallel requests from a sub-workflow to diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py index 95633b3df0..749e454fa8 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py @@ -2,7 +2,6 @@ import asyncio from dataclasses import dataclass -from typing import Never from agent_framework import ( Executor, @@ -16,6 +15,7 @@ handler, response_handler, ) +from typing_extensions import Never """ This sample demostrates how to handle request from the sub-workflow in the main workflow. From 983072489fc9f5f40d14733e0b0d9f83607b0ba0 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 28 Oct 2025 13:34:44 -0700 Subject: [PATCH 21/26] Use request type as part of the key --- .../_workflows/_edge_runner.py | 4 +- .../agent_framework/_workflows/_executor.py | 59 ++-- .../_workflows/_request_info_mixin.py | 125 +++++--- .../tests/workflow/test_request_info_mixin.py | 280 +++++++++++++++++- 4 files changed, 391 insertions(+), 77 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_edge_runner.py b/python/packages/core/agent_framework/_workflows/_edge_runner.py index e621b0fd30..0aa4139c48 100644 --- a/python/packages/core/agent_framework/_workflows/_edge_runner.py +++ b/python/packages/core/agent_framework/_workflows/_edge_runner.py @@ -93,8 +93,8 @@ def __init__(self, edge_group: SingleEdgeGroup | InternalEdgeGroup, executors: d async def send_message(self, message: Message, shared_state: SharedState, ctx: RunnerContext) -> bool: """Send a message through the single edge.""" should_execute = False - target_id = None - source_id = None + target_id: str | None = None + source_id: str | None = None with create_edge_group_processing_span( self._edge_group.__class__.__name__, edge_group_id=self._edge_group.id, diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 814a12ec7a..fe74fd8723 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -233,19 +233,6 @@ async def execute( Returns: An awaitable that resolves to the result of the execution. """ - # Default to find handler in regular handlers - target_handlers = self._handlers - - if isinstance(message, Message): - # Wrap the response handlers to include original_request parameter - if message.type == MessageType.RESPONSE: - target_handlers = { - message_type: functools.partial(handler, message.original_request) - for message_type, handler in self._response_handlers.items() - } - # Handle case where Message wrapper is passed instead of raw data - message = message.data - # Create processing span for tracing (gracefully handles disabled tracing) with create_processing_span( self.id, @@ -255,14 +242,10 @@ async def execute( source_span_ids=source_span_ids, ): # Find the handler and handler spec that matches the message type. - handler: Callable[[Any, WorkflowContext[Any, Any]], Awaitable[None]] | None = None - for message_type in target_handlers: - if is_instance_of(message, message_type): - handler = target_handlers[message_type] - break - - if handler is None: - raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.") + handler = self._find_handler(message) + if isinstance(message, Message): + # Unwrap raw data for handler call + message = message.data # Create the appropriate WorkflowContext based on handler specs context = self._create_context_for_handler( @@ -442,6 +425,40 @@ def to_dict(self) -> dict[str, Any]: """Serialize executor definition for workflow topology export.""" return {"id": self.id, "type": self.type} + def _find_handler(self, message: Any) -> Callable[[Any, WorkflowContext[Any, Any]], Awaitable[None]]: + """Find the handler for a given message. + + Args: + message: The message to find the handler for. + + Returns: + The handler function if found, None otherwise + """ + if isinstance(message, Message): + # Case where Message wrapper is passed instead of raw data + # Handler can be a standard handler or a response handler + if message.type == MessageType.STANDARD: + for message_type in self._handlers: + if is_instance_of(message.data, message_type): + return self._handlers[message_type] + raise RuntimeError( + f"Executor {self.__class__.__name__} cannot handle message of type {type(message.data)}." + ) + # Response message case - find response handler based on original request and response types + handler = self._find_response_handler(message.original_request, message.data) + if not handler: + raise RuntimeError( + f"Executor {self.__class__.__name__} cannot handle request of type " + f"{type(message.original_request)} and response of type {type(message.data)}." + ) + return handler + + # Standard raw message data case - only standard handlers apply + for message_type in self._handlers: + if is_instance_of(message, message_type): + return self._handlers[message_type] + raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.") + # endregion: Executor diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index 3205f17c78..1d4aa5d2f0 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -8,6 +8,7 @@ from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING, Any, TypeVar +from ._typing_utils import is_instance_of from ._workflow_context import WorkflowContext, validate_workflow_context_annotation if TYPE_CHECKING: @@ -20,12 +21,31 @@ class RequestInfoMixin: """Mixin providing common functionality for request info handling.""" + def _find_response_handler(self, request: Any, response: Any) -> Callable[..., Awaitable[None]] | None: + """Find a registered response handler for the given request and response types. + + Args: + request: The original request + response: The response message + Returns: + The response handler function with the request bound as the first argument, or None if not found + """ + if not hasattr(self, "_response_handlers"): + return None + + for (request_type, response_type), handler in self._response_handlers.items(): + if is_instance_of(request, request_type) and is_instance_of(response, response_type): + return functools.partial(handler, request) + + return None + def _discover_response_handlers(self) -> None: """Discover and register response handlers defined in the class.""" # Initialize handler storage if not already present if not hasattr(self, "_response_handlers"): self._response_handlers: dict[ - builtin_type[Any], Callable[[Any, Any, WorkflowContext[Any, Any]], Awaitable[None]] + tuple[builtin_type[Any], builtin_type[Any]], # key + Callable[[Any, Any, WorkflowContext[Any, Any]], Awaitable[None]], # value ] = {} if not hasattr(self, "_response_handler_specs"): self._response_handler_specs: list[dict[str, Any]] = [] @@ -35,17 +55,21 @@ def _discover_response_handlers(self) -> None: attr = getattr(self.__class__, attr_name) if callable(attr) and hasattr(attr, "_response_handler_spec"): handler_spec = attr._response_handler_spec # type: ignore - message_type = handler_spec["message_type"] - if self._response_handlers.get(message_type): + request_type = handler_spec["request_type"] + response_type = handler_spec["response_type"] + + if self._response_handlers.get((request_type, response_type)): raise ValueError( - f"Duplicate response handler for message type {message_type} in {self.__class__.__name__}" + f"Duplicate response handler for request type {request_type} " + f"and response type {response_type} in {self.__class__.__name__}" ) - self._response_handlers[message_type] = getattr(self, attr_name) + self._response_handlers[request_type, response_type] = getattr(self, attr_name) self._response_handler_specs.append({ "name": handler_spec["name"], - "message_type": message_type, + "request_type": request_type, + "response_type": response_type, "output_types": handler_spec.get("output_types", []), "workflow_output_types": handler_spec.get("workflow_output_types", []), "ctx_annotation": handler_spec.get("ctx_annotation"), @@ -78,38 +102,42 @@ def response_handler( The decorated function with handler metadata. Example: - @handler - async def run(self, message: int, context: WorkflowContext[str]) -> None: - # Example of a handler that sends a request - ... - # Send a request with a `CustomRequest` payload and expect a `str` response. - await context.request_info(CustomRequest(...), CustomRequest, str) - - @response_handler - async def handle_response( - self, - original_request: CustomRequest, - response: str, - context: WorkflowContext[str], - ) -> None: - # Example of a response handler for the above request - ... - - @response_handler - async def handle_response( - self, - original_request: CustomRequest, - response: dict, - context: WorkflowContext[int], - ) -> None: - # Example of a response handler for a request expecting a dict response - ... + .. code-block:: python + + @handler + async def run(self, message: int, context: WorkflowContext[str]) -> None: + # Example of a handler that sends a request + ... + # Send a request with a `CustomRequest` payload and expect a `str` response. + await context.request_info(CustomRequest(...), CustomRequest, str) + + + @response_handler + async def handle_response( + self, + original_request: CustomRequest, + response: str, + context: WorkflowContext[str], + ) -> None: + # Example of a response handler for the above request + ... + + + @response_handler + async def handle_response( + self, + original_request: CustomRequest, + response: dict, + context: WorkflowContext[int], + ) -> None: + # Example of a response handler for a request expecting a dict response + ... """ def decorator( func: Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]], ) -> Callable[[ExecutorT, Any, Any, ContextT], Awaitable[None]]: - message_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( + request_type, response_type, ctx_annotation, inferred_output_types, inferred_workflow_output_types = ( _validate_response_handler_signature(func) ) @@ -117,9 +145,9 @@ def decorator( sig = inspect.signature(func) @functools.wraps(func) - async def wrapper(self: ExecutorT, original_request: Any, message: Any, ctx: ContextT) -> Any: + async def wrapper(self: ExecutorT, original_request: Any, response: Any, ctx: ContextT) -> Any: """Wrapper function to call the handler.""" - return await func(self, original_request, message, ctx) + return await func(self, original_request, response, ctx) # Preserve the original function signature for introspection during validation with contextlib.suppress(AttributeError, TypeError): @@ -127,7 +155,8 @@ async def wrapper(self: ExecutorT, original_request: Any, message: Any, ctx: Con wrapper._response_handler_spec = { # type: ignore "name": func.__name__, - "message_type": message_type, + "request_type": request_type, + "response_type": response_type, # Keep output_types and workflow_output_types in spec for validators "output_types": inferred_output_types, "workflow_output_types": inferred_workflow_output_types, @@ -146,14 +175,14 @@ async def wrapper(self: ExecutorT, original_request: Any, message: Any, ctx: Con def _validate_response_handler_signature( func: Callable[..., Any], -) -> tuple[type, Any, list[type[Any]], list[type[Any]]]: +) -> tuple[type, type, Any, list[type[Any]], list[type[Any]]]: """Validate function signature for executor functions. Args: func: The function to validate Returns: - Tuple of (message_type, ctx_annotation, output_types, workflow_output_types) + Tuple of (request_type, response_type, ctx_annotation, output_types, workflow_output_types) Raises: ValueError: If the function signature is invalid @@ -166,15 +195,22 @@ def _validate_response_handler_signature( # to the original request when registering the handler, while maintaining # the order of parameters as if the response handler is a normal handler. expected_counts = 4 # self, original_request, message, ctx - param_description = "(self, original_request: Any, message: T, ctx: WorkflowContext[U, V])" + param_description = "(self, original_request: TRequest, message: TResponse, ctx: WorkflowContext[U, V])" if len(params) != expected_counts: raise ValueError( f"Response handler {func.__name__} must have {param_description}. Got {len(params)} parameters." ) - # Check message parameter has type annotation - message_param = params[2] - if message_param.annotation == inspect.Parameter.empty: + # Check original_request parameter exists + original_request_param = params[1] + if original_request_param.annotation == inspect.Parameter.empty: + raise ValueError( + f"Response handler {func.__name__} must have a type annotation for the original_request parameter" + ) + + # Check response parameter has type annotation + response_param = params[2] + if response_param.annotation == inspect.Parameter.empty: raise ValueError(f"Response handler {func.__name__} must have a type annotation for the message parameter") # Validate ctx parameter is WorkflowContext and extract type args @@ -183,10 +219,11 @@ def _validate_response_handler_signature( ctx_param.annotation, f"parameter '{ctx_param.name}'", "Response handler" ) - message_type = message_param.annotation + request_type = original_request_param.annotation + response_type = response_param.annotation ctx_annotation = ctx_param.annotation - return message_type, ctx_annotation, output_types, workflow_output_types + return request_type, response_type, ctx_annotation, output_types, workflow_output_types # endregion: Response Handler Validation diff --git a/python/packages/core/tests/workflow/test_request_info_mixin.py b/python/packages/core/tests/workflow/test_request_info_mixin.py index d8046a100f..892b9341fd 100644 --- a/python/packages/core/tests/workflow/test_request_info_mixin.py +++ b/python/packages/core/tests/workflow/test_request_info_mixin.py @@ -49,7 +49,8 @@ async def test_handler(self: Any, original_request: str, response: int, ctx: Wor # Check the spec attributes spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] assert spec["name"] == "test_handler" - assert spec["message_type"] is int + assert spec["response_type"] is int + assert spec["request_type"] is str def test_response_handler_with_workflow_context_types(self): """Test response handler with different WorkflowContext type parameters.""" @@ -111,8 +112,8 @@ async def handle_dict_response( # Should have registered handlers response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] assert len(response_handlers) == 2 - assert int in response_handlers - assert bool in response_handlers + assert (str, int) in response_handlers + assert (dict[str, Any], bool) in response_handlers def test_executor_without_response_handlers(self): """Test an executor without response handlers.""" @@ -156,7 +157,10 @@ async def handle_second(self, original_request: str, response: int, ctx: Workflo executor = DuplicateExecutor() - with pytest.raises(ValueError, match="Duplicate response handler for message type.*int"): + with pytest.raises( + ValueError, + match="Duplicate response handler for request type and response type ", + ): executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] def test_response_handler_function_callable(self): @@ -181,7 +185,7 @@ async def handle_response(self, original_request: str, response: int, ctx: Workf executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] # Get the handler - response_handler_func = executor._response_handlers[int] # type: ignore[reportAttributeAccessIssue] + response_handler_func = executor._response_handlers[(str, int)] # type: ignore[reportAttributeAccessIssue] # Create a mock context - we'll just use None since the handler doesn't use it asyncio.run(response_handler_func("test_request", 42, None)) # type: ignore[reportArgumentType] @@ -219,8 +223,8 @@ async def child_handler(self, original_request: str, response: bool, ctx: Workfl # Should have both handlers response_handlers = child._response_handlers # type: ignore[reportAttributeAccessIssue] assert len(response_handlers) == 2 - assert int in response_handlers - assert bool in response_handlers + assert (str, int) in response_handlers + assert (str, bool) in response_handlers assert child.is_request_response_capable is True def test_response_handler_spec_attributes(self): @@ -246,7 +250,8 @@ async def test_handler(self, original_request: str, response: int, ctx: Workflow spec = specs[0] assert spec["name"] == "test_handler" - assert spec["message_type"] is int + assert spec["request_type"] is str + assert spec["response_type"] is int assert "output_types" in spec assert "workflow_output_types" in spec assert "ctx_annotation" in spec @@ -274,7 +279,10 @@ async def test_handler(self, original_request: str, response: int, ctx: Workflow first_handlers = len(executor._response_handlers) # type: ignore[reportAttributeAccessIssue] # Second call should raise an error due to duplicate registration - with pytest.raises(ValueError, match="Duplicate response handler for message type.*int"): + with pytest.raises( + ValueError, + match="Duplicate response handler for request type and response type ", + ): executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] # Handlers count should remain the same @@ -304,4 +312,256 @@ async def valid_handler(self, original_request: str, response: int, ctx: Workflo # Should only have one handler despite other attributes response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] assert len(response_handlers) == 1 - assert int in response_handlers + assert (str, int) in response_handlers + + def test_same_request_type_different_response_types(self): + """Test that handlers with same request type but different response types are distinct.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + self.str_int_handler_called = False + self.str_bool_handler_called = False + self.str_dict_handler_called = False + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_str_int(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + self.str_int_handler_called = True + + @response_handler + async def handle_str_bool(self, original_request: str, response: bool, ctx: WorkflowContext[str]) -> None: + self.str_bool_handler_called = True + + @response_handler + async def handle_str_dict( + self, original_request: str, response: dict[str, Any], ctx: WorkflowContext[str] + ) -> None: + self.str_dict_handler_called = True + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should have three distinct handlers + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 3 + assert (str, int) in response_handlers + assert (str, bool) in response_handlers + assert (str, dict[str, Any]) in response_handlers + + # Test that each handler can be found correctly + str_int_handler = executor._find_response_handler("test", 42) # pyright: ignore[reportPrivateUsage] + str_bool_handler = executor._find_response_handler("test", True) # pyright: ignore[reportPrivateUsage] + str_dict_handler = executor._find_response_handler("test", {"key": "value"}) # pyright: ignore[reportPrivateUsage] + + assert str_int_handler is not None + assert str_bool_handler is not None + assert str_dict_handler is not None + + # Test that handlers are called correctly + asyncio.run(str_int_handler(42, None)) # type: ignore[reportArgumentType] + asyncio.run(str_bool_handler(True, None)) # type: ignore[reportArgumentType] + asyncio.run(str_dict_handler({"key": "value"}, None)) # type: ignore[reportArgumentType] + + assert executor.str_int_handler_called + assert executor.str_bool_handler_called + assert executor.str_dict_handler_called + + def test_different_request_types_same_response_type(self): + """Test that handlers with different request types but same response type are distinct.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + self.str_int_handler_called = False + self.dict_int_handler_called = False + self.list_int_handler_called = False + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_str_int(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + self.str_int_handler_called = True + + @response_handler + async def handle_dict_int( + self, original_request: dict[str, Any], response: int, ctx: WorkflowContext[str] + ) -> None: + self.dict_int_handler_called = True + + @response_handler + async def handle_list_int( + self, original_request: list[str], response: int, ctx: WorkflowContext[str] + ) -> None: + self.list_int_handler_called = True + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should have three distinct handlers + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 3 + assert (str, int) in response_handlers + assert (dict[str, Any], int) in response_handlers + assert (list[str], int) in response_handlers + + # Test that each handler can be found correctly + str_int_handler = executor._find_response_handler("test", 42) # pyright: ignore[reportPrivateUsage] + dict_int_handler = executor._find_response_handler({"key": "value"}, 42) # pyright: ignore[reportPrivateUsage] + list_int_handler = executor._find_response_handler(["test"], 42) # pyright: ignore[reportPrivateUsage] + + assert str_int_handler is not None + assert dict_int_handler is not None + assert list_int_handler is not None + + # Test that handlers are called correctly + asyncio.run(str_int_handler(42, None)) # type: ignore[reportArgumentType] + asyncio.run(dict_int_handler(42, None)) # type: ignore[reportArgumentType] + asyncio.run(list_int_handler(42, None)) # type: ignore[reportArgumentType] + + assert executor.str_int_handler_called + assert executor.dict_int_handler_called + assert executor.list_int_handler_called + + def test_complex_type_combinations(self): + """Test response handlers with complex type combinations.""" + + class CustomRequest: + pass + + class CustomResponse: + pass + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + self.custom_custom_called = False + self.custom_str_called = False + self.str_custom_called = False + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_custom_custom( + self, original_request: CustomRequest, response: CustomResponse, ctx: WorkflowContext[str] + ) -> None: + self.custom_custom_called = True + + @response_handler + async def handle_custom_str( + self, original_request: CustomRequest, response: str, ctx: WorkflowContext[str] + ) -> None: + self.custom_str_called = True + + @response_handler + async def handle_str_custom( + self, original_request: str, response: CustomResponse, ctx: WorkflowContext[str] + ) -> None: + self.str_custom_called = True + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should have three distinct handlers + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 3 + assert (CustomRequest, CustomResponse) in response_handlers + assert (CustomRequest, str) in response_handlers + assert (str, CustomResponse) in response_handlers + + # Test that each handler can be found correctly + custom_request = CustomRequest() + custom_response = CustomResponse() + + custom_custom_handler = executor._find_response_handler(custom_request, custom_response) # pyright: ignore[reportPrivateUsage] + custom_str_handler = executor._find_response_handler(custom_request, "test") # pyright: ignore[reportPrivateUsage] + str_custom_handler = executor._find_response_handler("test", custom_response) # pyright: ignore[reportPrivateUsage] + + assert custom_custom_handler is not None + assert custom_str_handler is not None + assert str_custom_handler is not None + + def test_handler_key_uniqueness(self): + """Test that handler keys (request_type, response_type) are truly unique.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle1(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + @response_handler + async def handle2(self, original_request: int, response: str, ctx: WorkflowContext[str]) -> None: + pass + + @response_handler + async def handle3(self, original_request: str, response: str, ctx: WorkflowContext[str]) -> None: + pass + + @response_handler + async def handle4(self, original_request: int, response: int, ctx: WorkflowContext[str]) -> None: + pass + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Should have four distinct handlers based on different combinations + response_handlers = executor._response_handlers # type: ignore[reportAttributeAccessIssue] + assert len(response_handlers) == 4 + + # Verify all expected combinations exist + expected_keys = { + (str, int), # handle1 + (int, str), # handle2 + (str, str), # handle3 + (int, int), # handle4 + } + + actual_keys = set(response_handlers.keys()) + assert actual_keys == expected_keys + + def test_no_false_matches_with_similar_types(self): + """Test that handlers don't match with similar but different types.""" + + class TestExecutor(Executor, RequestInfoMixin): + def __init__(self): + super().__init__(id="test_executor", defer_discovery=True) + + @handler + async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: + pass + + @response_handler + async def handle_str_int(self, original_request: str, response: int, ctx: WorkflowContext[str]) -> None: + pass + + @response_handler + async def handle_list_str_float( + self, original_request: list[str], response: float, ctx: WorkflowContext[str] + ) -> None: + pass + + executor = TestExecutor() + executor._discover_response_handlers() # type: ignore[reportAttributeAccessIssue] + + # Test that wrong combinations don't match + assert executor._find_response_handler("test", 3.14) is None # pyright: ignore[reportPrivateUsage] # str request, float response - no handler + assert executor._find_response_handler(["test"], 42) is None # pyright: ignore[reportPrivateUsage] # list request, int response - no handler + assert executor._find_response_handler(42, "test") is None # pyright: ignore[reportPrivateUsage] # int request, str response - no handler + + # Test that correct combinations do match + assert executor._find_response_handler("test", 42) is not None # pyright: ignore[reportPrivateUsage] # str request, int response - has handler + assert executor._find_response_handler(["test"], 3.14) is not None # pyright: ignore[reportPrivateUsage] # list request, float response - has handler From 9caa7aa63f85b5b37b6b90a698e5dc1c6fa7cae1 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 28 Oct 2025 14:39:51 -0700 Subject: [PATCH 22/26] Log warning if there is not response handler for a request --- .../agent_framework/_workflows/_executor.py | 2 +- .../_workflows/_request_info_mixin.py | 19 ++++++++++++++++++ .../_workflows/_workflow_context.py | 20 +++++++++++++++---- .../tests/workflow/test_workflow_context.py | 15 +++++++++++++- .../workflow/test_workflow_observability.py | 5 +++-- 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index fe74fd8723..254df4dad6 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -294,7 +294,7 @@ def _create_context_for_handler( """ # Create WorkflowContext return WorkflowContext( - executor_id=self.id, + executor=self, source_executor_ids=source_executor_ids, shared_state=shared_state, runner_context=runner_context, diff --git a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py index 1d4aa5d2f0..ded3652a0e 100644 --- a/python/packages/core/agent_framework/_workflows/_request_info_mixin.py +++ b/python/packages/core/agent_framework/_workflows/_request_info_mixin.py @@ -21,6 +21,25 @@ class RequestInfoMixin: """Mixin providing common functionality for request info handling.""" + def is_request_supported(self, request_type: builtin_type[Any], response_type: builtin_type[Any]) -> bool: + """Check if the executor supports request of the given type and handling a response of the given type. + + Args: + request_type: The type of the request message + response_type: The type of the expected response message + Returns: + True if a response handler is registered for the given request and response types, False otherwise + """ + if not hasattr(self, "_response_handlers"): + return False + + for request_type_key, response_type_key in self._response_handlers: + # TODO(@taochen): #1753 - Consider covariance/contravariance for request/response types + if issubclass(request_type, request_type_key) and issubclass(response_type, response_type_key): + return True + + return False + def _find_response_handler(self, request: Any, response: Any) -> Callable[..., Awaitable[None]] | None: """Find a registered response handler for the given request and response types. diff --git a/python/packages/core/agent_framework/_workflows/_workflow_context.py b/python/packages/core/agent_framework/_workflows/_workflow_context.py index 12b0a90a33..177e85563f 100644 --- a/python/packages/core/agent_framework/_workflows/_workflow_context.py +++ b/python/packages/core/agent_framework/_workflows/_workflow_context.py @@ -4,7 +4,7 @@ import logging import uuid from types import UnionType -from typing import Any, Generic, Union, cast, get_args, get_origin +from typing import TYPE_CHECKING, Any, Generic, Union, cast, get_args, get_origin from opentelemetry.propagate import inject from opentelemetry.trace import SpanKind @@ -27,6 +27,9 @@ from ._runner_context import Message, RunnerContext from ._shared_state import SharedState +if TYPE_CHECKING: + from ._executor import Executor + T_Out = TypeVar("T_Out", default=Never) T_W_Out = TypeVar("T_W_Out", default=Never) @@ -259,7 +262,7 @@ async def flexible(message: str, ctx: WorkflowContext[int | str, bool | dict]) - def __init__( self, - executor_id: str, + executor: "Executor", source_executor_ids: list[str], shared_state: SharedState, runner_context: RunnerContext, @@ -269,7 +272,7 @@ def __init__( """Initialize the executor context with the given workflow context. Args: - executor_id: The unique identifier of the executor that this context belongs to. + executor: The executor instance that this context belongs to. source_executor_ids: The IDs of the source executors that sent messages to this executor. This is a list to support fan_in scenarios where multiple sources send aggregated messages to the same executor. @@ -278,7 +281,8 @@ def __init__( trace_contexts: Optional trace contexts from multiple sources for OpenTelemetry propagation. source_span_ids: Optional source span IDs from multiple sources for linking (not for nesting). """ - self._executor_id = executor_id + self._executor = executor + self._executor_id = executor.id self._source_executor_ids = source_executor_ids self._runner_context = runner_context self._shared_state = shared_state @@ -359,6 +363,14 @@ async def request_info(self, request_data: Any, request_type: type, response_typ request_type: The type of the request, used to match with response handlers. response_type: The expected type of the response, used for validation. """ + if not self._executor.is_request_supported(request_type, response_type): + logger.warning( + f"Executor '{self._executor_id}' requested info of type {request_type.__name__} " + f"with expected response type {response_type.__name__}, but no matching " + "response handler is defined. The request will not be ignored but responses will " + "not be processed. Please define a response handler using the @response_handler decorator." + ) + request_info_event = RequestInfoEvent( request_id=str(uuid.uuid4()), source_executor_id=self._executor_id, diff --git a/python/packages/core/tests/workflow/test_workflow_context.py b/python/packages/core/tests/workflow/test_workflow_context.py index 6abdb33406..b63742d16f 100644 --- a/python/packages/core/tests/workflow/test_workflow_context.py +++ b/python/packages/core/tests/workflow/test_workflow_context.py @@ -24,6 +24,18 @@ from agent_framework._workflows._runner_context import InProcRunnerContext +class MockExecutor(Executor): + """Mock executor for testing.""" + + def __init__(self, id: str) -> None: + super().__init__(id=id) + + @handler + async def handle_message(self, message: str, ctx: WorkflowContext[str]) -> None: + """Handle string messages.""" + ... + + @asynccontextmanager async def make_context( executor_id: str = "exec", @@ -31,10 +43,11 @@ async def make_context( from agent_framework._workflows._runner_context import InProcRunnerContext from agent_framework._workflows._shared_state import SharedState + mock_executor = MockExecutor(executor_id) runner_ctx = InProcRunnerContext() shared_state = SharedState() workflow_ctx: WorkflowContext[object] = WorkflowContext( - executor_id, + mock_executor, ["source"], shared_state, runner_ctx, diff --git a/python/packages/core/tests/workflow/test_workflow_observability.py b/python/packages/core/tests/workflow/test_workflow_observability.py index d7ceb18f90..51f39599c1 100644 --- a/python/packages/core/tests/workflow/test_workflow_observability.py +++ b/python/packages/core/tests/workflow/test_workflow_observability.py @@ -175,7 +175,7 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No # Test trace context propagation in messages workflow_ctx: WorkflowContext[str] = WorkflowContext( - "test-executor", + executor, ["source"], shared_state, ctx, @@ -225,11 +225,12 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No async def test_trace_context_disabled_when_tracing_disabled(enable_otel, span_exporter: InMemorySpanExporter) -> None: """Test that no trace context is added when tracing is disabled.""" # Tracing should be disabled by default + executor = MockExecutor("test-executor") shared_state = SharedState() ctx = InProcRunnerContext() workflow_ctx: WorkflowContext[str] = WorkflowContext( - "test-executor", + executor, ["source"], shared_state, ctx, From f6a1ccb9c440e5bb39d087e9511ea6f1cfb72f27 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 28 Oct 2025 15:04:38 -0700 Subject: [PATCH 23/26] Update Internal edge group comments --- .../core/agent_framework/_workflows/_edge.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index 70eafc7d38..2d144657fe 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -873,8 +873,20 @@ def to_dict(self) -> dict[str, Any]: class InternalEdgeGroup(EdgeGroup): """Special edge group used to route internal messages to executors. - This group is not serialized and is only used at runtime to link internal - executors that should not be exposed as part of the public workflow graph. + This group is created automatically when a new executor is added to the workflow + builder. It contains a single edge that routes messages from the internal source + to the executor itself. Internal source represent messages that are generated by + the system rather than by another executor. This includes request and response + handling. + + This edge group only contains one edge from the internal source to the executor. + And it does not support any conditions or complex routing logic. + + During workflow serialization and deserialization, the internal edge group is + preserved and visible to systems consuming the workflow definition. + + Messages sent along this edge will also be captured by monitoring and logging systems, + allowing for observability into internal message flows (when tracing is enabled). """ def __init__(self, executor_id: str) -> None: From 4635c4dc08af134a7122e6ad9edfe02cf97dded3 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 28 Oct 2025 15:15:29 -0700 Subject: [PATCH 24/26] REcord message type in executor processing span --- .../core/agent_framework/_workflows/_executor.py | 1 + .../packages/core/agent_framework/observability.py | 11 +++++++++++ .../tests/workflow/test_workflow_observability.py | 12 ++++++++---- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/python/packages/core/agent_framework/_workflows/_executor.py b/python/packages/core/agent_framework/_workflows/_executor.py index 254df4dad6..aa63d2576a 100644 --- a/python/packages/core/agent_framework/_workflows/_executor.py +++ b/python/packages/core/agent_framework/_workflows/_executor.py @@ -237,6 +237,7 @@ async def execute( with create_processing_span( self.id, self.__class__.__name__, + str(MessageType.STANDARD if not isinstance(message, Message) else message.type), type(message).__name__, source_trace_contexts=trace_contexts, source_span_ids=source_span_ids, diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 645499471d..5ac176d295 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -210,6 +210,7 @@ class OtelAttr(str, Enum): MESSAGE_SOURCE_ID = "message.source_id" MESSAGE_TARGET_ID = "message.target_id" MESSAGE_TYPE = "message.type" + MESSAGE_PAYLOAD_TYPE = "message.payload_type" MESSAGE_DESTINATION_EXECUTOR_ID = "message.destination_executor_id" # Activity events @@ -1567,6 +1568,7 @@ def create_processing_span( executor_id: str, executor_type: str, message_type: str, + payload_type: str, source_trace_contexts: list[dict[str, str]] | None = None, source_span_ids: list[str] | None = None, ) -> "_AgnosticContextManager[trace.Span]": @@ -1575,6 +1577,14 @@ def create_processing_span( Processing spans are created as children of the current workflow span and linked (not nested) to the source publishing spans for causality tracking. This supports multiple links for fan-in scenarios. + + Args: + executor_id: The unique ID of the executor processing the message. + executor_type: The type of the executor (class name). + message_type: The type of the message being processed ("standard" or "response"). + payload_type: The data type of the message being processed. + source_trace_contexts: Optional trace contexts from source spans for linking. + source_span_ids: Optional source span IDs for linking. """ # Create links to source spans for causality without nesting links: list[trace.Link] = [] @@ -1608,6 +1618,7 @@ def create_processing_span( OtelAttr.EXECUTOR_ID: executor_id, OtelAttr.EXECUTOR_TYPE: executor_type, OtelAttr.MESSAGE_TYPE: message_type, + OtelAttr.MESSAGE_PAYLOAD_TYPE: payload_type, }, links=links, ) diff --git a/python/packages/core/tests/workflow/test_workflow_observability.py b/python/packages/core/tests/workflow/test_workflow_observability.py index 51f39599c1..5856a80035 100644 --- a/python/packages/core/tests/workflow/test_workflow_observability.py +++ b/python/packages/core/tests/workflow/test_workflow_observability.py @@ -8,7 +8,7 @@ from agent_framework import InMemoryCheckpointStorage, WorkflowBuilder from agent_framework._workflows._executor import Executor, handler -from agent_framework._workflows._runner_context import InProcRunnerContext, Message +from agent_framework._workflows._runner_context import InProcRunnerContext, Message, MessageType from agent_framework._workflows._shared_state import SharedState from agent_framework._workflows._workflow import Workflow from agent_framework._workflows._workflow_context import WorkflowContext @@ -127,7 +127,9 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter) OtelAttr.MESSAGE_DESTINATION_EXECUTOR_ID: "target-789", } with ( - create_processing_span("executor-456", "TestExecutor", "TestMessage") as processing_span, + create_processing_span( + "executor-456", "TestExecutor", str(MessageType.STANDARD), "TestMessage" + ) as processing_span, create_workflow_span( OtelAttr.MESSAGE_SEND_SPAN, sending_attributes, kind=trace.SpanKind.PRODUCER ) as sending_span, @@ -155,7 +157,8 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter) assert processing_span.attributes is not None assert processing_span.attributes.get("executor.id") == "executor-456" assert processing_span.attributes.get("executor.type") == "TestExecutor" - assert processing_span.attributes.get("message.type") == "TestMessage" + assert processing_span.attributes.get("message.type") == str(MessageType.STANDARD) + assert processing_span.attributes.get("message.payload_type") == "TestMessage" # Check sending span sending_span = next(s for s in spans if s.name == "message.send") @@ -218,7 +221,8 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No assert processing_span.attributes is not None assert processing_span.attributes.get("executor.id") == "test-executor" assert processing_span.attributes.get("executor.type") == "MockExecutor" - assert processing_span.attributes.get("message.type") == "str" + assert processing_span.attributes.get("message.type") == str(MessageType.STANDARD) + assert processing_span.attributes.get("message.payload_type") == "str" @pytest.mark.parametrize("enable_otel", [False], indirect=True) From 476b0c79b790bedd7e94d901cd819de75d992843 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 28 Oct 2025 16:01:34 -0700 Subject: [PATCH 25/26] Update sample --- .../workflows/composition/sub_workflow_parallel_requests.py | 1 - .../workflows/composition/sub_workflow_request_interception.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py index 4d9db405a7..ca30d87f6f 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_parallel_requests.py @@ -356,7 +356,6 @@ async def main() -> None: if outputs: print("\nWorkflow completed with outputs:") for output in outputs: - # TODO(@taochen): Allow the sub-workflow to output directly print(f"- {output}") else: raise RuntimeError("Workflow did not produce an output.") diff --git a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py index 749e454fa8..769397b972 100644 --- a/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py +++ b/python/samples/getting_started/workflows/composition/sub_workflow_request_interception.py @@ -18,7 +18,7 @@ from typing_extensions import Never """ -This sample demostrates how to handle request from the sub-workflow in the main workflow. +This sample demonstrates how to handle request from the sub-workflow in the main workflow. Prerequisite: - Understanding of sub-workflows. From 8cd07703cb53cce897773b80ed0abc225ba322df Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 28 Oct 2025 21:40:30 -0700 Subject: [PATCH 26/26] Improve tests --- .../tests/workflow/test_request_info_mixin.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/python/packages/core/tests/workflow/test_request_info_mixin.py b/python/packages/core/tests/workflow/test_request_info_mixin.py index 20e4a7587a..d5528f721d 100644 --- a/python/packages/core/tests/workflow/test_request_info_mixin.py +++ b/python/packages/core/tests/workflow/test_request_info_mixin.py @@ -7,7 +7,7 @@ import pytest from agent_framework._workflows._executor import Executor, handler -from agent_framework._workflows._request_info_mixin import RequestInfoMixin, response_handler +from agent_framework._workflows._request_info_mixin import response_handler from agent_framework._workflows._workflow_context import WorkflowContext @@ -17,7 +17,7 @@ class TestRequestInfoMixin: def test_request_info_mixin_initialization(self): """Test that RequestInfoMixin can be initialized.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test") @@ -82,7 +82,7 @@ async def original_handler(self: Any, original_request: str, response: int, ctx: def test_executor_with_response_handlers(self): """Test an executor with valid response handlers.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -116,7 +116,7 @@ async def handle_dict_response( def test_executor_without_response_handlers(self): """Test an executor without response handlers.""" - class PlainExecutor(Executor, RequestInfoMixin): + class PlainExecutor(Executor): def __init__(self): super().__init__(id="plain_executor") @@ -136,7 +136,7 @@ async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: def test_duplicate_response_handlers_raise_error(self): """Test that duplicate response handlers for the same message type raise an error.""" - class DuplicateExecutor(Executor, RequestInfoMixin): + class DuplicateExecutor(Executor): def __init__(self): super().__init__(id="duplicate_executor") @@ -161,7 +161,7 @@ async def handle_second(self, original_request: str, response: int, ctx: Workflo def test_response_handler_function_callable(self): """Test that response handlers can actually be called.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") self.handled_request = None @@ -190,7 +190,7 @@ async def handle_response(self, original_request: str, response: int, ctx: Workf def test_inheritance_with_response_handlers(self): """Test that response handlers work correctly with inheritance.""" - class BaseExecutor(Executor, RequestInfoMixin): + class BaseExecutor(Executor): def __init__(self): super().__init__(id="base_executor") @@ -223,7 +223,7 @@ async def child_handler(self, original_request: str, response: bool, ctx: Workfl def test_response_handler_spec_attributes(self): """Test that response handler specs contain expected attributes.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -252,7 +252,7 @@ async def test_handler(self, original_request: str, response: int, ctx: Workflow def test_multiple_discovery_calls_raise_error(self): """Test that multiple calls to _discover_response_handlers raise an error for duplicates.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -282,7 +282,7 @@ async def test_handler(self, original_request: str, response: int, ctx: Workflow def test_non_callable_attributes_ignored(self): """Test that non-callable attributes are ignored during discovery.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -307,7 +307,7 @@ async def valid_handler(self, original_request: str, response: int, ctx: Workflo def test_same_request_type_different_response_types(self): """Test that handlers with same request type but different response types are distinct.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") self.str_int_handler_called = False @@ -362,7 +362,7 @@ async def handle_str_dict( def test_different_request_types_same_response_type(self): """Test that handlers with different request types but same response type are distinct.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") self.str_int_handler_called = False @@ -425,7 +425,7 @@ class CustomRequest: class CustomResponse: pass - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") self.custom_custom_called = False @@ -478,7 +478,7 @@ async def handle_str_custom( def test_handler_key_uniqueness(self): """Test that handler keys (request_type, response_type) are truly unique.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -522,7 +522,7 @@ async def handle4(self, original_request: int, response: int, ctx: WorkflowConte def test_no_false_matches_with_similar_types(self): """Test that handlers don't match with similar but different types.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -554,7 +554,7 @@ async def handle_list_str_float( def test_is_request_supported_with_exact_matches(self): """Test is_request_supported with exact type matches.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -586,7 +586,7 @@ async def handle_dict_bool( def test_is_request_supported_without_handlers(self): """Test is_request_supported when no handlers are registered.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -604,7 +604,7 @@ async def dummy_handler(self, message: str, ctx: WorkflowContext) -> None: def test_is_request_supported_before_discovery(self): """Test is_request_supported before response handlers are discovered.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor", defer_discovery=True) @@ -638,7 +638,7 @@ class BaseResponse: class DerivedResponse(BaseResponse): pass - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -677,7 +677,7 @@ async def handle_str_int(self, original_request: str, response: int, ctx: Workfl def test_is_request_supported_with_multiple_handlers(self): """Test is_request_supported when multiple handlers are registered.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -722,7 +722,7 @@ async def handle_list_float( def test_is_request_supported_with_complex_types(self): """Test is_request_supported with complex generic types.""" - class TestExecutor(Executor, RequestInfoMixin): + class TestExecutor(Executor): def __init__(self): super().__init__(id="test_executor") @@ -756,7 +756,7 @@ async def handle_list_dict( def test_is_request_supported_with_inheritance(self): """Test is_request_supported with inherited response handlers.""" - class BaseExecutor(Executor, RequestInfoMixin): + class BaseExecutor(Executor): def __init__(self): super().__init__(id="base_executor")