Skip to content

Commit

Permalink
Fix some type annotations and edge cases (microsoft#572)
Browse files Browse the repository at this point in the history
* Fix some type annotations in agents

This fixes some errors in type annotations of `ConversableAgent`,
`UserProxyAgent`, `GroupChat` and `AssistantAgent` by adjusting the type
signature according to the actual implementation. There should be no
change in code behavior.

* Fix agent types in `GroupChat`

Some `Agent`s are actually required to be `ConversableAgent` because
they are used as one.

* Convert str message to dict before printing message

* Revert back to Agent for GroupChat

* GroupChat revert update

---------

Co-authored-by: Beibin Li <[email protected]>
Co-authored-by: Beibin Li <[email protected]>
  • Loading branch information
3 people authored Nov 16, 2023
1 parent 39fb143 commit 37064b1
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 19 deletions.
6 changes: 3 additions & 3 deletions autogen/agentchat/assistant_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .conversable_agent import ConversableAgent
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, Literal, Optional, Union


class AssistantAgent(ConversableAgent):
Expand Down Expand Up @@ -30,11 +30,11 @@ def __init__(
self,
name: str,
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
llm_config: Optional[Union[Dict, bool]] = None,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
is_termination_msg: Optional[Callable[[Dict], bool]] = None,
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "NEVER",
code_execution_config: Optional[Union[Dict, bool]] = False,
code_execution_config: Optional[Union[Dict, Literal[False]]] = False,
**kwargs,
):
"""
Expand Down
28 changes: 17 additions & 11 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import json
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Type, Union
from autogen import OpenAIWrapper
from .agent import Agent
from autogen.code_utils import (
Expand Down Expand Up @@ -45,6 +45,8 @@ class ConversableAgent(Agent):
}
MAX_CONSECUTIVE_AUTO_REPLY = 100 # maximum number of consecutive auto replies (subject to future change)

llm_config: Union[Dict, Literal[False]]

def __init__(
self,
name: str,
Expand All @@ -53,8 +55,8 @@ def __init__(
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "TERMINATE",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Optional[Union[Dict, bool]] = None,
llm_config: Optional[Union[Dict, bool]] = None,
code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
llm_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
):
"""
Expand Down Expand Up @@ -114,7 +116,9 @@ def __init__(
self.llm_config.update(llm_config)
self.client = OpenAIWrapper(**self.llm_config)

self._code_execution_config = {} if code_execution_config is None else code_execution_config
self._code_execution_config: Union[Dict, Literal[False]] = (
{} if code_execution_config is None else code_execution_config
)
self.human_input_mode = human_input_mode
self._max_consecutive_auto_reply = (
max_consecutive_auto_reply if max_consecutive_auto_reply is not None else self.MAX_CONSECUTIVE_AUTO_REPLY
Expand All @@ -135,7 +139,7 @@ def register_reply(
self,
trigger: Union[Type[Agent], str, Agent, Callable[[Agent], bool], List],
reply_func: Callable,
position: Optional[int] = 0,
position: int = 0,
config: Optional[Any] = None,
reset_config: Optional[Callable] = None,
):
Expand All @@ -162,7 +166,7 @@ def reply_func(
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
) -> Union[str, Dict, None]:
) -> Tuple[bool, Union[str, Dict, None]]:
```
position (int): the position of the reply function in the reply function list.
The function registered later will be checked earlier by default.
Expand Down Expand Up @@ -221,7 +225,7 @@ def chat_messages(self) -> Dict[Agent, List[Dict]]:
"""A dictionary of conversations from agent to list of messages."""
return self._oai_messages

def last_message(self, agent: Optional[Agent] = None) -> Dict:
def last_message(self, agent: Optional[Agent] = None) -> Optional[Dict]:
"""The last message exchanged with the agent.
Args:
Expand Down Expand Up @@ -304,7 +308,7 @@ def send(
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> bool:
):
"""Send a message to another agent.
Args:
Expand Down Expand Up @@ -353,7 +357,7 @@ async def a_send(
recipient: Agent,
request_reply: Optional[bool] = None,
silent: Optional[bool] = False,
) -> bool:
):
"""(async) Send a message to another agent.
Args:
Expand Down Expand Up @@ -399,6 +403,8 @@ async def a_send(
def _print_received_message(self, message: Union[Dict, str], sender: Agent):
# print the message received
print(colored(sender.name, "yellow"), "(to", f"{self.name}):\n", flush=True)
message = self._message_to_dict(message)

if message.get("role") == "function":
func_print = f"***** Response from calling function \"{message['name']}\" *****"
print(colored(func_print, "green"), flush=True)
Expand Down Expand Up @@ -606,7 +612,7 @@ def generate_oai_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
config: Optional[OpenAIWrapper] = None,
) -> Tuple[bool, Union[str, Dict, None]]:
"""Generate a reply using autogen.oai."""
client = self.client if config is None else config
Expand All @@ -625,7 +631,7 @@ def generate_code_execution_reply(
self,
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[Any] = None,
config: Optional[Union[Dict, Literal[False]]] = None,
):
"""Generate a reply using code execution."""
code_execution_config = config if config is not None else self._code_execution_config
Expand Down
5 changes: 3 additions & 2 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from dataclasses import dataclass
import logging
import sys
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

from .agent import Agent
from .conversable_agent import ConversableAgent
import logging

logger = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions autogen/agentchat/user_proxy_agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .conversable_agent import ConversableAgent
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, Literal, Optional, Union


class UserProxyAgent(ConversableAgent):
Expand All @@ -22,9 +22,9 @@ def __init__(
max_consecutive_auto_reply: Optional[int] = None,
human_input_mode: Optional[str] = "ALWAYS",
function_map: Optional[Dict[str, Callable]] = None,
code_execution_config: Optional[Union[Dict, bool]] = None,
code_execution_config: Optional[Union[Dict, Literal[False]]] = None,
default_auto_reply: Optional[Union[str, Dict, None]] = "",
llm_config: Optional[Union[Dict, bool]] = False,
llm_config: Optional[Union[Dict, Literal[False]]] = False,
system_message: Optional[str] = "",
):
"""
Expand Down

0 comments on commit 37064b1

Please sign in to comment.