Skip to content

Commit 3712566

Browse files
davorrunjesonichi
authored andcommitted
Add documentation and raise exception when registering async reply function in sync chat (#1208)
* documentation update and added tests for register_reply function * added raising an exception on an async reply function in sync chat * big fixing * test expanded * Update autogen/agentchat/conversable_agent.py Co-authored-by: Chi Wang <[email protected]> * Update autogen/agentchat/conversable_agent.py Co-authored-by: Chi Wang <[email protected]> * refactorization --------- Co-authored-by: Chi Wang <[email protected]>
1 parent baf8da1 commit 3712566

File tree

3 files changed

+180
-6
lines changed

3 files changed

+180
-6
lines changed

autogen/agentchat/conversable_agent.py

+50-4
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,21 @@ def __init__(
141141
)
142142
self._default_auto_reply = default_auto_reply
143143
self._reply_func_list = []
144+
self._ignore_async_func_in_sync_chat_list = []
144145
self.reply_at_receive = defaultdict(bool)
145146
self.register_reply([Agent, None], ConversableAgent.generate_oai_reply)
146-
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply)
147+
self.register_reply([Agent, None], ConversableAgent.a_generate_oai_reply, ignore_async_in_sync_chat=True)
147148
self.register_reply([Agent, None], ConversableAgent.generate_code_execution_reply)
148149
self.register_reply([Agent, None], ConversableAgent.generate_tool_calls_reply)
149-
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply)
150+
self.register_reply([Agent, None], ConversableAgent.a_generate_tool_calls_reply, ignore_async_in_sync_chat=True)
150151
self.register_reply([Agent, None], ConversableAgent.generate_function_call_reply)
151-
self.register_reply([Agent, None], ConversableAgent.a_generate_function_call_reply)
152+
self.register_reply(
153+
[Agent, None], ConversableAgent.a_generate_function_call_reply, ignore_async_in_sync_chat=True
154+
)
152155
self.register_reply([Agent, None], ConversableAgent.check_termination_and_human_reply)
153-
self.register_reply([Agent, None], ConversableAgent.a_check_termination_and_human_reply)
156+
self.register_reply(
157+
[Agent, None], ConversableAgent.a_check_termination_and_human_reply, ignore_async_in_sync_chat=True
158+
)
154159

