From 6465cad43f12523bce5a762902a6367148d45a59 Mon Sep 17 00:00:00 2001 From: Adam Fourney Date: Mon, 13 Nov 2023 12:37:38 -0800 Subject: [PATCH 1/3] Makes select_speaker more robust by checking for agents mentioned anywhere in the selection string. Addresses 663. --- autogen/agentchat/groupchat.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 604eb5c209db..4c1da90c69d1 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -1,6 +1,7 @@ from dataclasses import dataclass import sys from typing import Dict, List, Optional, Union +import re from .agent import Agent from .conversable_agent import ConversableAgent import logging @@ -100,6 +101,20 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent): if not final: # i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id return self.next_agent(last_speaker, agents) + + # Find mentions of any agents + mentions = dict() + for agent in self.agents: + regex = r"\b" + re.escape(agent.name) + r"\b" # Finds agent mentions, taking word boundaries into account + count = len(re.findall(regex, name)) + if count > 0: + mentions[agent.name] = count + + # If exactly one agent is found, use it. Otherwise, leave the OAI response unmodified + if len(mentions) == 1: + name = next(iter(mentions)) + + # Return the result try: return self.agent_by_name(name) except ValueError: From 4eb5a1e64adfddc8168361f01378cf47f0642b6c Mon Sep 17 00:00:00 2001 From: Adam Fourney Date: Thu, 16 Nov 2023 12:41:56 -0800 Subject: [PATCH 2/3] Added test coverage for group chat mentions. Refactored mention counter to own function. --- autogen/agentchat/groupchat.py | 25 +++++++++------ test/agentchat/test_groupchat.py | 52 ++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 9 deletions(-) diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 4c1da90c69d1..1e231efa573e 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -102,15 +102,8 @@ def select_speaker(self, last_speaker: Agent, selector: ConversableAgent): # i = self._random.randint(0, len(self._agent_names) - 1) # randomly pick an id return self.next_agent(last_speaker, agents) - # Find mentions of any agents - mentions = dict() - for agent in self.agents: - regex = r"\b" + re.escape(agent.name) + r"\b" # Finds agent mentions, taking word boundaries into account - count = len(re.findall(regex, name)) - if count > 0: - mentions[agent.name] = count - - # If exactly one agent is found, use it. Otherwise, leave the OAI response unmodified + # If exactly one agent is mentioned, use it. Otherwise, leave the OAI response unmodified + mentions = _mentioned_agents(name, agents) if len(mentions) == 1: name = next(iter(mentions)) @@ -133,6 +126,20 @@ def _participant_roles(self): roles.append(f"{agent.name}: {agent.system_message}") return "\n".join(roles) + def _mentioned_agents(self, message_content: str, agents: List[Agent]) -> Dict: + """ + Finds and counts agent mentions in the string message_content, taking word boundaries into account. + + Returns: A dictionary mapping agent names to mention counts (to be included, at least one mention must occur) + """ + mentions = dict() + for agent in agents: + regex = r"(?<=\W)" + re.escape(agent.name) + r"(?=\W)" # Finds agent mentions, taking word boundaries into account + count = len(re.findall(regex, " " + message_content + " ")) # Pad the message to help with matching + if count > 0: + mentions[agent.name] = count + return mentions + class GroupChatManager(ConversableAgent): """(In preview) A chat manager agent that can manage a group chat of multiple agents.""" diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index c50ef45cdcca..2a5a6d46ea90 100644 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -1,5 +1,6 @@ import pytest import autogen +import json def test_func_call_groupchat(): @@ -111,9 +112,60 @@ def test_plugin(): assert len(agent1.chat_messages[group_chat_manager]) == 2 assert len(groupchat.messages) == 2 +def test_agent_mentions(): + agent1 = autogen.ConversableAgent( + "alice", + max_consecutive_auto_reply=2, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is alice sepaking.", + ) + agent2 = autogen.ConversableAgent( + "bob", + max_consecutive_auto_reply=2, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is bob speaking.", + ) + agent3 = autogen.ConversableAgent( + "sam", + max_consecutive_auto_reply=2, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is sam speaking.", + ) + groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=2) + + # Basic counting + assert json.dumps(groupchat._mentioned_agents("", [agent1, agent2, agent3]), sort_keys=True) == "{}" + assert json.dumps(groupchat._mentioned_agents("alice", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 1}' + assert json.dumps(groupchat._mentioned_agents("alice bob alice", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 2, "bob": 1}' + assert json.dumps(groupchat._mentioned_agents("alice bob alice sam", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 2, "bob": 1, "sam": 1}' + assert json.dumps(groupchat._mentioned_agents("alice bob alice sam robert", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 2, "bob": 1, "sam": 1}' + + # Substring + assert json.dumps(groupchat._mentioned_agents("sam samantha basam asami", [agent1, agent2, agent3]), sort_keys=True) == '{"sam": 1}' + + # Word boundaries + assert json.dumps(groupchat._mentioned_agents("alice! .alice. .alice", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 3}' + + # Special characters in agent names + agent4 = autogen.ConversableAgent( + ".*", + max_consecutive_auto_reply=2, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="Match everything.", + ) + + groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3, agent4], messages=[], max_round=2) + assert json.dumps(groupchat._mentioned_agents("alice bob alice sam robert .*", [agent1, agent2, agent3, agent4]), sort_keys=True) == '{".*": 1, "alice": 2, "bob": 1, "sam": 1}' + + if __name__ == "__main__": test_func_call_groupchat() # test_broadcast() test_chat_manager() # test_plugin() + # test_agent_mentions() From 52e5df034a1fde13158dd2689765bbf0a6635a28 Mon Sep 17 00:00:00 2001 From: Adam Fourney Date: Thu, 16 Nov 2023 12:45:29 -0800 Subject: [PATCH 3/3] Fixed pre-commit formatting. --- test/agentchat/test_groupchat.py | 37 +++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 2a5a6d46ea90..6f634fd8677f 100644 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -112,6 +112,7 @@ def test_plugin(): assert len(agent1.chat_messages[group_chat_manager]) == 2 assert len(groupchat.messages) == 2 + def test_agent_mentions(): agent1 = autogen.ConversableAgent( "alice", @@ -135,19 +136,34 @@ def test_agent_mentions(): default_auto_reply="This is sam speaking.", ) groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3], messages=[], max_round=2) - + # Basic counting assert json.dumps(groupchat._mentioned_agents("", [agent1, agent2, agent3]), sort_keys=True) == "{}" assert json.dumps(groupchat._mentioned_agents("alice", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 1}' - assert json.dumps(groupchat._mentioned_agents("alice bob alice", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 2, "bob": 1}' - assert json.dumps(groupchat._mentioned_agents("alice bob alice sam", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 2, "bob": 1, "sam": 1}' - assert json.dumps(groupchat._mentioned_agents("alice bob alice sam robert", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 2, "bob": 1, "sam": 1}' + assert ( + json.dumps(groupchat._mentioned_agents("alice bob alice", [agent1, agent2, agent3]), sort_keys=True) + == '{"alice": 2, "bob": 1}' + ) + assert ( + json.dumps(groupchat._mentioned_agents("alice bob alice sam", [agent1, agent2, agent3]), sort_keys=True) + == '{"alice": 2, "bob": 1, "sam": 1}' + ) + assert ( + json.dumps(groupchat._mentioned_agents("alice bob alice sam robert", [agent1, agent2, agent3]), sort_keys=True) + == '{"alice": 2, "bob": 1, "sam": 1}' + ) # Substring - assert json.dumps(groupchat._mentioned_agents("sam samantha basam asami", [agent1, agent2, agent3]), sort_keys=True) == '{"sam": 1}' + assert ( + json.dumps(groupchat._mentioned_agents("sam samantha basam asami", [agent1, agent2, agent3]), sort_keys=True) + == '{"sam": 1}' + ) # Word boundaries - assert json.dumps(groupchat._mentioned_agents("alice! .alice. .alice", [agent1, agent2, agent3]), sort_keys=True) == '{"alice": 3}' + assert ( + json.dumps(groupchat._mentioned_agents("alice! .alice. .alice", [agent1, agent2, agent3]), sort_keys=True) + == '{"alice": 3}' + ) # Special characters in agent names agent4 = autogen.ConversableAgent( @@ -159,8 +175,13 @@ def test_agent_mentions(): ) groupchat = autogen.GroupChat(agents=[agent1, agent2, agent3, agent4], messages=[], max_round=2) - assert json.dumps(groupchat._mentioned_agents("alice bob alice sam robert .*", [agent1, agent2, agent3, agent4]), sort_keys=True) == '{".*": 1, "alice": 2, "bob": 1, "sam": 1}' - + assert ( + json.dumps( + groupchat._mentioned_agents("alice bob alice sam robert .*", [agent1, agent2, agent3, agent4]), + sort_keys=True, + ) + == '{".*": 1, "alice": 2, "bob": 1, "sam": 1}' + ) if __name__ == "__main__":