diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index bfd38a54d609..89b2dd94345a 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -937,6 +937,7 @@ def my_summary_method( One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect on the conversation and extract a summary when summary_method is "reflection_with_llm". The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out." + Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system". message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message. - If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context. If dict, it may contain the following reserved fields (either content or tool_calls need to be provided). @@ -1168,8 +1169,13 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args): raise ValueError("The summary_prompt must be a string.") msg_list = recipient.chat_messages_for_summary(sender) agent = sender if recipient is None else recipient + role = summary_args.get("summary_role", None) + if role and not isinstance(role, str): + raise ValueError("The summary_role in summary_arg must be a string.") try: - summary = sender._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache")) + summary = sender._reflection_with_llm( + prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role + ) except BadRequestError as e: warnings.warn( f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning @@ -1178,7 +1184,12 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args): return summary def _reflection_with_llm( - self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None + self, + prompt, + messages, + llm_agent: Optional[Agent] = None, + cache: Optional[AbstractCache] = None, + role: Union[str, None] = None, ) -> str: """Get a chat summary using reflection with an llm client based on the conversation history. @@ -1187,10 +1198,14 @@ def _reflection_with_llm( messages (list): The messages generated as part of a chat conversation. llm_agent: the agent with an llm client. cache (AbstractCache or None): the cache client to be used for this conversation. + role (str): the role of the message, usually "system" or "user". Default is "system". """ + if not role: + role = "system" + system_msg = [ { - "role": "system", + "role": role, "content": prompt, } ] diff --git a/test/agentchat/test_groupchat.py b/test/agentchat/test_groupchat.py index 8dc3dc77746f..0176c9e89901 100755 --- a/test/agentchat/test_groupchat.py +++ b/test/agentchat/test_groupchat.py @@ -5,9 +5,10 @@ import json import logging from typing import Any, Dict, List, Optional -from unittest import mock +from unittest import TestCase, mock import pytest +from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST import autogen from autogen import Agent, AssistantAgent, GroupChat, GroupChatManager @@ -1446,6 +1447,46 @@ def test_speaker_selection_agent_name_match(): assert result == {} +def test_role_for_reflection_summary(): + llm_config = {"config_list": [{"model": "mock", "api_key": "mock"}]} + agent1 = autogen.ConversableAgent( + "alice", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is alice speaking.", + ) + agent2 = autogen.ConversableAgent( + "bob", + max_consecutive_auto_reply=10, + human_input_mode="NEVER", + llm_config=False, + default_auto_reply="This is bob speaking.", + ) + groupchat = autogen.GroupChat( + agents=[agent1, agent2], messages=[], max_round=3, speaker_selection_method="round_robin" + ) + group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config) + + role_name = "user" + with mock.patch.object( + autogen.ConversableAgent, "_generate_oai_reply_from_client" + ) as mock_generate_oai_reply_from_client: + mock_generate_oai_reply_from_client.return_value = "Mocked summary" + + agent1.initiate_chat( + group_chat_manager, + max_turns=2, + message="hello", + summary_method="reflection_with_llm", + summary_args={"summary_role": role_name}, + ) + + mock_generate_oai_reply_from_client.assert_called_once() + args, kwargs = mock_generate_oai_reply_from_client.call_args + assert kwargs["messages"][-1]["role"] == role_name + + def test_speaker_selection_auto_process_result(): """ Tests the return result of the 2-agent chat used for speaker selection for the auto method. @@ -1984,12 +2025,16 @@ def test_manager_resume_messages(): # test_role_for_select_speaker_messages() # test_select_speaker_message_and_prompt_templates() # test_speaker_selection_agent_name_match() + # test_role_for_reflection_summary() + # test_speaker_selection_auto_process_result() + # test_speaker_selection_validate_speaker_name() + # test_select_speaker_auto_messages() # test_speaker_selection_auto_process_result() # test_speaker_selection_validate_speaker_name() # test_select_speaker_auto_messages() - test_manager_messages_to_string() - test_manager_messages_from_string() - test_manager_resume_functions() - test_manager_resume_returns() - test_manager_resume_messages() - # pass + # test_manager_messages_to_string() + # test_manager_messages_from_string() + # test_manager_resume_functions() + # test_manager_resume_returns() + # test_manager_resume_messages() + pass