Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, Dict, List, Sequence
from typing import Any, AsyncGenerator, Awaitable, Callable, Dict, List, Sequence

from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
Expand All @@ -27,7 +27,7 @@
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessages,
ToolCallResultMessage,
)
from ._base_chat_agent import BaseChatAgent

Expand Down Expand Up @@ -98,7 +98,11 @@ def set_defaults(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@property
def handoff_tool(self) -> Tool:
"""Create a handoff tool from this handoff configuration."""
return FunctionTool(lambda: self.message, name=self.name, description=self.description)

def _handoff_tool() -> str:
return self.message

return FunctionTool(_handoff_tool, name=self.name, description=self.description)


class AssistantAgent(BaseChatAgent):
Expand Down Expand Up @@ -138,7 +142,7 @@ class AssistantAgent(BaseChatAgent):


The following example demonstrates how to create an assistant agent with
a model client and a tool, and generate a response to a simple task using the tool.
a model client and a tool, and generate a stream of messages for a task.

.. code-block:: python

Expand All @@ -154,7 +158,11 @@ async def get_current_time() -> str:
model_client = OpenAIChatCompletionClient(model="gpt-4o")
agent = AssistantAgent(name="assistant", model_client=model_client, tools=[get_current_time])

await agent.run("What is the current time?", termination_condition=MaxMessageTermination(3))
stream = agent.run_stream("What is the current time?", termination_condition=MaxMessageTermination(3))

async for message in stream:
print(message)


"""

Expand Down Expand Up @@ -219,6 +227,14 @@ def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage, StopMessage]

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
async for message in self.on_messages_stream(messages, cancellation_token):
if isinstance(message, Response):
return message
raise AssertionError("The stream should have returned the final result.")

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ResetMessage):
Expand All @@ -243,14 +259,16 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
yield ToolCallMessage(content=result.content, source=self.name)

# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
inner_messages.append(ToolCallResultMessage(content=results, source=self.name))
yield ToolCallResultMessage(content=results, source=self.name)

# Detect handoff requests.
handoffs: List[Handoff] = []
Expand All @@ -261,12 +279,13 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Return the output messages to signal the handoff.
return Response(
yield Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
return

# Generate an inference result based on the current model context.
result = await self._model_client.create(
Expand All @@ -278,13 +297,13 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
# Detect stop request.
request_stop = "terminate" in result.content.strip().lower()
if request_stop:
return Response(
yield Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)

return Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
else:
yield Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)

async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Sequence
from typing import AsyncGenerator, List, Sequence

from autogen_core.base import CancellationToken

from ..base import ChatAgent, Response, TaskResult, TerminationCondition
from ..base import ChatAgent, Response, TaskResult
from ..messages import ChatMessage, InnerMessage, TextMessage


Expand Down Expand Up @@ -40,12 +40,22 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
"""Handles incoming messages and returns a response."""
...

async def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
"""Handles incoming messages and returns a stream of messages and
and the final item is the response. The base implementation in :class:`BaseChatAgent`
simply calls :meth:`on_messages` and yields the messages in the response."""
response = await self.on_messages(messages, cancellation_token)
for inner_message in response.inner_messages or []:
yield inner_message
yield response

async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
if cancellation_token is None:
Expand All @@ -57,3 +67,25 @@ async def run(
messages += response.inner_messages
messages.append(response.chat_message)
return TaskResult(messages=messages)

async def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the agent with the given task and return a stream of messages
and the final task result as the last item in the stream."""
if cancellation_token is None:
cancellation_token = CancellationToken()
first_message = TextMessage(content=task, source="user")
yield first_message
messages: List[InnerMessage | ChatMessage] = [first_message]
async for message in self.on_messages_stream([first_message], cancellation_token):
if isinstance(message, Response):
yield message.chat_message
messages.append(message.chat_message)
yield TaskResult(messages=messages)
else:
messages.append(message)
yield message
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from dataclasses import dataclass
from typing import List, Protocol, Sequence, runtime_checkable
from typing import AsyncGenerator, List, Protocol, Sequence, runtime_checkable

from autogen_core.base import CancellationToken

from ..messages import ChatMessage, InnerMessage
from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
from ._task import TaskRunner


@dataclass(kw_only=True)
Expand Down Expand Up @@ -45,12 +44,9 @@ async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token:
"""Handles incoming messages and returns a response."""
...

async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the agent with the given task and return the result."""
def on_messages_stream(
self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken
) -> AsyncGenerator[InnerMessage | Response, None]:
"""Handles incoming messages and returns a stream of inner messages and
and the final item is the response."""
...
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from dataclasses import dataclass
from typing import Protocol, Sequence
from typing import AsyncGenerator, Protocol, Sequence

from autogen_core.base import CancellationToken

from ..messages import ChatMessage, InnerMessage
from ._termination import TerminationCondition


@dataclass
Expand All @@ -23,7 +22,16 @@ async def run(
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the task."""
...

def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the task and produces a stream of messages and the final result
as the last item in the stream."""
...
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
from typing import Protocol

from autogen_core.base import CancellationToken

from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition
from ._task import TaskRunner


class Team(TaskRunner, Protocol):
async def run(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team on a given task until the termination condition is met."""
...
pass
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ class ToolCallMessage(BaseMessage):
"""The tool calls."""


class ToolCallResultMessages(BaseMessage):
class ToolCallResultMessage(BaseMessage):
"""A message signaling the results of tool calls."""

content: List[FunctionExecutionResult]
"""The tool call results."""


InnerMessage = ToolCallMessage | ToolCallResultMessages
InnerMessage = ToolCallMessage | ToolCallResultMessage
"""Messages for intra-agent monologues."""


Expand All @@ -80,6 +80,6 @@ class ToolCallResultMessages(BaseMessage):
"HandoffMessage",
"ResetMessage",
"ToolCallMessage",
"ToolCallResultMessages",
"ToolCallResultMessage",
"ChatMessage",
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import uuid
from abc import ABC, abstractmethod
from typing import Callable, List
from typing import AsyncGenerator, Callable, List

from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import (
Expand Down Expand Up @@ -75,9 +76,24 @@ async def run(
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> TaskResult:
"""Run the team and return the result."""
# Create intervention handler for termination.

"""Run the team and return the result. The base implementation uses
:meth:`run_stream` to run the team and then returns the final result."""
async for message in self.run_stream(
task, cancellation_token=cancellation_token, termination_condition=termination_condition
):
if isinstance(message, TaskResult):
return message
raise AssertionError("The stream should have returned the final result.")

async def run_stream(
self,
task: str,
*,
cancellation_token: CancellationToken | None = None,
termination_condition: TerminationCondition | None = None,
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
"""Run the team and produces a stream of messages and the final result
as the last item in the stream."""
# Create the runtime.
runtime = SingleThreadedAgentRuntime()

Expand Down Expand Up @@ -132,6 +148,7 @@ async def run(
)

output_messages: List[InnerMessage | ChatMessage] = []
output_message_queue: asyncio.Queue[InnerMessage | ChatMessage | None] = asyncio.Queue()

async def collect_output_messages(
_runtime: AgentRuntime,
Expand All @@ -140,6 +157,7 @@ async def collect_output_messages(
ctx: MessageContext,
) -> None:
output_messages.append(message)
await output_message_queue.put(message)

await ClosureAgent.register(
runtime,
Expand All @@ -158,14 +176,29 @@ async def collect_output_messages(
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
first_chat_message = TextMessage(content=task, source="user")
output_messages.append(first_chat_message)
await output_message_queue.put(first_chat_message)
await runtime.publish_message(
GroupChatPublishEvent(agent_message=first_chat_message),
topic_id=team_topic_id,
)
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)

# Wait for the runtime to stop.
await runtime.stop_when_idle()
# Start a coroutine to stop the runtime and signal the output message queue is complete.
async def stop_runtime() -> None:
await runtime.stop_when_idle()
await output_message_queue.put(None)

shutdown_task = asyncio.create_task(stop_runtime())

# Yield the messsages until the queue is empty.
while True:
message = await output_message_queue.get()
if message is None:
break
yield message

# Wait for the shutdown task to finish.
await shutdown_task

# Return the result.
return TaskResult(messages=output_messages)
# Yield the final result.
yield TaskResult(messages=output_messages)
Loading
Loading