Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using a more robust "reflection_with_llm" summary method #1575

Merged
merged 11 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 29 additions & 40 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class ConversableAgent(Agent):
DEFAULT_CONFIG = {} # An empty configuration
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change)

DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases. If the intended request is NOT properly addressed, please point it out."
DEFAULT_summary_prompt = "Summarize the takeaway from the conversation. Do not add any introductory phrases."
llm_config: Union[Dict, Literal[False]]

def __init__(
Expand Down Expand Up @@ -822,66 +822,51 @@ def _summarize_chat(
"""
agent = self if agent is None else agent
summary = ""
if method == "last_msg":
try:
summary = agent.last_message(self)["content"]
summary = summary.replace("TERMINATE", "")
except (IndexError, AttributeError):
warnings.warn("Cannot extract summary from last message.", UserWarning)
elif method == "reflection_with_llm":
if method == "reflection_with_llm":
prompt = ConversableAgent.DEFAULT_summary_prompt if prompt is None else prompt
if not isinstance(prompt, str):
raise ValueError("The summary_prompt must be a string.")
msg_list = agent._groupchat.messages if hasattr(agent, "_groupchat") else agent.chat_messages[self]
sonichi marked this conversation as resolved.
Show resolved Hide resolved
try:
summary = self._llm_response_preparer(prompt, msg_list, llm_agent=agent, cache=cache)
summary = self._reflection_with_llm(prompt, msg_list, llm_agent=agent, cache=cache)
except BadRequestError as e:
warnings.warn(f"Cannot extract summary using reflection_with_llm: {e}", UserWarning)
elif method == "last_msg" or method is None:
try:
summary = agent.last_message(self)["content"].replace("TERMINATE", "")
except (IndexError, AttributeError) as e:
warnings.warn(f"Cannot extract summary using last_msg: {e}", UserWarning)
else:
warnings.warn("No summary_method provided or summary_method is not supported: ")
warnings.warn(f"Unsupported summary method: {method}", UserWarning)
return summary

def _llm_response_preparer(
def _reflection_with_llm(
self, prompt, messages, llm_agent: Optional[Agent] = None, cache: Optional[Cache] = None
) -> str:
"""Default summary preparer with llm
"""Get a chat summary using reflection with an llm client based on the conversation history.

Args:
prompt (str): The prompt used to extract the final response from the transcript.
prompt (str): The prompt (in this method it is used as system prompt) used to get the summary.
messages (list): The messages generated as part of a chat conversation.
llm_agent: the agent with an llm client.
cache (Cache or None): the cache client to be used for this conversation.
"""

_messages = [
{
"role": "system",
"content": """Earlier you were asked to fulfill a request. You and your team worked diligently to address that request. Here is a transcript of that conversation:""",
}
]
for message in messages:
message = copy.deepcopy(message)
message["role"] = "user"
_messages.append(message)

_messages.append(
system_msg = [
{
"role": "system",
"content": prompt,
}
)
]

messages = messages + system_msg
if llm_agent and llm_agent.client is not None:
llm_client = llm_agent.client
elif self.client is not None:
llm_client = self.client
else:
raise ValueError("No OpenAIWrapper client is found.")

response = llm_client.create(context=None, messages=_messages, cache=cache)
extracted_response = llm_client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str) and hasattr(extracted_response, "model_dump"):
return str(extracted_response.model_dump(mode="dict"))
else:
return extracted_response
response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache)
return response

