Skip to content

Commit bf559d7

Browse files
MarianoMolinasonichiekzhu
authored andcommitted
Add role to reflection with llm (#2527)
* Added 'role' as a summary_args and to the reflection_with_llm flow to be able to pass the role for the summarizing prompt * Added 'role' as a summary_args and to the reflection_with_llm flow to be able to pass the role for the summarizing prompt, minor docstring adjustments * Added test for summary prompt role assignment * Fixed docstrings and mocked llm-config in the test * Update autogen/agentchat/conversable_agent.py Co-authored-by: Chi Wang <[email protected]> * ran pre-commit * ran pre-commit2 * fixed old arg name * Delete dasdaasd No idea what this file was about * Fixed incorrect merge update on test_groupchat --------- Co-authored-by: Chi Wang <[email protected]> Co-authored-by: Eric Zhu <[email protected]>
1 parent db7bbfd commit bf559d7

File tree

2 files changed

+70
-10
lines changed

2 files changed

+70
-10
lines changed

autogen/agentchat/conversable_agent.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,7 @@ def my_summary_method(
937937
One example key is "summary_prompt", and value is a string of text used to prompt a LLM-based agent (the sender or receiver agent) to reflect
938938
on the conversation and extract a summary when summary_method is "reflection_with_llm".
939939
The default summary_prompt is DEFAULT_SUMMARY_PROMPT, i.e., "Summarize takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
940+
Another available key is "summary_role", which is the role of the message sent to the agent in charge of summarizing. Default is "system".
940941
message (str, dict or Callable): the initial message to be sent to the recipient. Needs to be provided. Otherwise, input() will be called to get the initial message.
941942
- If a string or a dict is provided, it will be used as the initial message. `generate_init_message` is called to generate the initial message for the agent based on this string and the context.
942943
If dict, it may contain the following reserved fields (either content or tool_calls need to be provided).
@@ -1168,8 +1169,13 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
11681169
raise ValueError("The summary_prompt must be a string.")
11691170
msg_list = recipient.chat_messages_for_summary(sender)
11701171
agent = sender if recipient is None else recipient
1172+
role = summary_args.get("summary_role", None)
1173+
if role and not isinstance(role, str):
1174+
raise ValueError("The summary_role in summary_arg must be a string.")
11711175
try:
1172-
summary = sender._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"))
1176+
summary = sender._reflection_with_llm(
1177+
prompt, msg_list, llm_agent=agent, cache=summary_args.get("cache"), role=role
1178+
)
11731179
except BadRequestError as e:
11741180
warnings.warn(
11751181
f"Cannot extract summary using reflection_with_llm: {e}. Using an empty str as summary.", UserWarning
@@ -1178,7 +1184,12 @@ def _reflection_with_llm_as_summary(sender, recipient, summary_args):
11781184
return summary
11791185

11801186
def _reflection_with_llm(
1181-
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[AbstractCache] = None
1187+
self,
1188+
prompt,
1189+
messages,
1190+
llm_agent: Optional[Agent] = None,
1191+
cache: Optional[AbstractCache] = None,
1192+
role: Union[str, None] = None,
11821193
) -> str:
11831194
"""Get a chat summary using reflection with an llm client based on the conversation history.
11841195
@@ -1187,10 +1198,14 @@ def _reflection_with_llm(
11871198
messages (list): The messages generated as part of a chat conversation.
11881199
llm_agent: the agent with an llm client.
11891200
cache (AbstractCache or None): the cache client to be used for this conversation.
1201+
role (str): the role of the message, usually "system" or "user". Default is "system".
11901202
"""
1203+
if not role:
1204+
role = "system"
1205+
11911206
system_msg = [
11921207
{
1193-
"role": "system",
1208+
"role": role,
11941209
"content": prompt,
11951210
}
11961211
]

test/agentchat/test_groupchat.py

+52-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,10 @@
55
import json
66
import logging
77
from typing import Any, Dict, List, Optional
8-
from unittest import mock
8+
from unittest import TestCase, mock
99

1010
import pytest
11+
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST
1112

1213
import autogen
1314
from autogen import Agent, AssistantAgent, GroupChat, GroupChatManager
@@ -1446,6 +1447,46 @@ def test_speaker_selection_agent_name_match():
14461447
assert result == {}
14471448

14481449

1450+
def test_role_for_reflection_summary():
1451+
llm_config = {"config_list": [{"model": "mock", "api_key": "mock"}]}
1452+
agent1 = autogen.ConversableAgent(
1453+
"alice",
1454+
max_consecutive_auto_reply=10,
1455+
human_input_mode="NEVER",
1456+
llm_config=False,
1457+
default_auto_reply="This is alice speaking.",
1458+
)
1459+
agent2 = autogen.ConversableAgent(
1460+
"bob",
1461+
max_consecutive_auto_reply=10,
1462+
human_input_mode="NEVER",
1463+
llm_config=False,
1464+
default_auto_reply="This is bob speaking.",
1465+
)
1466+
groupchat = autogen.GroupChat(
1467+
agents=[agent1, agent2], messages=[], max_round=3, speaker_selection_method="round_robin"
1468+
)
1469+
group_chat_manager = autogen.GroupChatManager(groupchat=groupchat, llm_config=llm_config)
1470+
1471+
role_name = "user"
1472+
with mock.patch.object(
1473+
autogen.ConversableAgent, "_generate_oai_reply_from_client"
1474+
) as mock_generate_oai_reply_from_client:
1475+
mock_generate_oai_reply_from_client.return_value = "Mocked summary"
1476+
1477+
agent1.initiate_chat(
1478+
group_chat_manager,
1479+
max_turns=2,
1480+
message="hello",
1481+
summary_method="reflection_with_llm",
1482+
summary_args={"summary_role": role_name},
1483+
)
1484+
1485+
mock_generate_oai_reply_from_client.assert_called_once()
1486+
args, kwargs = mock_generate_oai_reply_from_client.call_args
1487+
assert kwargs["messages"][-1]["role"] == role_name
1488+
1489+
14491490
def test_speaker_selection_auto_process_result():
14501491
"""
14511492
Tests the return result of the 2-agent chat used for speaker selection for the auto method.
@@ -1984,12 +2025,16 @@ def test_manager_resume_messages():
19842025
# test_role_for_select_speaker_messages()
19852026
# test_select_speaker_message_and_prompt_templates()
19862027
# test_speaker_selection_agent_name_match()
2028+
# test_role_for_reflection_summary()
2029+
# test_speaker_selection_auto_process_result()
2030+
# test_speaker_selection_validate_speaker_name()
2031+
# test_select_speaker_auto_messages()
19872032
# test_speaker_selection_auto_process_result()
19882033
# test_speaker_selection_validate_speaker_name()
19892034
# test_select_speaker_auto_messages()
1990-
test_manager_messages_to_string()
1991-
test_manager_messages_from_string()
1992-
test_manager_resume_functions()
1993-
test_manager_resume_returns()
1994-
test_manager_resume_messages()
1995-
# pass
2035+
# test_manager_messages_to_string()
2036+
# test_manager_messages_from_string()
2037+
# test_manager_resume_functions()
2038+
# test_manager_resume_returns()
2039+
# test_manager_resume_messages()
2040+
pass

0 commit comments

Comments
 (0)