Skip to content

Commit

Permalink
General Enhancements in agentchat 2.0 (microsoft#1906)
Browse files Browse the repository at this point in the history
* work in progress

* wip

* groupchat type hints

* clean up

* formatting

* formatting

* clean up

* address comments

* better comment

* updates docstring a_send

* resolve comments

* agent.py back to original format

* resolve more comments

* rename carryover type exception

* revert next_agent changes + keeping UndefinedNextagent

* fixed ciruclar dependencies?

* fix cache tests

---------

Co-authored-by: Eric Zhu <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
3 people authored Mar 9, 2024
1 parent b1e00f0 commit 139618b
Show file tree
Hide file tree
Showing 9 changed files with 201 additions and 144 deletions.
38 changes: 22 additions & 16 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import warnings
from openai import BadRequestError

from autogen.exception_utils import InvalidCarryOverType, SenderRequired

from ..coding.base import CodeExecutor
from ..coding.factory import CodeExecutorFactory

Expand Down Expand Up @@ -77,7 +79,7 @@ def __init__(
system_message: Optional[Union[str, List]] = "You are a helpful AI Assistant.",
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "TERMINATE",
human_input_mode: Literal["ALWAYS", "NEVER", "TERMINATE"] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Union[Dict, Literal[False]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
Expand Down Expand Up @@ -576,7 +578,7 @@ def send(
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> ChatResult:
):
"""Send a message to another agent.
Args:
Expand Down Expand Up @@ -608,9 +610,6 @@ def send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
Returns:
ChatResult: a ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
Expand All @@ -629,7 +628,7 @@ async def a_send(
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> ChatResult:
):
"""(async) Send a message to another agent.
Args:
Expand Down Expand Up @@ -661,9 +660,6 @@ async def a_send(
Raises:
ValueError: if the message can't be converted into a valid ChatCompletion message.
Returns:
ChatResult: an ChatResult object.
"""
message = self._process_message_before_send(message, recipient, silent)
# When the agent composes and sends the message, the role of the message is "assistant"
Expand Down Expand Up @@ -857,7 +853,7 @@ def _raise_exception_on_async_reply_functions(self) -> None:
def initiate_chat(
self,
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
max_turns: Optional[int] = None,
Expand Down Expand Up @@ -946,7 +942,7 @@ def my_summary_method(
async def a_initiate_chat(
self,
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
clear_history: bool = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
max_turns: Optional[int] = None,
Expand Down Expand Up @@ -1524,8 +1520,6 @@ def check_termination_and_human_reply(
- Tuple[bool, Union[str, Dict, None]]: A tuple containing a boolean indicating if the conversation
should be terminated, and a human reply which can be a string, a dictionary, or None.
"""
# Function implementation...

if config is None:
config = self
if messages is None:
Expand Down Expand Up @@ -1839,6 +1833,7 @@ async def a_generate_reply(
reply_func = reply_func_tuple["reply_func"]
if "exclude" in kwargs and reply_func in kwargs["exclude"]:
continue

if self._match_trigger(reply_func_tuple["trigger"], sender):
if inspect.iscoroutinefunction(reply_func):
final, reply = await reply_func(
Expand All @@ -1850,7 +1845,7 @@ async def a_generate_reply(
return reply
return self._default_auto_reply

def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Agent) -> bool:
def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List], sender: Optional[Agent]) -> bool:
"""Check if the sender matches the trigger.
Args:
Expand All @@ -1867,6 +1862,8 @@ def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List],
if trigger is None:
return sender is None
elif isinstance(trigger, str):
if sender is None:
raise SenderRequired()
return trigger == sender.name
elif isinstance(trigger, type):
return isinstance(sender, trigger)
Expand All @@ -1875,7 +1872,7 @@ def _match_trigger(self, trigger: Union[None, str, type, Agent, Callable, List],
return trigger == sender
elif isinstance(trigger, Callable):
rst = trigger(sender)
assert rst in [True, False], f"trigger {trigger} must return a boolean value."
assert isinstance(rst, bool), f"trigger {trigger} must return a boolean value."
return rst
elif isinstance(trigger, list):
return any(self._match_trigger(t, sender) for t in trigger)
Expand Down Expand Up @@ -2154,7 +2151,7 @@ def _process_carryover(self, context):
elif isinstance(carryover, list):
context["message"] = context["message"] + "\nContext: \n" + ("\n").join([t for t in carryover])
else:
raise warnings.warn(
raise InvalidCarryOverType(
"Carryover should be a string or a list of strings. Not adding carryover to the message."
)

Expand Down Expand Up @@ -2212,6 +2209,11 @@ def update_function_signature(self, func_sig: Union[str, Dict], is_remove: None)
func for func in self.llm_config["functions"] if func["name"] != func_sig
]
else:
if not isinstance(func_sig, dict):
raise ValueError(
f"The function signature must be of the type dict. Received function signature type {type(func_sig)}"
)

self._assert_valid_name(func_sig["name"])
if "functions" in self.llm_config.keys():
self.llm_config["functions"] = [
Expand Down Expand Up @@ -2248,6 +2250,10 @@ def update_tool_signature(self, tool_sig: Union[str, Dict], is_remove: None):
tool for tool in self.llm_config["tools"] if tool["function"]["name"] != tool_sig
]
else:
if not isinstance(tool_sig, dict):
raise ValueError(
f"The tool signature must be of the type dict. Received tool signature type {type(tool_sig)}"
)
self._assert_valid_name(tool_sig["function"]["name"])
if "tools" in self.llm_config.keys():
self.llm_config["tools"] = [
Expand Down
39 changes: 17 additions & 22 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,19 @@
import re
import sys
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union, Tuple, Callable
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union

from autogen.agentchat.agent import Agent
from autogen.agentchat.conversable_agent import ConversableAgent

from ..code_utils import content_str
from ..exception_utils import AgentNameConflict
from .agent import Agent
from .conversable_agent import ConversableAgent
from ..runtime_logging import logging_enabled, log_new_agent
from ..exception_utils import AgentNameConflict, NoEligibleSpeaker, UndefinedNextAgent
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
from ..runtime_logging import log_new_agent, logging_enabled

logger = logging.getLogger(__name__)


class NoEligibleSpeakerException(Exception):
"""Exception raised for early termination of a GroupChat."""

def __init__(self, message="No eligible speakers."):
self.message = message
super().__init__(self.message)


@dataclass
class GroupChat:
"""(In preview) A group chat class that contains the following data fields:
Expand Down Expand Up @@ -76,10 +68,10 @@ def custom_speaker_selection_func(
max_round: Optional[int] = 10
admin_name: Optional[str] = "Admin"
func_call_filter: Optional[bool] = True
speaker_selection_method: Optional[Union[str, Callable]] = "auto"
speaker_selection_method: Union[Literal["auto", "manual", "random", "round_robin"], Callable] = "auto"
allow_repeat_speaker: Optional[Union[bool, List[Agent]]] = None
allowed_or_disallowed_speaker_transitions: Optional[Dict] = None
speaker_transitions_type: Optional[str] = None
speaker_transitions_type: Literal["allowed", "disallowed", None] = None
enable_clear_history: Optional[bool] = False
send_introductions: Optional[bool] = False

Expand Down Expand Up @@ -212,6 +204,10 @@ def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agen
if agents is None:
agents = self.agents

# Ensure the provided list of agents is a subset of self.agents
if not set(agents).issubset(set(self.agents)):
raise UndefinedNextAgent()

# What index is the agent? (-1 if not present)
idx = self.agent_names.index(agent.name) if agent.name in self.agent_names else -1

Expand All @@ -224,6 +220,9 @@ def next_agent(self, agent: Agent, agents: Optional[List[Agent]] = None) -> Agen
if self.agents[(offset + i) % len(self.agents)] in agents:
return self.agents[(offset + i) % len(self.agents)]

# Explicitly handle cases where no valid next agent exists in the provided subset.
raise UndefinedNextAgent()

def select_speaker_msg(self, agents: Optional[List[Agent]] = None) -> str:
"""Return the system message for selecting the next speaker. This is always the *first* message in the context."""
if agents is None:
Expand Down Expand Up @@ -295,9 +294,7 @@ def _prepare_and_select_agents(
if isinstance(self.speaker_selection_method, Callable):
selected_agent = self.speaker_selection_method(last_speaker, self)
if selected_agent is None:
raise NoEligibleSpeakerException(
"Custom speaker selection function returned None. Terminating conversation."
)
raise NoEligibleSpeaker("Custom speaker selection function returned None. Terminating conversation.")
elif isinstance(selected_agent, Agent):
if selected_agent in self.agents:
return selected_agent, self.agents, None
Expand Down Expand Up @@ -378,9 +375,7 @@ def _prepare_and_select_agents(

# this condition means last_speaker is a sink in the graph, then no agents are eligible
if last_speaker not in self.allowed_speaker_transitions_dict and is_last_speaker_in_group:
raise NoEligibleSpeakerException(
f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict."
)
raise NoEligibleSpeaker(f"Last speaker {last_speaker.name} is not in the allowed_speaker_transitions_dict.")
# last_speaker is not in the group, so all agents are eligible
elif last_speaker not in self.allowed_speaker_transitions_dict and not is_last_speaker_in_group:
graph_eligible_agents = []
Expand Down Expand Up @@ -618,7 +613,7 @@ def run_chat(
else:
# admin agent is not found in the participants
raise
except NoEligibleSpeakerException:
except NoEligibleSpeaker:
# No eligible speaker, terminate the conversation
break

Expand Down
2 changes: 1 addition & 1 deletion autogen/agentchat/user_proxy_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
name: str,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "ALWAYS",
human_input_mode: Literal["ALWAYS", "TERMINATE", "NEVER"] = "ALWAYS",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
Expand Down
4 changes: 2 additions & 2 deletions autogen/agentchat/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Dict, Tuple, Callable
from typing import Any, List, Dict, Tuple, Callable
from .agent import Agent


Expand Down Expand Up @@ -53,7 +53,7 @@ def gather_usage_summary(agents: List[Agent]) -> Tuple[Dict[str, any], Dict[str,
If none of the agents incurred any cost (not having a client), then the total_usage_summary and actual_usage_summary will be `{'total_cost': 0}`.
"""

def aggregate_summary(usage_summary: Dict[str, any], agent_summary: Dict[str, any]) -> None:
def aggregate_summary(usage_summary: Dict[str, Any], agent_summary: Dict[str, Any]) -> None:
if agent_summary is None:
return
usage_summary["total_cost"] += agent_summary.get("total_cost", 0)
Expand Down
34 changes: 34 additions & 0 deletions autogen/exception_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,37 @@
class AgentNameConflict(Exception):
def __init__(self, msg="Found multiple agents with the same name.", *args, **kwargs):
super().__init__(msg, *args, **kwargs)


class NoEligibleSpeaker(Exception):
"""Exception raised for early termination of a GroupChat."""

def __init__(self, message="No eligible speakers."):
self.message = message
super().__init__(self.message)


class SenderRequired(Exception):
"""Exception raised when the sender is required but not provided."""

def __init__(self, message="Sender is required but not provided."):
self.message = message
super().__init__(self.message)


class InvalidCarryOverType(Exception):
"""Exception raised when the carryover type is invalid."""

def __init__(
self, message="Carryover should be a string or a list of strings. Not adding carryover to the message."
):
self.message = message
super().__init__(self.message)


class UndefinedNextAgent(Exception):
"""Exception raised when the provided next agents list does not overlap with agents in the group."""

def __init__(self, message="The provided agents list does not overlap with agents in the group."):
self.message = message
super().__init__(self.message)
2 changes: 1 addition & 1 deletion autogen/graph_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Union
from typing import Dict, List
import logging

from autogen.agentchat.groupchat import Agent
Expand Down
Loading

0 comments on commit 139618b

Please sign in to comment.