def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[Agent, ChatResult]:
"""(Experimental) Initiate chats with multiple agents.
Expand Down Expand Up @@ -1021,7 +1006,12 @@ def generate_oai_reply(
return False, None
if messages is None:
messages = self._oai_messages[sender]
extracted_response = self._generate_oai_reply_from_client(
client, self._oai_system_message + messages, self.client_cache
)
return True, extracted_response

def _generate_oai_reply_from_client(self, llm_client, messages, cache):
# unroll tool_responses
all_messages = []
for message in messages:
Expand All @@ -1035,13 +1025,12 @@ def generate_oai_reply(
all_messages.append(message)

# TODO: #1143 handle token limit exceeded error
response = client.create(
response = llm_client.create(
context=messages[-1].pop("context", None),
messages=self._oai_system_message + all_messages,
cache=self.client_cache,
messages=all_messages,
cache=cache,
)

extracted_response = client.extract_text_or_completion_object(response)[0]
extracted_response = llm_client.extract_text_or_completion_object(response)[0]

if extracted_response is None:
warnings.warn("Extracted_response is None.", UserWarning)
Expand All @@ -1056,7 +1045,7 @@ def generate_oai_reply(
)
for tool_call in extracted_response.get("tool_calls") or []:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
return True, extracted_response
return extracted_response

async def a_generate_oai_reply(
self,
Expand Down
141 changes: 90 additions & 51 deletions notebook/agentchat_auto_feedback_from_code_execution.ipynb

Large diffs are not rendered by default.

90 changes: 67 additions & 23 deletions notebook/agentchat_function_call_currency_calculator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"id": "dca301a4",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -122,7 +122,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 2,
"id": "9fb85afb",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -249,7 +249,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"id": "d5518947",
"metadata": {},
"outputs": [
Expand All @@ -264,9 +264,9 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_ubo7cKE3TKumGHkqGjQtZisy): currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_Ak49uR4cwLWyPKs5T2gK9bMg): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base_amount\":123.45,\"base_currency\":\"USD\",\"quote_currency\":\"EUR\"}\n",
"{\"base_amount\":123.45}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
Expand All @@ -276,7 +276,7 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_ubo7cKE3TKumGHkqGjQtZisy\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_Ak49uR4cwLWyPKs5T2gK9bMg\" *****\u001b[0m\n",
"112.22727272727272 EUR\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
Expand All @@ -302,12 +302,29 @@
"source": [
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 USD in EUR?\",\n",
" res = user_proxy.initiate_chat(\n",
" chatbot, message=\"How much is 123.45 USD in EUR?\", summary_method=\"reflection_with_llm\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "4b5a0edc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat summary: 123.45 USD is equivalent to approximately 112.23 EUR.\n"
]
}
],
"source": [
"print(\"Chat summary:\", res.summary)"
]
},
{
"cell_type": "markdown",
"id": "bd9d61cf",
Expand All @@ -326,7 +343,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"id": "7b3d8b58",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -432,7 +449,7 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_0VuU2rATuOgYrGmcBnXzPXlh): currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_G64JQKQBT2rI4vnuA4iz1vmE): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base\":{\"currency\":\"EUR\",\"amount\":112.23},\"quote_currency\":\"USD\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
Expand All @@ -444,14 +461,14 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_0VuU2rATuOgYrGmcBnXzPXlh\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_G64JQKQBT2rI4vnuA4iz1vmE\" *****\u001b[0m\n",
"{\"currency\":\"USD\",\"amount\":123.45300000000002}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"112.23 Euros is approximately 123.45 US Dollars.\n",
"112.23 Euros is equivalent to approximately 123.45 US Dollars.\n",
"\n",
"--------------------------------------------------------------------------------\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
Expand All @@ -470,15 +487,32 @@
"source": [
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 112.23 Euros in US Dollars?\",\n",
" res = user_proxy.initiate_chat(\n",
" chatbot, message=\"How much is 112.23 Euros in US Dollars?\", summary_method=\"reflection_with_llm\"\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "4799f60c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat summary: 112.23 Euros is approximately 123.45 US Dollars.\n"
]
}
],
"source": [
"print(\"Chat summary:\", res.summary)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "0064d9cd",
"metadata": {},
"outputs": [
Expand All @@ -493,7 +527,7 @@
"--------------------------------------------------------------------------------\n",
"\u001b[33mchatbot\u001b[0m (to user_proxy):\n",
"\n",
"\u001b[32m***** Suggested tool Call (call_A6lqMu7s5SyDvftTSeQTtPcj): currency_calculator *****\u001b[0m\n",
"\u001b[32m***** Suggested tool Call (call_qv2SwJHpKrG73btxNzUnYBoR): currency_calculator *****\u001b[0m\n",
"Arguments: \n",
"{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}\n",
"\u001b[32m************************************************************************************\u001b[0m\n",
Expand All @@ -505,7 +539,7 @@
"\n",
"\u001b[33muser_proxy\u001b[0m (to chatbot):\n",
"\n",
"\u001b[32m***** Response from calling tool \"call_A6lqMu7s5SyDvftTSeQTtPcj\" *****\u001b[0m\n",
"\u001b[32m***** Response from calling tool \"call_qv2SwJHpKrG73btxNzUnYBoR\" *****\u001b[0m\n",
"{\"currency\":\"EUR\",\"amount\":112.22727272727272}\n",
"\u001b[32m**********************************************************************\u001b[0m\n",
"\n",
Expand All @@ -531,19 +565,29 @@
"source": [
"with Cache.disk():\n",
" # start the conversation\n",
" user_proxy.initiate_chat(\n",
" res = user_proxy.initiate_chat(\n",
" chatbot,\n",
" message=\"How much is 123.45 US Dollars in Euros?\",\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "06137f23",
"execution_count": 15,
"id": "80b2b42c",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Chat history: [{'content': 'How much is 123.45 US Dollars in Euros?', 'role': 'assistant'}, {'tool_calls': [{'id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'function': {'arguments': '{\"base\":{\"currency\":\"USD\",\"amount\":123.45},\"quote_currency\":\"EUR\"}', 'name': 'currency_calculator'}, 'type': 'function'}], 'content': None, 'role': 'assistant'}, {'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}', 'tool_responses': [{'tool_call_id': 'call_qv2SwJHpKrG73btxNzUnYBoR', 'role': 'tool', 'content': '{\"currency\":\"EUR\",\"amount\":112.22727272727272}'}], 'role': 'tool'}, {'content': '123.45 US Dollars is approximately 112.23 Euros.', 'role': 'user'}, {'content': '', 'role': 'assistant'}, {'content': 'TERMINATE', 'role': 'user'}]\n"
]
}
],
"source": [
"print(\"Chat history:\", res.chat_history)"
]
}
],
"metadata": {
Expand Down
Loading
Loading