Skip to content

Commit 00ffb37

Browse files
authored
Update group chat and message types (#20)
* Update group chat and message types * fix type based router
1 parent ce58c5b commit 00ffb37

File tree

9 files changed

+121
-87
lines changed

9 files changed

+121
-87
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,4 @@ cython_debug/
162162
.ruff_cache/
163163

164164
/docs/src/reference
165+
.DS_Store

examples/patterns.py

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import asyncio
3+
from typing import Any
34

45
import openai
56
from agnext.agent_components.models_clients.openai_client import OpenAI
@@ -8,8 +9,27 @@
89
)
910
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
1011
from agnext.chat.messages import ChatMessage
11-
from agnext.chat.patterns.group_chat import GroupChat
12+
from agnext.chat.patterns.group_chat import GroupChat, Output
1213
from agnext.chat.patterns.orchestrator import Orchestrator
14+
from agnext.chat.types import TextMessage
15+
16+
17+
class ConcatOutput(Output):
18+
def __init__(self) -> None:
19+
self._output = ""
20+
21+
def on_message_received(self, message: Any) -> None:
22+
match message:
23+
case TextMessage(content=content):
24+
self._output += content
25+
case _:
26+
...
27+
28+
def get_output(self) -> Any:
29+
return self._output
30+
31+
def reset(self) -> None:
32+
self._output = ""
1333

1434

1535
async def group_chat(message: str) -> None:
@@ -45,13 +65,7 @@ async def group_chat(message: str) -> None:
4565
thread_id=cathy_oai_thread.id,
4666
)
4767

48-
chat = GroupChat(
49-
"Host",
50-
"A round-robin chat room.",
51-
runtime,
52-
[joe, cathy],
53-
num_rounds=5,
54-
)
68+
chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())
5569

5670
response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)
5771

src/agnext/agent_components/type_routed_agent.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,16 @@
1414

1515
# NOTE: this works on concrete types and not inheritance
1616
def message_handler(
17-
target_type: Type[ReceivesT],
17+
*target_types: Type[ReceivesT],
1818
) -> Callable[
1919
[Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]],
2020
Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
2121
]:
2222
def decorator(
2323
func: Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]],
2424
) -> Callable[[Any, ReceivesT, bool, CancellationToken], Coroutine[Any, Any, ProducesT | None]]:
25-
func._target_type = target_type # type: ignore
25+
# Convert target_types to list and stash
26+
func._target_types = list(target_types) # type: ignore
2627
return func
2728

2829
return decorator
@@ -40,8 +41,9 @@ def __init__(self, name: str, router: AgentRuntime) -> None:
4041
for attr in dir(self):
4142
if callable(getattr(self, attr, None)):
4243
handler = getattr(self, attr)
43-
if hasattr(handler, "_target_type"):
44-
self._handlers[handler._target_type] = handler
44+
if hasattr(handler, "_target_types"):
45+
for target_type in handler._target_types:
46+
self._handlers[target_type] = handler
4547

