Skip to content

Commit

Permalink
max_turns
Browse files Browse the repository at this point in the history
  • Loading branch information
qingyun-wu committed Feb 16, 2024
1 parent bf3cc0f commit 18d7f07
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 60 deletions.
54 changes: 30 additions & 24 deletions autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,31 +34,37 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
args:
chat_queue (List[Dict]): a list of dictionaries containing the information of the chats.
Each dictionary should contain the following fields:
Each dictionary should contain the input arguments for `ConversableAgent.initiate_chat`.
More specifically, each dictionary could include the following fields:
recipient: the recipient agent.
- "sender": the sender agent.
- "recipient": the recipient agent.
- "context": any context information, e.g., the request message. The following fields are reserved:
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- Supported string are "last_msg" and "reflection_with_llm":
when set "last_msg", it returns the last message of the dialog as the summary.
when set "reflection_with_llm", it returns a summary extracted using an llm client.
`llm_config` must be set in either the recipient or sender.
"reflection_with_llm" requires the llm_config to be set in either the sender or the recipient.
- A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g,
```python
def my_summary_method(
sender: ConversableAgent,
recipient: ConversableAgent,
):
return recipient.last_message(sender)["content"]
```
"summary_prompt" can be used to specify the prompt used to extract a summary when summary_method is "reflection_with_llm".
Default is None and the following default prompt will be used when "summary_method" is set to "reflection_with_llm":
"Identify and extract the final solution to the originally asked question based on the conversation."
"carryover" can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
- clear_history (bool): whether to clear the chat history with the agent. Default is True.
- silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
- cache (Cache or None): the cache client to be used for this conversation. Default is None.
- max_turns (int or None): the maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
- "message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
- "summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- Supported string are "last_msg" and "reflection_with_llm":
when set "last_msg", it returns the last message of the dialog as the summary.
when set "reflection_with_llm", it returns a summary extracted using an llm client.
`llm_config` must be set in either the recipient or sender.
"reflection_with_llm" requires the llm_config to be set in either the sender or the recipient.
- A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g,
```python
def my_summary_method(
sender: ConversableAgent,
recipient: ConversableAgent,
):
return recipient.last_message(sender)["content"]
```
"summary_prompt" can be used to specify the prompt used to extract a summary when summary_method is "reflection_with_llm".
Default is None and the following default prompt will be used when "summary_method" is set to "reflection_with_llm":
"Identify and extract the final solution to the originally asked question based on the conversation."
"carryover" can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
returns:
Expand Down
63 changes: 33 additions & 30 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,7 @@ def initiate_chat(
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
max_turns: Optional[int] = None,
**context,
) -> ChatResult:
"""Initiate a chat with the recipient agent.
Expand All @@ -773,9 +774,10 @@ def initiate_chat(
Args:
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent.
silent (bool or None): (Experimental) whether to print the messages for this conversation.
cache (Cache or None): the cache client to be used for this conversation.
clear_history (bool): whether to clear the chat history with the agent. Default is True.
silent (bool or None): (Experimental) whether to print the messages for this conversation. Default is False.
cache (Cache or None): the cache client to be used for this conversation. Default is None.
max_turns (int or None): the maximum number of turns for the chat. If None, the chat will continue until a termination condition is met. Default is None.
**context: any context information. It has the following reserved fields:
"message": a str of message. Needs to be provided. Otherwise, input() will be called to get the initial message.
"summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
Expand Down Expand Up @@ -813,7 +815,19 @@ def my_summary_method(
agent.previous_cache = agent.client_cache
agent.client_cache = cache
self._prepare_chat(recipient, clear_history)
self.send(self.generate_init_message(**context), recipient, silent=silent)
if isinstance(max_turns, int):
msg2send = self.generate_init_message(**context)
for _ in range(max_turns):
if msg2send is None:
break
self.send(msg2send, recipient, request_reply=False, silent=silent)
msg2send = recipient.generate_reply(messages=recipient.chat_messages[self], sender=self)
if msg2send is None:
break
recipient.send(msg2send, self, request_reply=False, silent=silent)
msg2send = self.generate_reply(messages=self.chat_messages[recipient], sender=recipient)
else:
self.send(self.generate_init_message(**context), recipient, silent=silent)
summary = self._summarize_chat(
context.get("summary_method", ConversableAgent.DEFAULT_summary_method),
recipient,
Expand All @@ -837,6 +851,7 @@ async def a_initiate_chat(
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
max_turns: Optional[int] = None,
**context,
) -> ChatResult:
"""(async) Initiate a chat with the recipient agent.
Expand All @@ -857,7 +872,19 @@ async def a_initiate_chat(
for agent in [self, recipient]:
agent.previous_cache = agent.client_cache
agent.client_cache = cache
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
if isinstance(max_turns, int):
msg2send = await self.a_generate_init_message(**context)
for _ in range(max_turns):
if msg2send is None:
break
await self.a_send(msg2send, recipient, request_reply=False, silent=silent)
msg2send = await recipient.a_generate_reply(messages=recipient.chat_messages[self], sender=self)
if msg2send is None:
break
await recipient.a_send(msg2send, self, request_reply=False, silent=silent)
msg2send = await self.a_generate_reply(messages=self.chat_messages[recipient], sender=recipient)
else:
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
summary = self._summarize_chat(
context.get("summary_method", ConversableAgent.DEFAULT_summary_method),
recipient,
Expand Down Expand Up @@ -956,31 +983,7 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
Args:
chat_queue (List[Dict]): a list of dictionaries containing the information of the chats.
Each dictionary should contain the following fields:
- "recipient": the recipient agent.
- "context": any context information, e.g., the request message. The following fields are reserved:
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"summary_method": a string or callable specifying the method to get a summary from the chat. Default is DEFAULT_summary_method, i.e., "last_msg".
- Supported string are "last_msg" and "reflection_with_llm":
when set "last_msg", it returns the last message of the dialog as the summary.
when set "reflection_with_llm", it returns a summary extracted using an llm client.
`llm_config` must be set in either the recipient or sender.
"reflection_with_llm" requires the llm_config to be set in either the sender or the recipient.
- A callable summary_method should take the recipient and sender agent in a chat as input and return a string of summary. E.g,
```python
def my_summary_method(
sender: ConversableAgent,
recipient: ConversableAgent,
):
return recipient.last_message(sender)["content"]
```
"summary_prompt" can be used to specify the prompt used to extract a summary when summary_method is "reflection_with_llm".
Default is None and the following default prompt will be used when "summary_method" is set to "reflection_with_llm":
"Identify and extract the final solution to the originally asked question based on the conversation."
"carryover" can be used to specify the carryover information to be passed to this chat.
If provided, we will combine this carryover with the "message" content when generating the initial chat
message in `generate_init_message`.
Each dictionary should contain the input arguments for `initiate_chat`.
Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue.
"""
Expand Down
29 changes: 29 additions & 0 deletions test/agentchat/test_async_get_human_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,34 @@ async def test_async_get_human_input():
print("Human input:", res.human_input)


@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
@pytest.mark.asyncio
async def test_async_max_turn():
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, KEY_LOC)

# create an AssistantAgent instance named "assistant"
assistant = autogen.AssistantAgent(
name="assistant",
max_consecutive_auto_reply=10,
llm_config={
"seed": 41,
"config_list": config_list,
},
)

user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)

user_proxy.a_get_human_input = AsyncMock(return_value="Not funny. Try again.")

res = await user_proxy.a_initiate_chat(
assistant, clear_history=True, max_turns=3, message="Hello, make a joke about AI."
)
print("Result summary:", res.summary)
print("Human input:", res.human_input)
print("chat history:", res.chat_history)
assert len(res.chat_history) == 6


if __name__ == "__main__":
asyncio.run(test_async_get_human_input())
asyncio.run(test_async_max_turn())
14 changes: 9 additions & 5 deletions test/agentchat/test_chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def my_summary_method(recipient, sender):
{
"recipient": financial_assistant_2,
"message": financial_tasks[1],
"silent": True,
"silent": False,
"max_turns": 1,
"summary_method": "reflection_with_llm",
},
{
Expand Down Expand Up @@ -228,6 +229,7 @@ def my_summary_method(recipient, sender):
print(all_res[0].summary)
print(all_res[0].chat_history)
print(all_res[1].summary)
assert len(all_res[1].chat_history) <= 2
# print(blogpost.summary, insights_and_blogpost)


Expand Down Expand Up @@ -305,7 +307,8 @@ def my_summary_method(recipient, sender):
"sender": user_2,
"recipient": financial_assistant_2,
"message": financial_tasks[1],
"silent": True,
"silent": False,
"max_turns": 3,
"summary_method": "reflection_with_llm",
},
{
Expand All @@ -332,6 +335,7 @@ def my_summary_method(recipient, sender):
print(chat_res[0].summary)
print(chat_res[0].chat_history)
print(chat_res[1].summary)
assert len(chat_res[1].chat_history) <= 6
# print(blogpost.summary, insights_and_blogpost)


Expand Down Expand Up @@ -485,7 +489,7 @@ def currency_calculator(
if __name__ == "__main__":
test_chats()
test_chats_general()
test_chats_exceptions()
test_chats_group()
test_chats_w_func()
# test_chats_exceptions()
# test_chats_group()
# test_chats_w_func()
# test_chat_messages_for_summary()
29 changes: 28 additions & 1 deletion test/agentchat/test_conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any, Callable, Dict, Literal
import unittest
import inspect
from unittest.mock import MagicMock

import pytest
from unittest.mock import patch
Expand Down Expand Up @@ -1028,10 +1029,36 @@ def stopwatch(num_seconds: Annotated[str, "Number of seconds in the stopwatch."]
stopwatch_mock.assert_called_once_with(num_seconds="5")


@pytest.mark.skipif(
skip or not sys.version.startswith("3.10"),
reason="do not run if openai is not installed or py!=3.10",
)
def test_max_turn():
config_list = autogen.config_list_from_json(OAI_CONFIG_LIST, KEY_LOC)

# create an AssistantAgent instance named "assistant"
assistant = autogen.AssistantAgent(
name="assistant",
max_consecutive_auto_reply=10,
llm_config={"timeout": 600, "cache_seed": 41, "config_list": config_list},
)

user_proxy = autogen.UserProxyAgent(name="user", human_input_mode="ALWAYS", code_execution_config=False)

# Use MagicMock to create a mock get_human_input function
user_proxy.get_human_input = MagicMock(return_value="Not funny. Try again.")
res = user_proxy.initiate_chat(assistant, clear_history=True, max_turns=3, message="Hello, make a joke about AI.")
print("Result summary:", res.summary)
print("Human input:", res.human_input)
print("history", res.chat_history)
assert len(res.chat_history) <= 6


if __name__ == "__main__":
# test_trigger()
# test_context()
# test_max_consecutive_auto_reply()
test_generate_code_execution_reply()
# test_generate_code_execution_reply()
# test_conversable_agent()
# test_no_llm_config()
test_max_turn()

0 comments on commit 18d7f07

Please sign in to comment.