Skip to content

Commit

Permalink
Support async nested chats (#3309)
Browse files Browse the repository at this point in the history
* Allow async nested chats in agent chat

* Fix pre-comit

* Minor fix

* Fix

* Address feedback

* Update

* Fix build error

---------

Co-authored-by: Qingyun Wu <[email protected]>
  • Loading branch information
heyitsaamir and qingyun-wu authored Aug 9, 2024
1 parent 4dab28c commit aac6f05
Show file tree
Hide file tree
Showing 2 changed files with 297 additions and 10 deletions.
77 changes: 67 additions & 10 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ def replace_reply_func(self, old_reply_func: Callable, new_reply_func: Callable)
f["reply_func"] = new_reply_func

@staticmethod
def _summary_from_nested_chats(
def _get_chats_to_run(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, str]:
) -> List[Dict[str, Any]]:
"""A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue.
Expand All @@ -406,22 +406,59 @@ def _summary_from_nested_chats(
if message:
current_c["message"] = message
chat_to_run.append(current_c)
return chat_to_run

@staticmethod
def _summary_from_nested_chats(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, Union[str, None]]:
"""A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue.
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
Returns:
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = initiate_chats(chat_to_run)
return True, res[-1].summary

@staticmethod
async def _a_summary_from_nested_chats(
chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any
) -> Tuple[bool, Union[str, None]]:
"""A simple chat reply function.
This function initiate one or a sequence of chats between the "recipient" and the agents in the
chat_queue.
It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue.
Returns:
Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated.
"""
chat_to_run = ConversableAgent._get_chats_to_run(chat_queue, recipient, messages, sender, config)
if not chat_to_run:
return True, None
res = await a_initiate_chats(chat_to_run)
index_of_last_chat = chat_to_run[-1]["chat_id"]
return True, res[index_of_last_chat].summary

def register_nested_chats(
self,
chat_queue: List[Dict[str, Any]],
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats",
position: int = 2,
use_async: Union[bool, None] = None,
**kwargs,
) -> None:
"""Register a nested chat reply function.
Args:
chat_queue (list): a list of chat objects to be initiated.
chat_queue (list): a list of chat objects to be initiated. If use_async is used, then all messages in chat_queue must have a chat-id associated with them.
trigger (Agent class, str, Agent instance, callable, or list): refer to `register_reply` for details.
reply_func_from_nested_chats (Callable, str): the reply function for the nested chat.
The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message.
Expand All @@ -436,15 +473,33 @@ def reply_func_from_nested_chats(
) -> Tuple[bool, Union[str, Dict, None]]:
```
position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply.
use_async: Uses a_initiate_chats internally to start nested chats. If the original chat is initiated with a_initiate_chats, you may set this to true so nested chats do not run in sync.
kwargs: Ref to `register_reply` for details.
"""
if reply_func_from_nested_chats == "summary_from_nested_chats":
reply_func_from_nested_chats = self._summary_from_nested_chats
if not callable(reply_func_from_nested_chats):
raise ValueError("reply_func_from_nested_chats must be a callable")
if use_async:
for chat in chat_queue:
if chat.get("chat_id") is None:
raise ValueError("chat_id is required for async nested chats")

if use_async:
if reply_func_from_nested_chats == "summary_from_nested_chats":
reply_func_from_nested_chats = self._a_summary_from_nested_chats
if not callable(reply_func_from_nested_chats) or not inspect.iscoroutinefunction(
reply_func_from_nested_chats
):
raise ValueError("reply_func_from_nested_chats must be a callable and a coroutine")

def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)
async def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
return await reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)

else:
if reply_func_from_nested_chats == "summary_from_nested_chats":
reply_func_from_nested_chats = self._summary_from_nested_chats
if not callable(reply_func_from_nested_chats):
raise ValueError("reply_func_from_nested_chats must be a callable")

def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
return reply_func_from_nested_chats(chat_queue, recipient, messages, sender, config)

functools.update_wrapper(wrapped_reply_func, reply_func_from_nested_chats)

Expand All @@ -454,7 +509,9 @@ def wrapped_reply_func(recipient, messages=None, sender=None, config=None):
position,
kwargs.get("config"),
kwargs.get("reset_config"),
ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"),
ignore_async_in_sync_chat=(
not use_async if use_async is not None else kwargs.get("ignore_async_in_sync_chat")
),
)

@property
Expand Down
230 changes: 230 additions & 0 deletions test/agentchat/test_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,36 @@

import os
import sys
from typing import List

import pytest

import autogen
from autogen.agentchat.contrib.capabilities.agent_capability import AgentCapability

sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(os.path.join(os.path.dirname(__file__), "../.."))
from conftest import reason, skip_openai # noqa: E402
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST # noqa: E402


