Skip to content

Commit

Permalink
Async version of multiple sequential chat (microsoft#1724)
Browse files Browse the repository at this point in the history
* async_initiate_chats init commit

* Fix a_get_human_input bug

* Add agentchat_multi_task_async_chats.ipynb with concurrent exampls.

* Addess the comments, Update unit test

* Add agentchat_multi_task_async_chats.ipynb to Examples.md

* Fix type for Python 3.8

---------

Co-authored-by: Qingyun Wu <[email protected]>
  • Loading branch information
randombet and qingyun-wu authored Feb 21, 2024
1 parent 7c9f099 commit 41fefbb
Show file tree
Hide file tree
Showing 5 changed files with 2,152 additions and 35 deletions.
178 changes: 145 additions & 33 deletions autogen/agentchat/chat.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
from typing import Dict, List, Any
from collections import defaultdict
from typing import Dict, List, Any, Set, Tuple
from dataclasses import dataclass
from .utils import consolidate_chat_info
import warnings
Expand All @@ -13,12 +15,15 @@ def colored(x, *args, **kwargs):


logger = logging.getLogger(__name__)
Prerequisite = Tuple[int, int]


@dataclass
class ChatResult:
"""(Experimental) The result of a chat. Almost certain to be changed."""

chat_id: int = None
"""chat id"""
chat_history: List[Dict[str, any]] = None
"""The chat history."""
summary: str = None
Expand All @@ -29,6 +34,103 @@ class ChatResult:
"""A list of human input solicited during the chat."""


def _validate_recipients(chat_queue: List[Dict[str, Any]]) -> None:
"""
Validate recipients exits and warn repetitive recipients.
"""
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)


def __create_async_prerequisites(chat_queue: List[Dict[str, Any]]) -> List[Prerequisite]:
"""
Create list of Prerequisite (prerequisite_chat_id, chat_id)
"""
prerequisites = []
for chat_info in chat_queue:
if "chat_id" not in chat_info:
raise ValueError("Each chat must have a unique id for async multi-chat execution.")
chat_id = chat_info["chat_id"]
pre_chats = chat_info.get("prerequisites", [])
for pre_chat_id in pre_chats:
if not isinstance(pre_chat_id, int):
raise ValueError("Prerequisite chat id is not int.")
prerequisites.append((chat_id, pre_chat_id))
return prerequisites


def __find_async_chat_order(chat_ids: Set[int], prerequisites: List[Prerequisite]) -> List[int]:
"""Find chat order for async execution based on the prerequisite chats
args:
num_chats: number of chats
prerequisites: List of Prerequisite (prerequisite_chat_id, chat_id)
returns:
list: a list of chat_id in order.
"""
edges = defaultdict(set)
indegree = defaultdict(int)
for pair in prerequisites:
chat, pre = pair[0], pair[1]
if chat not in edges[pre]:
indegree[chat] += 1
edges[pre].add(chat)
bfs = [i for i in chat_ids if i not in indegree]
chat_order = []
steps = len(indegree)
for _ in range(steps + 1):
if not bfs:
break
chat_order.extend(bfs)
nxt = []
for node in bfs:
if node in edges:
for course in edges[node]:
indegree[course] -= 1
if indegree[course] == 0:
nxt.append(course)
indegree.pop(course)
edges.pop(node)
bfs = nxt

if indegree:
return []
return chat_order


def __post_carryover_processing(chat_info: Dict[str, Any]):
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")


