-
Notifications
You must be signed in to change notification settings - Fork 681
Python: Fix orchestration patterns when using workflow as an agent #1470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,11 +2,13 @@ | |
|
|
||
| import json | ||
| import logging | ||
| import types | ||
| import typing | ||
| import uuid | ||
| from collections.abc import AsyncIterable, Sequence | ||
| from collections.abc import AsyncIterable, Callable, Sequence | ||
| from dataclasses import dataclass | ||
| from datetime import datetime | ||
| from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast | ||
| from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast, get_args, get_origin | ||
|
|
||
| from agent_framework import ( | ||
| AgentRunResponse, | ||
|
|
@@ -29,6 +31,7 @@ | |
| ) | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ._executor import Executor | ||
| from ._workflow import Workflow | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
@@ -91,8 +94,7 @@ def __init__( | |
| except KeyError as exc: # Defensive: workflow lacks a configured entry point | ||
| raise ValueError("Workflow's start executor is not defined.") from exc | ||
|
|
||
| if list[ChatMessage] not in start_executor.input_types: | ||
| raise ValueError("Workflow's start executor cannot handle list[ChatMessage]") | ||
| self._start_payload_encoder = self._resolve_start_payload_encoder(start_executor) | ||
|
|
||
| super().__init__(id=id, name=name, description=description, **kwargs) | ||
| self._workflow: "Workflow" = workflow | ||
|
|
@@ -106,6 +108,143 @@ def workflow(self) -> "Workflow": | |
| def pending_requests(self) -> dict[str, RequestInfoEvent]: | ||
| return self._pending_requests | ||
|
|
||
| def _resolve_start_payload_encoder(self, start_executor: "Executor") -> Callable[[list[ChatMessage]], Any]: | ||
| """Determine how to map agent chat messages to the workflow's start executor input.""" | ||
| probe_conversation = [ChatMessage(role=Role.USER, text="__agent_probe__")] | ||
| if start_executor.can_handle(probe_conversation): | ||
| return lambda messages: list(messages) | ||
|
|
||
| for adapter in self._candidate_adapters_from_input_types(start_executor.input_types): | ||
| try: | ||
| probe_payload = adapter(probe_conversation) | ||
| except ValueError: | ||
| continue | ||
| if start_executor.can_handle(probe_payload): | ||
| return adapter | ||
|
|
||
| raise ValueError("Workflow's start executor cannot be adapted to agent chat inputs.") | ||
|
|
||
| def _candidate_adapters_from_input_types( | ||
| self, | ||
| input_types: Sequence[type[Any]], | ||
| ) -> list[Callable[[list[ChatMessage]], Any]]: | ||
| adapters: list[Callable[[list[ChatMessage]], Any]] = [] | ||
| for annotation in input_types: | ||
| for candidate in self._flatten_type_annotation(annotation): | ||
| adapter = self._adapter_for_concrete_type(candidate) | ||
| if adapter is not None and adapter not in adapters: | ||
| adapters.append(adapter) | ||
| return adapters | ||
|
|
||
| def _flatten_type_annotation(self, annotation: Any) -> list[Any]: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wonder if we need to support the cases when the input executor cannot handle chat message. |
||
| origin = get_origin(annotation) | ||
| if origin is None: | ||
| return [annotation] | ||
|
|
||
| if origin in (types.UnionType, typing.Union): | ||
| flattened: list[Any] = [] | ||
| for arg in get_args(annotation): | ||
| flattened.extend(self._flatten_type_annotation(arg)) | ||
| return flattened | ||
|
|
||
| if origin is typing.Annotated: | ||
| args = get_args(annotation) | ||
| return self._flatten_type_annotation(args[0]) if args else [] | ||
|
|
||
| return [annotation] | ||
|
|
||
| def _adapter_for_concrete_type(self, message_type: Any) -> Callable[[list[ChatMessage]], Any] | None: | ||
| if self._is_chat_message_list_type(message_type): | ||
| return lambda messages: list(messages) | ||
|
|
||
| if self._is_chat_message_type(message_type): | ||
| return self._messages_to_single_chat_message | ||
|
|
||
| if message_type is str: | ||
| return self._messages_to_text | ||
|
|
||
| if isinstance(message_type, type): | ||
| specialized = self._adapter_for_specialized_class(message_type) | ||
| if specialized is not None: | ||
| return specialized | ||
|
|
||
| return None | ||
|
|
||
| def _adapter_for_specialized_class(self, message_cls: type[Any]) -> Callable[[list[ChatMessage]], Any] | None: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: this is probably ok, but I have observed that we have too many lower-level modules taking dependencies on higher level modules. At some point, we should think about some of the design choices we are making to make the code a bit organized. |
||
| try: | ||
| from ._magentic import MagenticStartMessage | ||
| except Exception: # pragma: no cover - optional dependency | ||
moonbox3 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| MagenticStartMessage = None # type: ignore | ||
|
|
||
| if MagenticStartMessage is not None and message_cls is MagenticStartMessage: | ||
| return self._build_magentic_start_adapter(MagenticStartMessage) | ||
|
|
||
| return None | ||
|
|
||
| def _build_magentic_start_adapter( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here too, though it's not taking a dependency on the MagenticBuilder but it's logically related. I don't think this module should know anything about the Magentic orchestration.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I may abandon this PR. I'm working on introducing the base group chat pattern, which the Magentic pattern will extend (as we talked about a while back). There's too much custom "MagenticWorkflow" handling that I don't like as well. I expect we will be able to move away from these custom methods. |
||
| self, | ||
| message_cls: type[Any], | ||
| ) -> Callable[[list[ChatMessage]], Any]: | ||
| def adapter(messages: list[ChatMessage]) -> Any: | ||
| task_message = self._select_user_or_last_message(messages) | ||
| try: | ||
| return message_cls(task=task_message) | ||
| except TypeError as exc: | ||
| raise ValueError("Cannot construct MagenticStartMessage from provided chat messages.") from exc | ||
|
|
||
| return adapter | ||
|
|
||
| def _is_chat_message_list_type(self, annotation: Any) -> bool: | ||
| if annotation is list: | ||
| return True | ||
|
|
||
| origin = get_origin(annotation) | ||
| if origin in (list, Sequence): | ||
| args = get_args(annotation) | ||
| if not args: | ||
| return origin is list | ||
| return all(self._is_chat_message_type(arg) for arg in args) | ||
|
|
||
| return False | ||
|
|
||
| def _is_chat_message_type(self, annotation: Any) -> bool: | ||
| if annotation is ChatMessage: | ||
| return True | ||
| return isinstance(annotation, type) and issubclass(annotation, ChatMessage) | ||
|
|
||
| def _messages_to_single_chat_message(self, messages: list[ChatMessage]) -> ChatMessage: | ||
| return self._select_user_or_last_message(messages) | ||
|
|
||
| def _messages_to_text(self, messages: list[ChatMessage]) -> str: | ||
| message = self._select_user_or_last_message(messages) | ||
| text = message.text.strip() | ||
| if text: | ||
| return text | ||
|
|
||
| fallback_parts: list[str] = [] | ||
| for content in message.contents: | ||
| candidate = getattr(content, "text", None) | ||
| if isinstance(candidate, str) and candidate: | ||
| fallback_parts.append(candidate) | ||
| else: | ||
| rendered = str(content) | ||
| if rendered: | ||
| fallback_parts.append(rendered) | ||
|
|
||
| if fallback_parts: | ||
| return " ".join(fallback_parts) | ||
|
|
||
| raise ValueError("Cannot derive plain-text prompt from chat message contents.") | ||
|
|
||
| def _select_user_or_last_message(self, messages: list[ChatMessage]) -> ChatMessage: | ||
| if not messages: | ||
| raise ValueError("At least one ChatMessage is required to start the workflow.") | ||
|
|
||
| for message in reversed(messages): | ||
| if getattr(message, "role", None) == Role.USER: | ||
| return message | ||
| return messages[-1] | ||
|
|
||
| async def run( | ||
| self, | ||
| messages: str | ChatMessage | list[str] | list[ChatMessage] | None = None, | ||
|
|
@@ -213,8 +352,8 @@ async def _run_stream_impl( | |
| event_stream = self.workflow.send_responses_streaming(function_responses) | ||
| else: | ||
| # Execute workflow with streaming (initial run or no function responses) | ||
| # Pass the new input messages directly to the workflow | ||
| event_stream = self.workflow.run_stream(input_messages) | ||
| start_payload = self._start_payload_encoder(input_messages) | ||
| event_stream = self.workflow.run_stream(start_payload) | ||
|
|
||
| # Process events from the stream | ||
| async for event in event_stream: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,7 @@ | |
| from collections.abc import AsyncIterable, Awaitable, Callable | ||
| from dataclasses import dataclass, field | ||
| from enum import Enum | ||
| from typing import Any, Literal, Protocol, TypeVar, Union, cast | ||
| from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, Union, cast | ||
| from uuid import uuid4 | ||
|
|
||
| from agent_framework import ( | ||
|
|
@@ -38,6 +38,9 @@ | |
| else: | ||
| from typing_extensions import Self # pragma: no cover | ||
|
|
||
| if TYPE_CHECKING: | ||
| from ._agent import WorkflowAgent | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Consistent author name for messages produced by the Magentic manager/orchestrator | ||
|
|
@@ -2043,6 +2046,10 @@ def workflow(self) -> Workflow: | |
| """Access the underlying workflow.""" | ||
| return self._workflow | ||
|
|
||
| def as_agent(self, name: str | None = None) -> "WorkflowAgent": | ||
| """Expose the underlying workflow as a WorkflowAgent.""" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this doc string we should note the message type conversion behavior rather than making this behavior implicit.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to call out the message type conversion? This is simply creating a "WorkflowAgent". We don't call anything specifically even in the workflow's method: def as_agent(self, name: str | None = None) -> WorkflowAgent:
"""Create a WorkflowAgent that wraps this workflow.
Args:
name: Optional name for the agent. If None, a default name will be generated.
Returns:
A WorkflowAgent instance that wraps this workflow.
"""
# Import here to avoid circular imports
from ._agent import WorkflowAgent
return WorkflowAgent(workflow=self, name=name)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought user may want to understand how their input messages are being handled under the hood to become input to the first executor/agent? The |
||
| return self._workflow.as_agent(name=name) | ||
|
|
||
| async def run_streaming_with_string(self, task_text: str) -> AsyncIterable[WorkflowEvent]: | ||
| """Run the workflow with a task string. | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I know these are "private" methods, but it'd still be beneficial to have some comments for them. It'd make reading and understanding the code so much easier for others and potential AI. And maybe copilot could detect errors if the method doesn't do what the description says.