Skip to content

Commit e3a2f79

Browse files
authored
Orchestrator Chat and OAI Assistant update (#31)
1 parent ecbc3b7 commit e3a2f79

File tree

4 files changed

+138
-153
lines changed

4 files changed

+138
-153
lines changed

examples/patterns.py

+50-15
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@
33
import logging
44

55
import openai
6-
from agnext.agent_components.model_client import OpenAI
76
from agnext.application_components import (
87
SingleThreadedAgentRuntime,
98
)
109
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
1110
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
12-
from agnext.chat.patterns.orchestrator import Orchestrator
11+
from agnext.chat.patterns.orchestrator_chat import OrchestratorChat
1312
from agnext.chat.types import TextMessage
1413
from agnext.core._agent import Agent
1514
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
@@ -38,20 +37,28 @@ def reset(self) -> None:
3837

3938

4039
class LoggingHandler(DefaultInterventionHandler):
40+
send_color = "\033[31m"
41+
response_color = "\033[34m"
42+
reset_color = "\033[0m"
43+
4144
@override
4245
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
4346
if sender is None:
44-
print(f"Sending message to {recipient.name}: {message}")
47+
print(f"{self.send_color}Sending message to {recipient.name}:{self.reset_color} {message}")
4548
else:
46-
print(f"Sending message from {sender.name} to {recipient.name}: {message}")
49+
print(
50+
f"{self.send_color}Sending message from {sender.name} to {recipient.name}:{self.reset_color} {message}"
51+
)
4752
return message
4853

4954
@override
5055
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
5156
if recipient is None:
52-
print(f"Received response from {sender.name}: {message}")
57+
print(f"{self.response_color}Received response from {sender.name}:{self.reset_color} {message}")
5358
else:
54-
print(f"Received response from {sender.name} to {recipient.name}: {message}")
59+
print(
60+
f"{self.response_color}Received response from {sender.name} to {recipient.name}:{self.reset_color} {message}"
61+
)
5562
return message
5663

5764

@@ -131,19 +138,47 @@ async def orchestrator(message: str) -> None:
131138
thread_id=product_manager_oai_thread.id,
132139
)
133140

134-
chat = Orchestrator(
135-
"Manager",
136-
"A software development team manager.",
137-
runtime,
138-
[developer, product_manager],
139-
model_client=OpenAI(model="gpt-3.5-turbo"),
141+
planner_oai_assistant = openai.beta.assistants.create(
142+
model="gpt-4-turbo",
143+
name="Planner",
144+
instructions="You are a planner of complex tasks.",
145+
)
146+
planner_oai_thread = openai.beta.threads.create()
147+
planner = OpenAIAssistantAgent(
148+
name="Planner",
149+
description="A planner that organizes and schedules tasks.",
150+
runtime=runtime,
151+
client=openai.AsyncClient(),
152+
assistant_id=planner_oai_assistant.id,
153+
thread_id=planner_oai_thread.id,
154+
)
155+
156+
orchestrator_oai_assistant = openai.beta.assistants.create(
157+
model="gpt-4-turbo",
158+
name="Orchestrator",
159+
instructions="You are an orchestrator that coordinates the team to complete a complex task.",
160+
)
161+
orchestrator_oai_thread = openai.beta.threads.create()
162+
orchestrator = OpenAIAssistantAgent(
163+
name="Orchestrator",
164+
description="An orchestrator that coordinates the team.",
165+
runtime=runtime,
166+
client=openai.AsyncClient(),
167+
assistant_id=orchestrator_oai_assistant.id,
168+
thread_id=orchestrator_oai_thread.id,
140169
)
141170

142-
response = runtime.send_message(
143-
TextMessage(content=message, source="customer"),
144-
chat,
171+
chat = OrchestratorChat(
172+
"Orchestrator Chat",
173+
"A software development team.",
174+
runtime,
175+
orchestrator=orchestrator,
176+
planner=planner,
177+
specialists=[developer, product_manager],
145178
)
146179

180+
response = runtime.send_message(TextMessage(content=message, source="Customer"), chat)
181+
147182
while not response.done():
148183
await runtime.process_next()
149184

src/agnext/chat/agents/oai_assistant.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Callable, Dict
1+
from typing import Callable, Dict, List
22

33
import openai
44

@@ -23,8 +23,6 @@ def __init__(
2323
self._client = client
2424
self._assistant_id = assistant_id
2525
self._thread_id = thread_id
26-
# TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored.
27-
self._current_session_window_length = 1
2826
self._tools = tools or {}
2927

3028
@message_handler(TextMessage)
@@ -36,32 +34,40 @@ async def on_text_message(self, message: TextMessage, cancellation_token: Cancel
3634
role="user",
3735
metadata={"sender": message.source},
3836
)
39-
self._current_session_window_length += 1
4037

4138
@message_handler(Reset)
4239
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
43-
# Reset the current session window.
44-
self._current_session_window_length = 1
40+
# Get all messages in this thread.
41+
all_msgs: List[str] = []
42+
while True:
43+
if not all_msgs:
44+
msgs = await self._client.beta.threads.messages.list(self._thread_id)
45+
else:
46+
msgs = await self._client.beta.threads.messages.list(self._thread_id, after=all_msgs[-1])
47+
for msg in msgs.data:
48+
all_msgs.append(msg.id)
49+
if not msgs.has_next_page():
50+
break
51+
# Delete all the messages.
52+
for msg_id in all_msgs:
53+
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
54+
assert status.deleted is True
4555

4656
@message_handler(RespondNow)
4757
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
58+
# Handle response format.
59+
4860
# Create a run and wait until it finishes.
4961
run = await self._client.beta.threads.runs.create_and_poll(
5062
thread_id=self._thread_id,
5163
assistant_id=self._assistant_id,
52-
truncation_strategy={
53-
"type": "last_messages",
54-
"last_messages": self._current_session_window_length,
55-
},
64+
response_format=message.response_format,
5665
)
5766

5867
if run.status != "completed":
5968
# TODO: handle other statuses.
6069
raise ValueError(f"Run did not complete successfully: {run}")
6170

62-
# Increment the current session window length.
63-
self._current_session_window_length += 1
64-
6571
# Get the last message from the run.
6672
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
6773
last_message_content = response.data[0].content

0 commit comments

Comments
 (0)