From 89946158eccac86006d737c56eba53e00a4bea05 Mon Sep 17 00:00:00 2001 From: ShobhitVishnoi30 Date: Wed, 17 Apr 2024 09:15:22 +0530 Subject: [PATCH] initate chats enhancement --- autogen/agentchat/chat.py | 12 +++++++++++- autogen/agentchat/conversable_agent.py | 26 ++++++++++++++++++++------ 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/autogen/agentchat/chat.py b/autogen/agentchat/chat.py index 10ad8014ed95..b527f8e0bae5 100644 --- a/autogen/agentchat/chat.py +++ b/autogen/agentchat/chat.py @@ -171,6 +171,9 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: - `"carryover"` - It can be used to specify the carryover information to be passed to this chat. If provided, we will combine this carryover with the "message" content when generating the initial chat message in `generate_init_message`. + - `"finished_chat_indexes_to_exclude_from_carryover"` - It can be used by specifying a list of indexes of the finished_chats list, + from which to exclude the summaries for carryover. If 'finished_chat_indexes_to_exclude_from_carryover' is not provided or an empty list, + then summary from all the finished chats will be taken. Returns: (list): a list of ChatResult objects corresponding to the finished chats in the chat_queue. """ @@ -182,9 +185,16 @@ def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: while current_chat_queue: chat_info = current_chat_queue.pop(0) _chat_carryover = chat_info.get("carryover", []) + finished_chat_indexes_to_exclude_from_carryover = chat_info.get( + "finished_chat_indexes_to_exclude_from_carryover", [] + ) + if isinstance(_chat_carryover, str): _chat_carryover = [_chat_carryover] - chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats] + chat_info["carryover"] = _chat_carryover + [ + r.summary for i, r in enumerate(finished_chats) if i not in finished_chat_indexes_to_exclude_from_carryover + ] + __post_carryover_processing(chat_info) sender = chat_info["sender"] chat_res = sender.initiate_chat(**chat_info) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 3ebb2d723367..8e66b58a9bb6 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1180,6 +1180,23 @@ def _reflection_with_llm( response = self._generate_oai_reply_from_client(llm_client=llm_client, messages=messages, cache=cache) return response + def _check_chat_queue_for_sender(self, chat_queue: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Check the chat queue and add the "sender" key if it's missing. + + Args: + chat_queue (List[Dict[str, Any]]): A list of dictionaries containing chat information. + + Returns: + List[Dict[str, Any]]: A new list of dictionaries with the "sender" key added if it was missing. + """ + chat_queue_with_sender = [] + for chat_info in chat_queue: + if chat_info.get("sender") is None: + chat_info["sender"] = self + chat_queue_with_sender.append(chat_info) + return chat_queue_with_sender + def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: """(Experimental) Initiate chats with multiple agents. @@ -1189,16 +1206,13 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]: Returns: a list of ChatResult objects corresponding to the finished chats in the chat_queue. """ - _chat_queue = chat_queue.copy() - for chat_info in _chat_queue: - chat_info["sender"] = self + _chat_queue = self._check_chat_queue_for_sender(chat_queue) self._finished_chats = initiate_chats(_chat_queue) return self._finished_chats async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]: - _chat_queue = chat_queue.copy() - for chat_info in _chat_queue: - chat_info["sender"] = self + + _chat_queue = self._check_chat_queue_for_sender(chat_queue) self._finished_chats = await a_initiate_chats(_chat_queue) return self._finished_chats