4648
@property
4749
def subscriptions(self) -> Sequence[Type[Any]]:
@@ -60,4 +62,4 @@ async def on_message(
6062
async def on_unhandled_message(
6163
self, message: Any, require_response: bool, cancellation_token: CancellationToken
6264
) -> NoReturn:
63-
raise CantHandleException()
65+
raise CantHandleException(f"Unhandled message: {message}")

src/agnext/chat/agents/base.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
from agnext.core.agent_runtime import AgentRuntime
2+
from agnext.core.base_agent import BaseAgent
23

3-
from ...agent_components.type_routed_agent import TypeRoutedAgent
44

5-
6-
class BaseChatAgent(TypeRoutedAgent):
5+
class BaseChatAgent(BaseAgent):
76
"""The BaseAgent class for the chat API."""
87

98
def __init__(self, name: str, description: str, runtime: AgentRuntime) -> None:

src/agnext/chat/agents/oai_assistant.py

+25-17
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import openai
22

3-
from agnext.agent_components.type_routed_agent import message_handler
3+
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
44
from agnext.chat.agents.base import BaseChatAgent
5+
from agnext.chat.types import Reset, RespondNow, TextMessage
56
from agnext.core.agent_runtime import AgentRuntime
67
from agnext.core.cancellation_token import CancellationToken
78

8-
from ..messages import ChatMessage
99

10-
11-
class OpenAIAssistantAgent(BaseChatAgent):
10+
class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
1211
def __init__(
1312
self,
1413
name: str,
@@ -25,29 +24,38 @@ def __init__(
2524
self._current_session_window_length = 0
2625

2726
# TODO: use require_response
28-
@message_handler(ChatMessage)
27+
@message_handler(TextMessage)
2928
async def on_chat_message_with_cancellation(
30-
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
31-
) -> ChatMessage | None:
29+
self, message: TextMessage, require_response: bool, cancellation_token: CancellationToken
30+
) -> None:
3231
print("---------------")
33-
print(f"{self.name} received message from {message.sender}: {message.body}")
32+
print(f"{self.name} received message from {message.source}: {message.content}")
3433
print("---------------")
35-
if message.reset:
36-
# Reset the current session window.
37-
self._current_session_window_length = 0
3834

3935
# Save the message to the thread.
4036
_ = await self._client.beta.threads.messages.create(
4137
thread_id=self._thread_id,
42-
content=message.body,
38+
content=message.content,
4339
role="user",
44-
metadata={"sender": message.sender},
40+
metadata={"sender": message.source},
4541
)
4642
self._current_session_window_length += 1
4743

48-
# If the message is a save_message_only message, return early.
49-
if message.save_message_only:
50-
return ChatMessage(body="OK", sender=self.name)
44+
if require_response:
45+
# TODO ?
46+
...
47+
48+
@message_handler(Reset)
49+
async def on_reset(self, message: Reset, require_response: bool, cancellation_token: CancellationToken) -> None:
50+
# Reset the current session window.
51+
self._current_session_window_length = 0
52+
53+
@message_handler(RespondNow)
54+
async def on_respond_now(
55+
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
56+
) -> TextMessage | None:
57+
if not require_response:
58+
return None
5159

5260
# Create a run and wait until it finishes.
5361
run = await self._client.beta.threads.runs.create_and_poll(
@@ -73,4 +81,4 @@ async def on_chat_message_with_cancellation(
7381
raise ValueError(f"Expected text content in the last message: {last_message_content}")
7482

7583
# TODO: handle multiple text content.
76-
return ChatMessage(body=text_content[0].text.value, sender=self.name)
84+
return TextMessage(content=text_content[0].text.value, source=self.name)

src/agnext/chat/agents/random_agent.py

+7-10
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
11
import random
22

3-
from agnext.agent_components.type_routed_agent import message_handler
3+
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
4+
from agnext.chat.types import RespondNow, TextMessage
45
from agnext.core.cancellation_token import CancellationToken
56

67
from ..agents.base import BaseChatAgent
7-
from ..messages import ChatMessage
88

99

10-
class RandomResponseAgent(BaseChatAgent):
10+
class RandomResponseAgent(BaseChatAgent, TypeRoutedAgent):
1111
# TODO: use require_response
12-
@message_handler(ChatMessage)
12+
@message_handler(RespondNow)
1313
async def on_chat_message_with_cancellation(
14-
self, message: ChatMessage, require_response: bool, cancellation_token: CancellationToken
15-
) -> ChatMessage | None:
16-
print(f"{self.name} received message from {message.sender}: {message.body}")
17-
if message.save_message_only:
18-
return ChatMessage(body="OK", sender=self.name)
14+
self, message: RespondNow, require_response: bool, cancellation_token: CancellationToken
15+
) -> TextMessage:
1916
# Generate a random response.
2017
response_body = random.choice(
2118
[
@@ -36,4 +33,4 @@ async def on_chat_message_with_cancellation(
3633
"See you!",
3734
]
3835
)
39-
return ChatMessage(body=response_body, sender=self.name)
36+
return TextMessage(content=response_body, source=self.name)
+49-42
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
1-
from typing import List, Sequence
1+
from typing import Any, List, Protocol, Sequence
2+
3+
from agnext.chat.types import Reset, RespondNow
24

3-
from ...agent_components.type_routed_agent import message_handler
45
from ...core.agent_runtime import AgentRuntime
56
from ...core.cancellation_token import CancellationToken
67
from ..agents.base import BaseChatAgent
7-
from ..messages import ChatMessage
8+
9+
10+
class Output(Protocol):
11+
def on_message_received(self, message: Any) -> None: ...
12+
13+
def get_output(self) -> Any: ...
14+
15+
def reset(self) -> None: ...
816

917

1018
class GroupChat(BaseChatAgent):
@@ -15,70 +23,69 @@ def __init__(
1523
runtime: AgentRuntime,
1624
agents: Sequence[BaseChatAgent],
1725
num_rounds: int,
26+
output: Output,
1827
) -> None:
1928
super().__init__(name, description, runtime)
2029
self._agents = agents
2130
self._num_rounds = num_rounds
22-
self._history: List[ChatMessage] = []
31+
self._history: List[Any] = []
32+
self._output = output
2333

24-
@message_handler(ChatMessage)
25-
async def on_chat_message(
26-
self,
27-
message: ChatMessage,
28-
require_response: bool,
29-
cancellation_token: CancellationToken,
30-
) -> ChatMessage | None:
31-
if message.reset:
34+
@property
35+
def subscriptions(self) -> Sequence[type]:
36+
agent_sublists = [agent.subscriptions for agent in self._agents]
37+
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
38+
39+
async def on_message(
40+
self, message: Any, require_response: bool, cancellation_token: CancellationToken
41+
) -> Any | None:
42+
if isinstance(message, Reset):
3243
# Reset the history.
3344
self._history = []
34-
if message.save_message_only:
35-
# TODO: what should we do with save_message_only messages for this pattern?
36-
return ChatMessage(body="OK", sender=self.name)
45+
# TODO: reset sub-agents?
46+
47+
if isinstance(message, RespondNow):
48+
# TODO reset...
49+
return self._output.get_output()
50+
51+
# TODO: should we do nothing here?
52+
# Perhaps it should be saved into the message history?
53+
if not require_response:
54+
return None
3755

3856
self._history.append(message)
39-
previous_speaker: BaseChatAgent | None = None
4057
round = 0
4158

4259
while round < self._num_rounds:
4360
# TODO: add support for advanced speaker selection.
4461
# Select speaker (round-robin for now).
4562
speaker = self._agents[round % len(self._agents)]
4663

47-
# Send the last message to non-speaking agents.
48-
for agent in [agent for agent in self._agents if agent is not previous_speaker and agent is not speaker]:
64+
# Send the last message to all agents.
65+
for agent in [agent for agent in self._agents]:
66+
# TODO gather and await
4967
_ = await self._send_message(
50-
ChatMessage(
51-
body=self._history[-1].body,
52-
sender=self._history[-1].sender,
53-
save_message_only=True,
54-
),
68+
self._history[-1],
5569
agent,
70+
require_response=False,
71+
cancellation_token=cancellation_token,
5672
)
5773

58-
# Send the last message to the speaking agent and ask to speak.
59-
if previous_speaker is not speaker:
60-
response = await self._send_message(
61-
ChatMessage(body=self._history[-1].body, sender=self._history[-1].sender),
62-
speaker,
63-
)
64-
else:
65-
# The same speaker is speaking again.
66-
# TODO: should support a separate message type for request to speak only.
67-
response = await self._send_message(
68-
ChatMessage(body="", sender=self.name),
69-
speaker,
70-
)
74+
response = await self._send_message(
75+
RespondNow(),
76+
speaker,
77+
require_response=True,
78+
cancellation_token=cancellation_token,
79+
)
7180

7281
if response is not None:
7382
# 4. Append the response to the history.
7483
self._history.append(response)
75-
76-
# 5. Update the previous speaker.
77-
previous_speaker = speaker
84+
self._output.on_message_received(response)
7885

7986
# 6. Increment the round.
8087
round += 1
8188

82-
# Construct the final response.
83-
response_body = "\n".join([f"{message.sender}: {message.body}" for message in self._history])
84-
return ChatMessage(body=response_body, sender=self.name)
89+
output = self._output.get_output()
90+
self._output.reset()
91+
return output

src/agnext/chat/patterns/orchestrator.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@
22
from typing import Any, List, Sequence, Tuple
33

44
from ...agent_components.model_client import ModelClient
5-
from ...agent_components.type_routed_agent import message_handler
5+
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
66
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
77
from ...core.agent_runtime import AgentRuntime
88
from ...core.cancellation_token import CancellationToken
99
from ..agents.base import BaseChatAgent
1010
from ..messages import ChatMessage
1111

1212

13-
class Orchestrator(BaseChatAgent):
13+
class Orchestrator(BaseChatAgent, TypeRoutedAgent):
1414
def __init__(
1515
self,
1616
name: str,

src/agnext/chat/types.py

+6
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,9 @@ class FunctionExecutionResultMessage(BaseMessage):
4040

4141

4242
Message = Union[TextMessage, MultiModalMessage, FunctionCallMessage, FunctionExecutionResultMessage]
43+
44+
45+
class RespondNow: ...
46+
47+
48+
class Reset: ...

0 commit comments

Comments
 (0)