155160
# Registered hooks are kept in lists, indexed by hookable method, to be called in their order of registration.
156161
# New hookable methods should be added to this list as required to support new agent capabilities.
@@ -163,13 +168,22 @@ def register_reply(
163168
position: int = 0,
164169
config: Optional[Any] = None,
165170
reset_config: Optional[Callable] = None,
171+
*,
172+
ignore_async_in_sync_chat: bool = False,
166173
):
167174
"""Register a reply function.
168175
169176
The reply function will be called when the trigger matches the sender.
170177
The function registered later will be checked earlier by default.
171178
To change the order, set the position to a positive integer.
172179
180+
Both sync and async reply functions can be registered. The sync reply function will be triggered
181+
from both sync and async chats. However, an async reply function will only be triggered from async
182+
chats (initiated with `ConversableAgent.a_initiate_chat`). If an `async` reply function is registered
183+
and a chat is initialized with a sync function, `ignore_async_in_sync_chat` determines the behaviour as follows:
184+
- if `ignore_async_in_sync_chat` is set to `False` (default value), an exception will be raised, and
185+
- if `ignore_async_in_sync_chat` is set to `True`, the reply function will be ignored.
186+
173187
Args:
174188
trigger (Agent class, str, Agent instance, callable, or list): the trigger.
175189
- If a class is provided, the reply function will be called when the sender is an instance of the class.
@@ -181,6 +195,12 @@ def register_reply(
181195
Note: Be sure to register `None` as a trigger if you would like to trigger an auto-reply function with non-empty messages and `sender=None`.
182196
reply_func (Callable): the reply function.
183197
The function takes a recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
198+
position: the position of the reply function in the reply function list.
199+
config: the config to be passed to the reply function, see below.
200+
reset_config: the function to reset the config, see below.
201+
ignore_async_in_sync_chat: whether to ignore the async reply function in sync chats. If `False`, an exception
202+
will be raised if an async reply function is registered and a chat is initialized with a sync
203+
function.
184204
```python
185205
def reply_func(
186206
recipient: ConversableAgent,
@@ -209,6 +229,8 @@ def reply_func(
209229
"reset_config": reset_config,
210230
},
211231
)
232+
if ignore_async_in_sync_chat and asyncio.coroutines.iscoroutinefunction(reply_func):
233+
self._ignore_async_func_in_sync_chat_list.append(reply_func)
212234

213235
@property
214236
def system_message(self) -> Union[str, List]:
@@ -597,6 +619,25 @@ def _prepare_chat(self, recipient, clear_history):
597619
self.clear_history(recipient)
598620
recipient.clear_history(self)
599621

622+
def _raise_exception_on_async_reply_functions(self) -> None:
623+
"""Raise an exception if any async reply functions are registered.
624+
625+
Raises:
626+
RuntimeError: if any async reply functions are registered.
627+
"""
628+
reply_functions = {f["reply_func"] for f in self._reply_func_list}.difference(
629+
self._ignore_async_func_in_sync_chat_list
630+
)
631+
632+
async_reply_functions = [f for f in reply_functions if asyncio.coroutines.iscoroutinefunction(f)]
633+
if async_reply_functions != []:
634+
msg = (
635+
"Async reply functions can only be used with ConversableAgent.a_initiate_chat(). The following async reply functions are found: "
636+
+ ", ".join([f.__name__ for f in async_reply_functions])
637+
)
638+
639+
raise RuntimeError(msg)
640+
600641
def initiate_chat(
601642
self,
602643
recipient: "ConversableAgent",
@@ -616,7 +657,12 @@ def initiate_chat(
616657
silent (bool or None): (Experimental) whether to print the messages for this conversation.
617658
**context: any context information.
618659
"message" needs to be provided if the `generate_init_message` method is not overridden.
660+
661+
Raises:
662+
RuntimeError: if any async reply functions are registered and not ignored in sync chat.
619663
"""
664+
for agent in [self, recipient]:
665+
agent._raise_exception_on_async_reply_functions()
620666
self._prepare_chat(recipient, clear_history)
621667
self.send(self.generate_init_message(**context), recipient, silent=silent)
622668

autogen/agentchat/groupchat.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -457,11 +457,20 @@ def __init__(
457457
system_message=system_message,
458458
**kwargs,
459459
)
460+
# Store groupchat
461+
self._groupchat = groupchat
462+
460463
# Order of register_reply is important.
461464
# Allow sync chat if initiated using initiate_chat
462465
self.register_reply(Agent, GroupChatManager.run_chat, config=groupchat, reset_config=GroupChat.reset)
463466
# Allow async chat if initiated using a_initiate_chat
464-
self.register_reply(Agent, GroupChatManager.a_run_chat, config=groupchat, reset_config=GroupChat.reset)
467+
self.register_reply(
468+
Agent,
469+
GroupChatManager.a_run_chat,
470+
config=groupchat,
471+
reset_config=GroupChat.reset,
472+
ignore_async_in_sync_chat=True,
473+
)
465474

