Skip to content

Commit c3283c6

Browse files
authored
Agentchat refactor (#4062)
* Agentchat refactor * Move termination stop message to a separate field in task result * Update quick start example * Use string stop reason instead of stop message in task result for simpler API * Use main function
1 parent 1098768 commit c3283c6

18 files changed

+284
-412
lines changed

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py

+9-31
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
UserMessage,
1616
)
1717
from autogen_core.components.tools import FunctionTool, Tool
18-
from pydantic import BaseModel, ConfigDict, Field, model_validator
18+
from pydantic import BaseModel, Field, model_validator
1919

2020
from .. import EVENT_LOGGER_NAME
2121
from ..base import Response
@@ -33,30 +33,6 @@
3333
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
3434

3535

36-
class ToolCallEvent(BaseModel):
37-
"""A tool call event."""
38-
39-
source: str
40-
"""The source of the event."""
41-
42-
tool_calls: List[FunctionCall]
43-
"""The tool call message."""
44-
45-
model_config = ConfigDict(arbitrary_types_allowed=True)
46-
47-
48-
class ToolCallResultEvent(BaseModel):
49-
"""A tool call result event."""
50-
51-
source: str
52-
"""The source of the event."""
53-
54-
tool_call_results: List[FunctionExecutionResult]
55-
"""The tool call result message."""
56-
57-
model_config = ConfigDict(arbitrary_types_allowed=True)
58-
59-
6036
class Handoff(BaseModel):
6137
"""Handoff configuration for :class:`AssistantAgent`."""
6238

@@ -264,19 +240,21 @@ async def on_messages_stream(
264240

265241
# Run tool calls until the model produces a string response.
266242
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
267-
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
243+
tool_call_msg = ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
244+
event_logger.debug(tool_call_msg)
268245
# Add the tool call message to the output.
269-
inner_messages.append(ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage))
270-
yield ToolCallMessage(content=result.content, source=self.name, models_usage=result.usage)
246+
inner_messages.append(tool_call_msg)
247+
yield tool_call_msg
271248

272249
# Execute the tool calls.
273250
results = await asyncio.gather(
274251
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
275252
)
276-
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
253+
tool_call_result_msg = ToolCallResultMessage(content=results, source=self.name)
254+
event_logger.debug(tool_call_result_msg)
277255
self._model_context.append(FunctionExecutionResultMessage(content=results))
278-
inner_messages.append(ToolCallResultMessage(content=results, source=self.name))
279-
yield ToolCallResultMessage(content=results, source=self.name)
256+
inner_messages.append(tool_call_result_msg)
257+
yield tool_call_result_msg
280258

281259
# Detect handoff requests.
282260
handoffs: List[Handoff] = []

python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from autogen_core.base import CancellationToken
55

66
from ..base import ChatAgent, Response, TaskResult
7-
from ..messages import ChatMessage, InnerMessage, TextMessage
7+
from ..messages import AgentMessage, ChatMessage, InnerMessage, TextMessage
88

99

