diff --git a/autogen/agentchat/groupchat.py b/autogen/agentchat/groupchat.py index 5b12a97e6b17..03251eccfb71 100644 --- a/autogen/agentchat/groupchat.py +++ b/autogen/agentchat/groupchat.py @@ -64,8 +64,10 @@ def agent_by_name(self, name: str) -> Agent: """Returns the agent with a given name.""" return self.agents[self.agent_names.index(name)] - def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent: + def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agent: """Return the next agent in the list.""" + if agents is None: + agents = self.agents # What index is the agent? (-1 if not present) idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1 @@ -79,20 +81,26 @@ def next_agent(self, agent: Agent, agents: List[Agent]) -> Agent: if self.agents[(offset + i) % len(self.agents)] in agents: return self.agents[(offset + i) % len(self.agents)] - def select_speaker_msg(self, agents: List[Agent]) -> str: + def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str: """Return the system message for selecting the next speaker. This is always the *first* message in the context.""" + if agents is None: + agents = self.agents return f"""You are in a role play game. The following roles are available: {self._participant_roles(agents)}. Read the following conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role.""" - def select_speaker_prompt(self, agents: List[Agent]) -> str: + def select_speaker_prompt(self, agents: Optional[List[Agent]] = None) -> str: """Return the floating system prompt selecting the next speaker. This is always the *last* message in the context.""" + if agents is None: + agents = self.agents return f"Read the above conversation. Then select the next role from {[agent.name for agent in agents]} to play. Only return the role." - def manual_select_speaker(self, agents: List[Agent]) -> Union[Agent, None]: + def manual_select_speaker(self, agents: Optional[List[Agent]] = None) -> Union[Agent, None]: """Manually select the next speaker.""" + if agents is None: + agents = self.agents print("Please select the next speaker from the following list:") _n_agents = len(agents) diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 6d592ae3fa3d..038f68bf1624 100644 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -421,6 +421,10 @@ def test_next_agent(): assert groupchat.next_agent(agent2, [agent1, agent2, agent3]) == agent3 assert groupchat.next_agent(agent3, [agent1, agent2, agent3]) == agent1 + assert groupchat.next_agent(agent1) == agent2 + assert groupchat.next_agent(agent2) == agent3 + assert groupchat.next_agent(agent3) == agent1 + assert groupchat.next_agent(agent1, [agent1, agent3]) == agent3 assert groupchat.next_agent(agent3, [agent1, agent3]) == agent1 @@ -429,6 +433,48 @@ def test_next_agent(): assert groupchat.next_agent(agent4, [agent1, agent2, agent3]) == agent1 +def test_selection_helpers(): + agent1 = autogen.ConversableAgent( + "alice", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is alice speaking.", + description="Alice is an AI agent.", + ) + agent2 = autogen.ConversableAgent( + "bob", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + description="Bob is an AI agent.", + ) + agent3 = autogen.ConversableAgent( + "sam", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is sam speaking.", + system_message="Sam is an AI agent.", + ) + + # Test empty is_termination_msg function + groupchat = autogen.GroupChat( + agents=[agent1, agent2, agent3], messages=[], speaker_selection_method="round_robin", max_round=10 + ) + + select_speaker_msg = groupchat.select_speaker_msg() + select_speaker_prompt = groupchat.select_speaker_prompt() + + assert "Alice is an AI agent." in select_speaker_msg + assert "Bob is an AI agent." in select_speaker_msg + assert "Sam is an AI agent." in select_speaker_msg + assert str(["Alice", "Bob", "Sam"]).lower() in select_speaker_prompt.lower() + + with mock.patch.object(builtins, "input", lambda _: "1"): + groupchat.manual_select_speaker() + + if __name__ == "__main__": # test_func_call_groupchat() # test_broadcast()