From deb5f94ffd6430a680428e4767343f6a0d2832d6 Mon Sep 17 00:00:00 2001 From: Qingyun Wu Date: Tue, 6 Feb 2024 20:32:27 -0500 Subject: [PATCH] Error handling in getting LLM-based summary (#1567) * summary exception * badrequest error * test * skip reason * error --- autogen/agentchat/conversable_agent.py | 7 ++- test/agentchat/test_chats.py | 69 +++++++++++++++++++++++++- 2 files changed, 73 insertions(+), 3 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 32ccf5ea5817..287476387068 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -8,6 +8,7 @@ from collections import defaultdict from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union import warnings +from openai import BadRequestError from ..oai.client import OpenAIWrapper, ModelClient from ..cache.cache import Cache @@ -832,8 +833,10 @@ def _summarize_chat( if not isinstance(prompt, str): raise ValueError("The summary_prompt must be a string.") msg_list = agent._groupchat.messages if hasattr(agent, "_groupchat") else agent.chat_messages[self] - - summary = self._llm_response_preparer(prompt, msg_list, llm_agent=agent, cache=cache) + try: + summary = self._llm_response_preparer(prompt, msg_list, llm_agent=agent, cache=cache) + except BadRequestError as e: + warnings.warn(f"Cannot extract summary using reflection_with_llm: {e}", UserWarning) else: warnings.warn("No summary_method provided or summary_method is not supported: ") return summary diff --git a/test/agentchat/test_chats.py b/test/agentchat/test_chats.py index b6fd39c695e9..36de8d120e35 100644 --- a/test/agentchat/test_chats.py +++ b/test/agentchat/test_chats.py @@ -4,6 +4,10 @@ import pytest from conftest import skip_openai import autogen +from typing import Literal + +from pydantic import BaseModel, Field +from typing_extensions import Annotated @pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") @@ -127,6 +131,7 @@ def test_chats(): financial_tasks = [ """What are the full names of NVDA and TESLA.""", + """Investigate the reasons.""", """Pros and cons of the companies I'm interested in. Keep it short.""", ] @@ -197,6 +202,68 @@ def test_chats(): # print(blogpost.summary, insights_and_blogpost) +@pytest.mark.skipif(skip_openai, reason="requested to skip openai tests") +def test_chats_w_func(): + config_list = autogen.config_list_from_json( + OAI_CONFIG_LIST, + file_location=KEY_LOC, + ) + + llm_config = { + "config_list": config_list, + "timeout": 120, + } + + chatbot = autogen.AssistantAgent( + name="chatbot", + system_message="For currency exchange tasks, only use the functions you have been provided with. Reply TERMINATE when the task is done.", + llm_config=llm_config, + ) + + # create a UserProxyAgent instance named "user_proxy" + user_proxy = autogen.UserProxyAgent( + name="user_proxy", + is_termination_msg=lambda x: x.get("content", "") and x.get("content", "").rstrip().endswith("TERMINATE"), + human_input_mode="NEVER", + max_consecutive_auto_reply=10, + code_execution_config={ + "last_n_messages": 1, + "work_dir": "tasks", + "use_docker": False, + }, + ) + + CurrencySymbol = Literal["USD", "EUR"] + + def exchange_rate(base_currency: CurrencySymbol, quote_currency: CurrencySymbol) -> float: + if base_currency == quote_currency: + return 1.0 + elif base_currency == "USD" and quote_currency == "EUR": + return 1 / 1.1 + elif base_currency == "EUR" and quote_currency == "USD": + return 1.1 + else: + raise ValueError(f"Unknown currencies {base_currency}, {quote_currency}") + + @user_proxy.register_for_execution() + @chatbot.register_for_llm(description="Currency exchange calculator.") + def currency_calculator( + base_amount: Annotated[float, "Amount of currency in base_currency"], + base_currency: Annotated[CurrencySymbol, "Base currency"] = "USD", + quote_currency: Annotated[CurrencySymbol, "Quote currency"] = "EUR", + ) -> str: + quote_amount = exchange_rate(base_currency, quote_currency) * base_amount + return f"{quote_amount} {quote_currency}" + + res = user_proxy.initiate_chat( + chatbot, + message="How much is 123.45 USD in EUR?", + summary_method="reflection_with_llm", + ) + print(res.summary, res.cost, res.chat_history) + + if __name__ == "__main__": # test_chats() - test_chats_group() + # test_chats_group() + test_chats_w_func()