1010
class BaseChatAgent(ChatAgent, ABC):
@@ -62,7 +62,7 @@ async def run(
6262
cancellation_token = CancellationToken()
6363
first_message = TextMessage(content=task, source="user")
6464
response = await self.on_messages([first_message], cancellation_token)
65-
messages: List[InnerMessage | ChatMessage] = [first_message]
65+
messages: List[AgentMessage] = [first_message]
6666
if response.inner_messages is not None:
6767
messages += response.inner_messages
6868
messages.append(response.chat_message)
@@ -73,14 +73,14 @@ async def run_stream(
7373
task: str,
7474
*,
7575
cancellation_token: CancellationToken | None = None,
76-
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
76+
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
7777
"""Run the agent with the given task and return a stream of messages
7878
and the final task result as the last item in the stream."""
7979
if cancellation_token is None:
8080
cancellation_token = CancellationToken()
8181
first_message = TextMessage(content=task, source="user")
8282
yield first_message
83-
messages: List[InnerMessage | ChatMessage] = [first_message]
83+
messages: List[AgentMessage] = [first_message]
8484
async for message in self.on_messages_stream([first_message], cancellation_token):
8585
if isinstance(message, Response):
8686
yield message.chat_message

python/packages/autogen-agentchat/src/autogen_agentchat/base/_task.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@
33

44
from autogen_core.base import CancellationToken
55

6-
from ..messages import ChatMessage, InnerMessage
6+
from ..messages import AgentMessage
77

88

99
@dataclass
1010
class TaskResult:
1111
"""Result of running a task."""
1212

13-
messages: Sequence[InnerMessage | ChatMessage]
13+
messages: Sequence[AgentMessage]
1414
"""Messages produced by the task."""
1515

16+
stop_reason: str | None = None
17+
"""The reason the task stopped."""
18+
1619

1720
class TaskRunner(Protocol):
1821
"""A task runner."""
@@ -31,7 +34,7 @@ def run_stream(
3134
task: str,
3235
*,
3336
cancellation_token: CancellationToken | None = None,
34-
) -> AsyncGenerator[InnerMessage | ChatMessage | TaskResult, None]:
37+
) -> AsyncGenerator[AgentMessage | TaskResult, None]:
3538
"""Run the task and produces a stream of messages and the final result
3639
:class:`TaskResult` as the last item in the stream."""
3740
...

python/packages/autogen-agentchat/src/autogen_agentchat/base/_termination.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from abc import ABC, abstractmethod
33
from typing import List, Sequence
44

5-
from ..messages import ChatMessage, StopMessage
5+
from ..messages import AgentMessage, StopMessage
66

77

88
class TerminatedException(BaseException): ...
@@ -50,7 +50,7 @@ def terminated(self) -> bool:
5050
...
5151

5252
@abstractmethod
53-
async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
53+
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
5454
"""Check if the conversation should be terminated based on the messages received
5555
since the last time the condition was called.
5656
Return a StopMessage if the conversation should be terminated, or None otherwise.
@@ -88,7 +88,7 @@ def __init__(self, *conditions: TerminationCondition) -> None:
8888
def terminated(self) -> bool:
8989
return all(condition.terminated for condition in self._conditions)
9090

91-
async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
91+
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
9292
if self.terminated:
9393
raise TerminatedException("Termination condition has already been reached.")
9494
# Check all remaining conditions.
@@ -120,7 +120,7 @@ def __init__(self, *conditions: TerminationCondition) -> None:
120120
def terminated(self) -> bool:
121121
return any(condition.terminated for condition in self._conditions)
122122

123-
async def __call__(self, messages: Sequence[ChatMessage]) -> StopMessage | None:
123+
async def __call__(self, messages: Sequence[AgentMessage]) -> StopMessage | None:
124124
if self.terminated:
125125
raise RuntimeError("Termination condition has already been reached")
126126
stop_messages = await asyncio.gather(*[condition(messages) for condition in self._conditions])

python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py

+9-53
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,18 @@
33
import sys
44
from datetime import datetime
55

6-
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
7-
from ..messages import ChatMessage, StopMessage, TextMessage
8-
from ..teams._events import (
9-
GroupChatPublishEvent,
10-
GroupChatSelectSpeakerEvent,
11-
TerminationEvent,
12-
)
6+
from pydantic import BaseModel
137

148

159
class ConsoleLogHandler(logging.Handler):
16-
@staticmethod
17-
def serialize_chat_message(message: ChatMessage) -> str:
18-
if isinstance(message, TextMessage | StopMessage):
19-
return message.content
20-
else:
21-
d = message.model_dump()
22-
assert "content" in d
23-
return json.dumps(d["content"], indent=2)
24-
2510
def emit(self, record: logging.LogRecord) -> None:
2611
ts = datetime.fromtimestamp(record.created).isoformat()
27-
if isinstance(record.msg, GroupChatPublishEvent):
28-
if record.msg.source is None:
29-
sys.stdout.write(
30-
f"\n{'-'*75} \n"
31-
f"\033[91m[{ts}]:\033[0m\n"
32-
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
33-
)
34-
else:
35-
sys.stdout.write(
36-
f"\n{'-'*75} \n"
37-
f"\033[91m[{ts}], {record.msg.source.type}:\033[0m\n"
38-
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
39-
)
40-
sys.stdout.flush()
41-
elif isinstance(record.msg, ToolCallEvent):
42-
sys.stdout.write(
43-
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}"
44-
)
45-
sys.stdout.flush()
46-
elif isinstance(record.msg, ToolCallResultEvent):
47-
sys.stdout.write(
48-
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}"
49-
)
50-
sys.stdout.flush()
51-
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
52-
sys.stdout.write(
53-
f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}"
54-
)
55-
sys.stdout.flush()
56-
elif isinstance(record.msg, TerminationEvent):
57-
sys.stdout.write(
58-
f"\n{'-'*75} \n"
59-
f"\033[91m[{ts}], Termination:\033[0m\n"
60-
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
12+
if isinstance(record.msg, BaseModel):
13+
record.msg = json.dumps(
14+
{
15+
"timestamp": ts,
16+
"message": record.msg.model_dump_json(indent=2),
17+
"type": record.msg.__class__.__name__,
18+
},
6119
)
62-
sys.stdout.flush()
63-
else:
64-
raise ValueError(f"Unexpected log record: {record.msg}")
20+
sys.stdout.write(f"{record.msg}\n")
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
import json
22
import logging
3-
from dataclasses import asdict, is_dataclass
43
from datetime import datetime
5-
from typing import Any
64

7-
from ..agents._assistant_agent import ToolCallEvent, ToolCallResultEvent
8-
from ..teams._events import (
9-
GroupChatPublishEvent,
10-
GroupChatSelectSpeakerEvent,
11-
TerminationEvent,
12-
)
5+
from pydantic import BaseModel
136

147

158
class FileLogHandler(logging.Handler):
@@ -20,65 +13,12 @@ def __init__(self, filename: str) -> None:
2013

2114
def emit(self, record: logging.LogRecord) -> None:
2215
ts = datetime.fromtimestamp(record.created).isoformat()
23-
if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent):
24-
log_entry = json.dumps(
16+
if isinstance(record.msg, BaseModel):
17+
record.msg = json.dumps(
2518
{
2619
"timestamp": ts,
27-
"source": record.msg.source,
28-
"agent_message": record.msg.agent_message.model_dump(),
20+
"message": record.msg.model_dump(),
2921
"type": record.msg.__class__.__name__,
3022
},
31-
default=self.json_serializer,
3223
)
33-
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
34-
log_entry = json.dumps(
35-
{
36-
"timestamp": ts,
37-
"source": record.msg.source,
38-
"selected_speaker": record.msg.selected_speaker,
39-
"type": "SelectSpeakerEvent",
40-
},
41-
default=self.json_serializer,
42-
)
43-
elif isinstance(record.msg, ToolCallEvent):
44-
log_entry = json.dumps(
45-
{
46-
"timestamp": ts,
47-
"tool_calls": record.msg.model_dump(),
48-
"type": "ToolCallEvent",
49-
},
50-
default=self.json_serializer,
51-
)
52-
elif isinstance(record.msg, ToolCallResultEvent):
53-
log_entry = json.dumps(
54-
{
55-
"timestamp": ts,
56-
"tool_call_results": record.msg.model_dump(),
57-
"type": "ToolCallResultEvent",
58-
},
59-
default=self.json_serializer,
60-
)
61-
else:
62-
raise ValueError(f"Unexpected log record: {record.msg}")
63-
file_record = logging.LogRecord(
64-
name=record.name,
65-
level=record.levelno,
66-
pathname=record.pathname,
67-
lineno=record.lineno,
68-
msg=log_entry,
69-
args=(),
70-
exc_info=record.exc_info,
71-
)
72-
self.file_handler.emit(file_record)
73-
74-
def close(self) -> None:
75-
self.file_handler.close()
76-
super().close()
77-
78-
@staticmethod
79-
def json_serializer(obj: Any) -> Any:
80-
if is_dataclass(obj) and not isinstance(obj, type):
81-
return asdict(obj)
82-
elif isinstance(obj, type):
83-
return str(obj)
84-
return str(obj)
24+
self.file_handler.emit(record)

python/packages/autogen-agentchat/src/autogen_agentchat/messages.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from autogen_core.components import FunctionCall, Image
44
from autogen_core.components.models import FunctionExecutionResult, RequestUsage
5-
from pydantic import BaseModel
5+
from pydantic import BaseModel, ConfigDict
66

77

88
class BaseMessage(BaseModel):
@@ -14,6 +14,8 @@ class BaseMessage(BaseModel):
1414
models_usage: RequestUsage | None = None
1515
"""The model client usage incurred when producing this message."""
1616

17+
model_config = ConfigDict(arbitrary_types_allowed=True)
18+
1719

1820
class TextMessage(BaseMessage):
1921
"""A text message."""
@@ -75,6 +77,10 @@ class ToolCallResultMessage(BaseMessage):
7577
"""Messages for agent-to-agent communication."""
7678

7779

80+
AgentMessage = InnerMessage | ChatMessage
81+
"""All message types."""
82+
83+
7884
__all__ = [
7985
"BaseMessage",
8086
"TextMessage",
@@ -85,4 +91,6 @@ class ToolCallResultMessage(BaseMessage):
8591
"ToolCallMessage",
8692
"ToolCallResultMessage",
8793
"ChatMessage",
94+
"InnerMessage",
95+
"AgentMessage",
8896
]

0 commit comments

Comments
 (0)