|
6 | 6 | import logging
|
7 | 7 | import re
|
8 | 8 | from collections import defaultdict
|
| 9 | +from functools import partial |
9 | 10 | from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union
|
10 | 11 | import warnings
|
11 | 12 | from openai import BadRequestError
|
@@ -325,6 +326,80 @@ def reply_func(
|
325 | 326 | if ignore_async_in_sync_chat and inspect.iscoroutinefunction(reply_func):
|
326 | 327 | self._ignore_async_func_in_sync_chat_list.append(reply_func)
|
327 | 328 |
|
| 329 | + @staticmethod |
| 330 | + def _summary_from_nested_chats( |
| 331 | + chat_queue: List[Dict[str, Any]], recipient: Agent, messages: Union[str, Callable], sender: Agent, config: Any |
| 332 | + ) -> Tuple[bool, str]: |
| 333 | + """A simple chat reply function. |
| 334 | + This function initiate one or a sequence of chats between the "recipient" and the agents in the |
| 335 | + chat_queue. |
| 336 | +
|
| 337 | + It extracts and returns a summary from the nested chat based on the "summary_method" in each chat in chat_queue. |
| 338 | +
|
| 339 | + Returns: |
| 340 | + Tuple[bool, str]: A tuple where the first element indicates the completion of the chat, and the second element contains the summary of the last chat if any chats were initiated. |
| 341 | + """ |
| 342 | + last_msg = messages[-1].get("content") |
| 343 | + chat_to_run = [] |
| 344 | + for i, c in enumerate(chat_queue): |
| 345 | + current_c = c.copy() |
| 346 | + message = current_c.get("message") |
| 347 | + # If message is not provided in chat_queue, we by default use the last message from the original chat history as the first message in this nested chat (for the first chat in the chat queue). |
| 348 | + # NOTE: This setting is prone to change. |
| 349 | + if message is None and i == 0: |
| 350 | + message = last_msg |
| 351 | + if callable(message): |
| 352 | + message = message(recipient, messages, sender, config) |
| 353 | + # We only run chat that has a valid message. NOTE: This is prone to change dependin on applications. |
| 354 | + if message: |
| 355 | + current_c["message"] = message |
| 356 | + chat_to_run.append(current_c) |
| 357 | + if not chat_to_run: |
| 358 | + return True, None |
| 359 | + res = recipient.initiate_chats(chat_to_run) |
| 360 | + return True, res[-1].summary |
| 361 | + |
| 362 | + def register_nested_chats( |
| 363 | + self, |
| 364 | + chat_queue: List[Dict[str, Any]], |
| 365 | + trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List] = [Agent, None], |
| 366 | + reply_func_from_nested_chats: Union[str, Callable] = "summary_from_nested_chats", |
| 367 | + position: int = 2, |
| 368 | + **kwargs, |
| 369 | + ) -> None: |
| 370 | + """Register a nested chat reply function. |
| 371 | + Args: |
| 372 | + chat_queue (list): a list of chat objects to be initiated. |
| 373 | + trigger (Agent class, str, Agent instance, callable, or list): Default to [Agent, None]. Ref to `register_reply` for details. |
| 374 | + reply_func_from_nested_chats (Callable, str): the reply function for the nested chat. |
| 375 | + The function takes a chat_queue for nested chat, recipient agent, a list of messages, a sender agent and a config as input and returns a reply message. |
| 376 | + Default to "summary_from_nested_chats", which corresponds to a built-in reply function that get summary from the nested chat_queue. |
| 377 | + ```python |
| 378 | + def reply_func_from_nested_chats( |
| 379 | + chat_queue: List[Dict], |
| 380 | + recipient: ConversableAgent, |
| 381 | + messages: Optional[List[Dict]] = None, |
| 382 | + sender: Optional[Agent] = None, |
| 383 | + config: Optional[Any] = None, |
| 384 | + ) -> Tuple[bool, Union[str, Dict, None]]: |
| 385 | + ``` |
| 386 | + position (int): Ref to `register_reply` for details. Default to 2. It means we first check the termination and human reply, then check the registered nested chat reply. |
| 387 | + kwargs: Ref to `register_reply` for details. |
| 388 | + """ |
| 389 | + if reply_func_from_nested_chats == "summary_from_nested_chats": |
| 390 | + reply_func_from_nested_chats = self._summary_from_nested_chats |
| 391 | + if not callable(reply_func_from_nested_chats): |
| 392 | + raise ValueError("reply_func_from_nested_chats must be a callable") |
| 393 | + reply_func = partial(reply_func_from_nested_chats, chat_queue) |
| 394 | + self.register_reply( |
| 395 | + trigger, |
| 396 | + reply_func, |
| 397 | + position, |
| 398 | + kwargs.get("config"), |
| 399 | + kwargs.get("reset_config"), |
| 400 | + ignore_async_in_sync_chat=kwargs.get("ignore_async_in_sync_chat"), |
| 401 | + ) |
| 402 | + |
328 | 403 | @property
|
329 | 404 | def system_message(self) -> str:
|
330 | 405 | """Return the system message."""
|
@@ -477,7 +552,7 @@ def _process_message_before_send(
|
477 | 552 | """Process the message before sending it to the recipient."""
|
478 | 553 | hook_list = self.hook_lists["process_message_before_send"]
|
479 | 554 | for hook in hook_list:
|
480 |
| - message = hook(message, recipient, silent) |
| 555 | + message = hook(sender=self, message=message, recipient=recipient, silent=silent) |
481 | 556 | return message
|
482 | 557 |
|
483 | 558 | def send(
|
@@ -2054,15 +2129,18 @@ async def a_generate_init_message(self, **context) -> Union[str, Dict]:
|
2054 | 2129 | self._process_carryover(context)
|
2055 | 2130 | return context["message"]
|
2056 | 2131 |
|
2057 |
| - def register_function(self, function_map: Dict[str, Callable]): |
| 2132 | + def register_function(self, function_map: Dict[str, Union[Callable, None]]): |
2058 | 2133 | """Register functions to the agent.
|
2059 | 2134 |
|
2060 | 2135 | Args:
|
2061 |
| - function_map: a dictionary mapping function names to functions. |
| 2136 | + function_map: a dictionary mapping function names to functions. if function_map[name] is None, the function will be removed from the function_map. |
2062 | 2137 | """
|
2063 |
| - for name in function_map.keys(): |
| 2138 | + for name, func in function_map.items(): |
2064 | 2139 | self._assert_valid_name(name)
|
| 2140 | + if func is None and name not in self._function_map.keys(): |
| 2141 | + warnings.warn(f"The function {name} to remove doesn't exist", name) |
2065 | 2142 | self._function_map.update(function_map)
|
| 2143 | + self._function_map = {k: v for k, v in self._function_map.items() if v is not None} |
2066 | 2144 |
|
2067 | 2145 | def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None):
|
2068 | 2146 | """update a function_signature in the LLM configuration for function_call.
|
|
0 commit comments