Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async version of multiple sequential chat #1724

Merged
merged 8 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 145 additions & 33 deletions autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
from typing import Dict, List, Any
from collections import defaultdict
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass
from .utils import consolidate_chat_info
import warnings
Expand All @@ -13,12 +15,15 @@ def colored(x, *args, **kwargs):


logger = logging.getLogger(__name__)
Prerequisite = Tuple[int, int]


@dataclass
class ChatResult:
"""(Experimental) The result of a chat. Almost certain to be changed."""

chat_id: int = None
"""chat id"""
chat_history: List[Dict[str, any]] = None
"""The chat history."""
summary: str = None
Expand All @@ -29,6 +34,103 @@ class ChatResult:
"""A list of human input solicited during the chat."""


def _validate_recipients(chat_queue: List[Dict[str, Any]]) -> None:
"""
Validate recipients exits and warn repetitive recipients.
"""
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)


def __create_async_prerequisites(chat_queue: List[Dict[str, Any]]) -> List[Prerequisite]:
"""
Create list of Prerequisite (prerequisite_chat_id, chat_id)
"""
prerequisites = []
for chat_info in chat_queue:
if "chat_id" not in chat_info:
raise ValueError("Each chat must have a unique id for async multi-chat execution.")
chat_id = chat_info["chat_id"]
pre_chats = chat_info.get("prerequisites", [])
for pre_chat_id in pre_chats:
if not isinstance(pre_chat_id, int):
raise ValueError("Prerequisite chat id is not int.")
prerequisites.append((chat_id, pre_chat_id))
return prerequisites


def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite]) -> List[int]:
"""Find chat order for async execution based on the prerequisite chats
args:
num_chats: number of chats
prerequisites: List of Prerequisite (prerequisite_chat_id, chat_id)
returns:
list: a list of chat_id in order.
"""
edges = defaultdict(set)
indegree = defaultdict(int)
for pair in prerequisites:
chat, pre = pair[0], pair[1]
if chat not in edges[pre]:
indegree[chat] += 1
edges[pre].add(chat)
bfs = [i for i in chat_ids if i not in indegree]
chat_order = []
steps = len(indegree)
for _ in range(steps + 1):
if not bfs:
break
chat_order.extend(bfs)
nxt = []
for node in bfs:
if node in edges:
for course in edges[node]:
indegree[course] -= 1
if indegree[course] == 0:
nxt.append(course)
indegree.pop(course)
edges.pop(node)
bfs = nxt

if indegree:
return []
return chat_order


def __post_carryover_processing(chat_info: Dict[str, Any]):
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")


def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""Initiate a list of chats.
Expand Down Expand Up @@ -71,15 +173,7 @@ def my_summary_method(
"""

consolidate_chat_info(chat_queue)
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)
_validate_recipients(chat_queue)
current_chat_queue = chat_queue.copy()
finished_chats = []
while current_chat_queue:
Expand All @@ -88,30 +182,48 @@ def my_summary_method(
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats]
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
chat_info["recipient"]
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
finished_chats.append(chat_res)
return finished_chats


async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
"""(async) Initiate a list of chats.
args:
Please refer to `initiate_chats`.
returns:
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""

consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chats = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
condition = asyncio.Condition()
prerequisite_chat_ids = chat_info.get("prerequisites", [])
async with condition:
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
# Do the actual work here.
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = await sender.a_initiate_chat(**chat_info)
chat_res.chat_id = chat_id
finished_chats[chat_id] = chat_res

return finished_chats
11 changes: 9 additions & 2 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
infer_lang,
)
from .utils import gather_usage_summary, consolidate_chat_info
from .chat import ChatResult, initiate_chats
from .chat import ChatResult, initiate_chats, a_initiate_chats


from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
Expand Down Expand Up @@ -985,6 +985,13 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
self._finished_chats = initiate_chats(_chat_queue)
return self._finished_chats

async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
_chat_queue = chat_queue.copy()
for chat_info in _chat_queue:
chat_info["sender"] = self
self._finished_chats = await a_initiate_chats(_chat_queue)
return self._finished_chats

def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]:
"""A summary from the finished chats of particular agents."""
if chat_index is not None:
Expand Down Expand Up @@ -1766,7 +1773,7 @@ async def a_get_human_input(self, prompt: str) -> str:
str: human input.
"""
reply = input(prompt)
self._human_inputs.append(reply)
self._human_input.append(reply)
return reply

def run_code(self, code, **kwargs):
Expand Down
Loading
Loading