def initiate_chats(chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
"""Initiate a list of chats.
Expand Down Expand Up @@ -71,15 +173,7 @@ def my_summary_method(
"""

consolidate_chat_info(chat_queue)
receipts_set = set()
for chat_info in chat_queue:
assert "recipient" in chat_info, "recipient must be provided."
receipts_set.add(chat_info["recipient"])
if len(receipts_set) < len(chat_queue):
warnings.warn(
"Repetitive recipients detected: The chat history will be cleared by default if a recipient appears more than once. To retain the chat history, please set 'clear_history=False' in the configuration of the repeating agent.",
UserWarning,
)
_validate_recipients(chat_queue)
current_chat_queue = chat_queue.copy()
finished_chats = []
while current_chat_queue:
Expand All @@ -88,30 +182,48 @@ def my_summary_method(
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [r.summary for r in finished_chats]
if "message" not in chat_info:
warnings.warn(
"message is not provided in a chat_queue entry. input() will be called to get the initial message.",
UserWarning,
)
chat_info["recipient"]
print_carryover = (
("\n").join([t for t in chat_info["carryover"]])
if isinstance(chat_info["carryover"], list)
else chat_info["carryover"]
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
print(
colored(
"Start a new chat with the following message: \n"
+ chat_info.get("message")
+ "\n\nWith the following carryover: \n"
+ print_carryover,
"blue",
),
flush=True,
)
print(colored("\n" + "*" * 80, "blue"), flush=True, sep="")
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = sender.initiate_chat(**chat_info)
finished_chats.append(chat_res)
return finished_chats


async def a_initiate_chats(chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
"""(async) Initiate a list of chats.
args:
Please refer to `initiate_chats`.
returns:
(Dict): a dict of ChatId: ChatResult corresponding to the finished chats in the chat_queue.
"""

consolidate_chat_info(chat_queue)
_validate_recipients(chat_queue)
chat_book = {chat_info["chat_id"]: chat_info for chat_info in chat_queue}
num_chats = chat_book.keys()
prerequisites = __create_async_prerequisites(chat_queue)
chat_order_by_id = __find_async_chat_order(num_chats, prerequisites)
finished_chats = dict()
for chat_id in chat_order_by_id:
chat_info = chat_book[chat_id]
condition = asyncio.Condition()
prerequisite_chat_ids = chat_info.get("prerequisites", [])
async with condition:
await condition.wait_for(lambda: all([id in finished_chats for id in prerequisite_chat_ids]))
# Do the actual work here.
_chat_carryover = chat_info.get("carryover", [])
if isinstance(_chat_carryover, str):
_chat_carryover = [_chat_carryover]
chat_info["carryover"] = _chat_carryover + [
finished_chats[pre_id].summary for pre_id in prerequisite_chat_ids
]
__post_carryover_processing(chat_info)
sender = chat_info["sender"]
chat_res = await sender.a_initiate_chat(**chat_info)
chat_res.chat_id = chat_id
finished_chats[chat_id] = chat_res

return finished_chats
11 changes: 9 additions & 2 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
infer_lang,
)
from .utils import gather_usage_summary, consolidate_chat_info
from .chat import ChatResult, initiate_chats
from .chat import ChatResult, initiate_chats, a_initiate_chats


from ..function_utils import get_function_schema, load_basemodels_if_needed, serialize_to_str
Expand Down Expand Up @@ -985,6 +985,13 @@ def initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> List[ChatResult]:
self._finished_chats = initiate_chats(_chat_queue)
return self._finished_chats

async def a_initiate_chats(self, chat_queue: List[Dict[str, Any]]) -> Dict[int, ChatResult]:
_chat_queue = chat_queue.copy()
for chat_info in _chat_queue:
chat_info["sender"] = self
self._finished_chats = await a_initiate_chats(_chat_queue)
return self._finished_chats

def get_chat_results(self, chat_index: Optional[int] = None) -> Union[List[ChatResult], ChatResult]:
"""A summary from the finished chats of particular agents."""
if chat_index is not None:
Expand Down Expand Up @@ -1766,7 +1773,7 @@ async def a_get_human_input(self, prompt: str) -> str:
str: human input.
"""
reply = input(prompt)
self._human_inputs.append(reply)
self._human_input.append(reply)
return reply

def run_code(self, code, **kwargs):
Expand Down
Loading

0 comments on commit 41fefbb

Please sign in to comment.