Skip to content

Commit f11a884

Browse files
aswnyjoshkyh
authored andcommitted
allow function to remove termination string in groupchat (#2804)
* allow function to remove termination string in groupchat * improve docstring Co-authored-by: Joshua Kim <[email protected]> * improve docstring Co-authored-by: Joshua Kim <[email protected]> * improve test case description Co-authored-by: Joshua Kim <[email protected]> --------- Co-authored-by: Joshua Kim <[email protected]>
1 parent 160bdcc commit f11a884

File tree

2 files changed

+72
-11
lines changed

2 files changed

+72
-11
lines changed

autogen/agentchat/groupchat.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -1160,15 +1160,17 @@ async def a_run_chat(
11601160
def resume(
11611161
self,
11621162
messages: Union[List[Dict], str],
1163-
remove_termination_string: str = None,
1163+
remove_termination_string: Union[str, Callable[[str], str]] = None,
11641164
silent: Optional[bool] = False,
11651165
) -> Tuple[ConversableAgent, Dict]:
11661166
"""Resumes a group chat using the previous messages as a starting point. Requires the agents, group chat, and group chat manager to be established
11671167
as per the original group chat.
11681168
11691169
Args:
11701170
- messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
1171-
- remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
1171+
- remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
1172+
If a string is provided, this string will be removed from last message.
1173+
If a function is provided, the last message will be passed to this function.
11721174
- silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
11731175
11741176
Returns:
@@ -1263,15 +1265,17 @@ def resume(
12631265
async def a_resume(
12641266
self,
12651267
messages: Union[List[Dict], str],
1266-
remove_termination_string: str = None,
1268+
remove_termination_string: Union[str, Callable[[str], str]],
12671269
silent: Optional[bool] = False,
12681270
) -> Tuple[ConversableAgent, Dict]:
12691271
"""Resumes a group chat using the previous messages as a starting point, asynchronously. Requires the agents, group chat, and group chat manager to be established
12701272
as per the original group chat.
12711273
12721274
Args:
12731275
- messages Union[List[Dict], str]: The content of the previous chat's messages, either as a Json string or a list of message dictionaries.
1274-
- remove_termination_string str: Remove the provided string from the last message to prevent immediate termination
1276+
- remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
1277+
If a string is provided, this string will be removed from last message.
1278+
If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
12751279
- silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
12761280
12771281
Returns:
@@ -1390,11 +1394,15 @@ def _valid_resume_messages(self, messages: List[Dict]):
13901394
):
13911395
raise Exception(f"Agent name in message doesn't exist as agent in group chat: {message['name']}")
13921396

1393-
def _process_resume_termination(self, remove_termination_string: str, messages: List[Dict]):
1397+
def _process_resume_termination(
1398+
self, remove_termination_string: Union[str, Callable[[str], str]], messages: List[Dict]
1399+
):
13941400
"""Removes termination string, if required, and checks if termination may occur.
13951401
13961402
args:
1397-
remove_termination_string (str): termination string to remove from the last message
1403+
remove_termination_string (str or function): Remove the termination string from the last message to prevent immediate termination
1404+
If a string is provided, this string will be removed from last message.
1405+
If a function is provided, the last message will be passed to this function, and the function returns the string after processing.
13981406
13991407
returns:
14001408
None
@@ -1403,9 +1411,17 @@ def _process_resume_termination(self, remove_termination_string: str, messages:
14031411
last_message = messages[-1]
14041412

14051413
# Replace any given termination string in the last message
1406-
if remove_termination_string:
1407-
if messages[-1].get("content") and remove_termination_string in messages[-1]["content"]:
1408-
messages[-1]["content"] = messages[-1]["content"].replace(remove_termination_string, "")
1414+
if isinstance(remove_termination_string, str):
1415+
1416+
def _remove_termination_string(content: str) -> str:
1417+
return content.replace(remove_termination_string, "")
1418+
1419+
else:
1420+
_remove_termination_string = remove_termination_string
1421+
1422+
if _remove_termination_string:
1423+
if messages[-1].get("content"):
1424+
messages[-1]["content"] = _remove_termination_string(messages[-1]["content"])
14091425

14101426
# Check if the last message meets termination (if it has one)
14111427
if self._is_termination_msg:

test/agentchat/test_groupchat.py

+47-2
Original file line numberDiff line numberDiff line change
@@ -1916,6 +1916,51 @@ def test_manager_resume_functions():
19161916
# TERMINATE should be removed
19171917
assert messages[-1]["content"] == final_msg.replace("TERMINATE", "")
19181918

1919+
# Tests termination message replacement with function
1920+
def termination_func(x: str) -> str:
1921+
if "APPROVED" in x:
1922+
x = x.replace("APPROVED", "")
1923+
else:
1924+
x = x.replace("TERMINATE", "")
1925+
return x
1926+
1927+
final_msg1 = "Product_Manager has created 3 new product ideas. APPROVED"
1928+
messages1 = [
1929+
{
1930+
"content": "You are an expert at finding the next speaker.",
1931+
"role": "system",
1932+
},
1933+
{
1934+
"content": final_msg1,
1935+
"name": "Coder",
1936+
"role": "assistant",
1937+
},
1938+
]
1939+
1940+
manager._process_resume_termination(remove_termination_string=termination_func, messages=messages1)
1941+
1942+
# APPROVED should be removed
1943+
assert messages1[-1]["content"] == final_msg1.replace("APPROVED", "")
1944+
1945+
final_msg2 = "Idea has been approved. TERMINATE"
1946+
messages2 = [
1947+
{
1948+
"content": "You are an expert at finding the next speaker.",
1949+
"role": "system",
1950+
},
1951+
{
1952+
"content": final_msg2,
1953+
"name": "Coder",
1954+
"role": "assistant",
1955+
},
1956+
]
1957+
1958+
manager._process_resume_termination(remove_termination_string=termination_func, messages=messages2)
1959+
1960+
# TERMINATE should be removed, "approved" should still be present as the termination_func only replaces upper-cased "APPROVED".
1961+
assert messages2[-1]["content"] == final_msg2.replace("TERMINATE", "")
1962+
assert "approved" in messages2[-1]["content"]
1963+
19191964
# Check if the termination string doesn't exist there's no replacing of content
19201965
final_msg = (
19211966
"Let's get this meeting started. First the Product_Manager will create 3 new product ideas. TERMINATE this."
@@ -2027,7 +2072,7 @@ def test_manager_resume_messages():
20272072
# test_clear_agents_history()
20282073
# test_custom_speaker_selection_overrides_transition_graph()
20292074
# test_role_for_select_speaker_messages()
2030-
test_select_speaker_message_and_prompt_templates()
2075+
# test_select_speaker_message_and_prompt_templates()
20312076
# test_speaker_selection_agent_name_match()
20322077
# test_role_for_reflection_summary()
20332078
# test_speaker_selection_auto_process_result()
@@ -2036,7 +2081,7 @@ def test_manager_resume_messages():
20362081
# test_select_speaker_auto_messages()
20372082
# test_manager_messages_to_string()
20382083
# test_manager_messages_from_string()
2039-
# test_manager_resume_functions()
2084+
test_manager_resume_functions()
20402085
# test_manager_resume_returns()
20412086
# test_manager_resume_messages()
20422087
pass

0 commit comments

Comments
 (0)