466475
def run_chat(
467476
self,
@@ -567,3 +576,14 @@ async def a_run_chat(
567576
await speaker.a_send(reply, self, request_reply=False)
568577
message = self.last_message(speaker)
569578
return True, None
579+
580+
def _raise_exception_on_async_reply_functions(self) -> None:
581+
"""Raise an exception if any async reply functions are registered.
582+
583+
Raises:
584+
RuntimeError: if any async reply functions are registered.
585+
"""
586+
super()._raise_exception_on_async_reply_functions()
587+
588+
for agent in self._groupchat.agents:
589+
agent._raise_exception_on_async_reply_functions()

test/agentchat/test_conversable_agent.py

+109-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def conversable_agent():
3434
)
3535

3636

37-
def test_trigger():
37+
def test_sync_trigger():
3838
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
3939
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
4040
agent.register_reply(agent1, lambda recipient, messages, sender, config: (True, "hello"))
@@ -72,6 +72,114 @@ def test_trigger():
7272
pytest.raises(ValueError, agent._match_trigger, 1, agent1)
7373

7474

75+
@pytest.mark.asyncio
76+
async def test_async_trigger():
77+
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
78+
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
79+
80+
async def a_reply(recipient, messages, sender, config):
81+
print("hello from a_reply")
82+
return (True, "hello")
83+
84+
agent.register_reply(agent1, a_reply)
85+
await agent1.a_initiate_chat(agent, message="hi")
86+
assert agent1.last_message(agent)["content"] == "hello"
87+
88+
async def a_reply_a1(recipient, messages, sender, config):
89+
print("hello from a_reply_a1")
90+
return (True, "hello a1")
91+
92+
agent.register_reply("a1", a_reply_a1)
93+
await agent1.a_initiate_chat(agent, message="hi")
94+
assert agent1.last_message(agent)["content"] == "hello a1"
95+
96+
async def a_reply_conversable_agent(recipient, messages, sender, config):
97+
print("hello from a_reply_conversable_agent")
98+
return (True, "hello conversable agent")
99+
100+
agent.register_reply(ConversableAgent, a_reply_conversable_agent)
101+
await agent1.a_initiate_chat(agent, message="hi")
102+
assert agent1.last_message(agent)["content"] == "hello conversable agent"
103+
104+
async def a_reply_a(recipient, messages, sender, config):
105+
print("hello from a_reply_a")
106+
return (True, "hello a")
107+
108+
agent.register_reply(lambda sender: sender.name.startswith("a"), a_reply_a)
109+
await agent1.a_initiate_chat(agent, message="hi")
110+
assert agent1.last_message(agent)["content"] == "hello a"
111+
112+
async def a_reply_b(recipient, messages, sender, config):
113+
print("hello from a_reply_b")
114+
return (True, "hello b")
115+
116+
agent.register_reply(lambda sender: sender.name.startswith("b"), a_reply_b)
117+
await agent1.a_initiate_chat(agent, message="hi")
118+
assert agent1.last_message(agent)["content"] == "hello a"
119+
120+
async def a_reply_agent2_or_agent1(recipient, messages, sender, config):
121+
print("hello from a_reply_agent2_or_agent1")
122+
return (True, "hello agent2 or agent1")
123+
124+
agent.register_reply(["agent2", agent1], a_reply_agent2_or_agent1)
125+
await agent1.a_initiate_chat(agent, message="hi")
126+
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"
127+
128+
async def a_reply_agent2_or_agent3(recipient, messages, sender, config):
129+
print("hello from a_reply_agent2_or_agent3")
130+
return (True, "hello agent2 or agent3")
131+
132+
agent.register_reply(["agent2", "agent3"], a_reply_agent2_or_agent3)
133+
await agent1.a_initiate_chat(agent, message="hi")
134+
assert agent1.last_message(agent)["content"] == "hello agent2 or agent1"
135+
136+
with pytest.raises(ValueError):
137+
agent.register_reply(1, a_reply)
138+
139+
with pytest.raises(ValueError):
140+
agent._match_trigger(1, agent1)
141+
142+
143+
def test_async_trigger_in_sync_chat():
144+
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
145+
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
146+
agent2 = ConversableAgent("a2", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
147+
148+
reply_mock = unittest.mock.MagicMock()
149+
150+
async def a_reply(recipient, messages, sender, config):
151+
reply_mock()
152+
print("hello from a_reply")
153+
return (True, "hello from reply function")
154+
155+
agent.register_reply(agent1, a_reply)
156+
157+
with pytest.raises(RuntimeError) as e:
158+
agent1.initiate_chat(agent, message="hi")
159+
160+
assert (
161+
e.value.args[0] == "Async reply functions can only be used with ConversableAgent.a_initiate_chat(). "
162+
"The following async reply functions are found: a_reply"
163+
)
164+
165+
agent2.register_reply(agent1, a_reply, ignore_async_in_sync_chat=True)
166+
reply_mock.assert_not_called()
167+
168+
169+
@pytest.mark.asyncio
170+
async def test_sync_trigger_in_async_chat():
171+
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
172+
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
173+
174+
def a_reply(recipient, messages, sender, config):
175+
print("hello from a_reply")
176+
return (True, "hello from reply function")
177+
178+
agent.register_reply(agent1, a_reply)
179+
await agent1.a_initiate_chat(agent, message="hi")
180+
assert agent1.last_message(agent)["content"] == "hello from reply function"
181+
182+
75183
def test_context():
76184
agent = ConversableAgent("a0", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")
77185
agent1 = ConversableAgent("a1", max_consecutive_auto_reply=0, llm_config=False, human_input_mode="NEVER")

0 commit comments

Comments
 (0)