class MockAgentReplies(AgentCapability):
def __init__(self, mock_messages: List[str]):
self.mock_messages = mock_messages
self.mock_message_index = 0

def add_to_agent(self, agent: autogen.ConversableAgent):
def mock_reply(recipient, messages, sender, config):
if self.mock_message_index < len(self.mock_messages):
reply_msg = self.mock_messages[self.mock_message_index]
self.mock_message_index += 1
return [True, reply_msg]
else:
raise ValueError(f"No more mock messages available for {sender.name} to reply to {recipient.name}")

agent.register_reply([autogen.Agent, None], mock_reply, position=2)


@pytest.mark.skipif(skip_openai, reason=reason)
def test_nested():
config_list = autogen.config_list_from_json(env_or_file=OAI_CONFIG_LIST, file_location=KEY_LOC)
Expand Down Expand Up @@ -142,5 +161,216 @@ def writing_message(recipient, messages, sender, config):
)


def test_sync_nested_chat():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False

inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)

inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)

assistant = autogen.AssistantAgent(
"Assistant",
)
user = autogen.UserProxyAgent(
"User",
human_input_mode="NEVER",
is_termination_msg=is_termination,
)
assistant.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}], trigger=user
)
chat_result = user.initiate_chat(assistant, message="Start chat")
assert len(chat_result.chat_history) == 2
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "FINAL_RESULT"]


@pytest.mark.asyncio
async def test_async_nested_chat():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False

inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)

inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)

assistant = autogen.AssistantAgent(
"Assistant",
)
user = autogen.UserProxyAgent(
"User",
human_input_mode="NEVER",
is_termination_msg=is_termination,
)
assistant.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
trigger=user,
use_async=True,
)
chat_result = await user.a_initiate_chat(assistant, message="Start chat")
assert len(chat_result.chat_history) == 2
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "FINAL_RESULT"]


@pytest.mark.asyncio
async def test_async_nested_chat_chat_id_validation():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False

inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)

inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)

assistant = autogen.AssistantAgent(
"Assistant",
)
user = autogen.UserProxyAgent(
"User",
human_input_mode="NEVER",
is_termination_msg=is_termination,
)
with pytest.raises(ValueError, match="chat_id is required for async nested chats"):
assistant.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
trigger=user,
use_async=True,
)


def test_sync_nested_chat_in_group():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False

inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)

inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)

assistant = autogen.AssistantAgent(
"Assistant_In_Group_1",
)
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
assistant2 = autogen.AssistantAgent(
"Assistant_In_Group_2",
)
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
group = autogen.GroupChat(
agents=[assistant, assistant2, user],
messages=[],
speaker_selection_method="round_robin",
)
group_manager = autogen.GroupChatManager(groupchat=group)
assistant2.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg"}],
trigger=group_manager,
)

chat_result = user.initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
assert len(chat_result.chat_history) == 3
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]


@pytest.mark.asyncio
async def test_async_nested_chat_in_group():
def is_termination(msg):
if isinstance(msg, str) and msg == "FINAL_RESULT":
return True
elif isinstance(msg, dict) and msg.get("content") == "FINAL_RESULT":
return True
return False

inner_assistant = autogen.AssistantAgent(
"Inner-assistant",
is_termination_msg=is_termination,
)
MockAgentReplies(["Inner-assistant message 1", "Inner-assistant message 2"]).add_to_agent(inner_assistant)

inner_assistant_2 = autogen.AssistantAgent(
"Inner-assistant-2",
)
MockAgentReplies(["Inner-assistant-2 message 1", "Inner-assistant-2 message 2", "FINAL_RESULT"]).add_to_agent(
inner_assistant_2
)

assistant = autogen.AssistantAgent(
"Assistant_In_Group_1",
)
MockAgentReplies(["Assistant_In_Group_1 message 1"]).add_to_agent(assistant)
assistant2 = autogen.AssistantAgent(
"Assistant_In_Group_2",
)
user = autogen.UserProxyAgent("User", human_input_mode="NEVER", is_termination_msg=is_termination)
group = autogen.GroupChat(
agents=[assistant, assistant2, user],
messages=[],
speaker_selection_method="round_robin",
)
group_manager = autogen.GroupChatManager(groupchat=group)
assistant2.register_nested_chats(
[{"sender": inner_assistant, "recipient": inner_assistant_2, "summary_method": "last_msg", "chat_id": 1}],
trigger=group_manager,
use_async=True,
)

chat_result = await user.a_initiate_chat(group_manager, message="Start chat", summary_method="last_msg")
assert len(chat_result.chat_history) == 3
chat_messages = [msg["content"] for msg in chat_result.chat_history]
assert chat_messages == ["Start chat", "Assistant_In_Group_1 message 1", "FINAL_RESULT"]


if __name__ == "__main__":
test_nested()

0 comments on commit aac6f05

Please sign in to comment.