Skip to content

Commit

Permalink
Fix #2643 - groupchat model registration (#2696)
Browse files Browse the repository at this point in the history
* remove unused import statement

* fix #2643: register custom model clients within GroupChat

* add docs for fix #2643

* Update website/docs/topics/groupchat/using_custom_models.md

Co-authored-by: Chi Wang <[email protected]>

* Update website/docs/topics/groupchat/using_custom_models.md

Co-authored-by: Chi Wang <[email protected]>

* fix: removed unnecessary llm_config from checking agent

* fix: handle missing config or "config_list" key in config

* fix: code formatting

* Isolate method for internal agents creation

* Add unit test to verify that internal agents' client actually registers ModelClient class

* fix: function arguments formatting

* chore: prepend "select_speaker_auto_" to llm_config and model_client_cls attributes in GroupChat

* feat: use selector's llm_config for speaker selection agent if none is passed to GroupChat

* Update test/agentchat/test_groupchat.py

* Update groupchat.py - moved class parameters around, added to docstring

* Update groupchat.py - added selector to async select speaker functions

* Update test_groupchat.py - Corrected test cases for custom model client class

* Update test_groupchat.py pre-commit tidy

---------

Co-authored-by: Matteo Frattaroli <[email protected]>
Co-authored-by: Chi Wang <[email protected]>
Co-authored-by: Eric Zhu <[email protected]>
Co-authored-by: Mark Sze <[email protected]>
  • Loading branch information
5 people authored Oct 11, 2024
1 parent 32022b2 commit ec4f3c0
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 44 deletions.
123 changes: 82 additions & 41 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from ..formatting_utils import colored
from ..graph_utils import check_graph_validity, invert_disallowed_to_allowed
from ..io.base import IOStream
from ..oai.client import ModelClient
from ..runtime_logging import log_new_agent, logging_enabled
from .agent import Agent
from .chat import ChatResult
from .conversable_agent import ConversableAgent

try:
Expand Down Expand Up @@ -105,6 +105,8 @@ def custom_speaker_selection_func(
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
- send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False)
- select_speaker_auto_model_client_cls: Custom model client class for the internal speaker select agent used during 'auto' speaker selection (optional)
- select_speaker_auto_llm_config: LLM config for the internal speaker select agent used during 'auto' speaker selection (optional)
- role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system')
"""

Expand Down Expand Up @@ -142,6 +144,8 @@ def custom_speaker_selection_func(
Respond with ONLY the name of the speaker and DO NOT provide a reason."""
select_speaker_transform_messages: Optional[Any] = None
select_speaker_auto_verbose: Optional[bool] = False
select_speaker_auto_model_client_cls: Optional[Union[ModelClient, List[ModelClient]]] = None
select_speaker_auto_llm_config: Optional[Union[Dict, Literal[False]]] = None
role_for_select_speaker_messages: Optional[str] = "system"

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
Expand Down Expand Up @@ -591,6 +595,79 @@ def _finalize_speaker(self, last_speaker: Agent, final: bool, name: str, agents:
agent = self.agent_by_name(name)
return agent if agent else self.next_agent(last_speaker, agents)

def _register_client_from_config(self, agent: Agent, config: Dict):
model_client_cls_to_match = config.get("model_client_cls")
if model_client_cls_to_match:
if not self.select_speaker_auto_model_client_cls:
raise ValueError(
"A custom model was detected in the config but no 'model_client_cls' "
"was supplied for registration in GroupChat."
)

if isinstance(self.select_speaker_auto_model_client_cls, list):
# Register the first custom model client class matching the name specified in the config
matching_model_cls = [
client_cls
for client_cls in self.select_speaker_auto_model_client_cls
if client_cls.__name__ == model_client_cls_to_match
]
if len(set(matching_model_cls)) > 1:
raise RuntimeError(
f"More than one unique 'model_client_cls' with __name__ '{model_client_cls_to_match}'."
)
if not matching_model_cls:
raise ValueError(
"No model's __name__ matches the model client class "
f"'{model_client_cls_to_match}' specified in select_speaker_auto_llm_config."
)
select_speaker_auto_model_client_cls = matching_model_cls[0]
else:
# Register the only custom model client
select_speaker_auto_model_client_cls = self.select_speaker_auto_model_client_cls

agent.register_model_client(select_speaker_auto_model_client_cls)

def _register_custom_model_clients(self, agent: ConversableAgent):
if not self.select_speaker_auto_llm_config:
return

config_format_is_list = "config_list" in self.select_speaker_auto_llm_config.keys()
if config_format_is_list:
for config in self.select_speaker_auto_llm_config["config_list"]:
self._register_client_from_config(agent, config)
elif not config_format_is_list:
self._register_client_from_config(agent, self.select_speaker_auto_llm_config)

def _create_internal_agents(
self, agents, max_attempts, messages, validate_speaker_name, selector: Optional[ConversableAgent] = None
):
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# Override the selector's config if one was passed as a parameter to this class
speaker_selection_llm_config = self.select_speaker_auto_llm_config or selector.llm_config

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages={checking_agent: messages},
llm_config=speaker_selection_llm_config,
human_input_mode="NEVER",
# Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
)

# Register any custom model passed in select_speaker_auto_llm_config with the speaker_selection_agent
self._register_custom_model_clients(speaker_selection_agent)

return checking_agent, speaker_selection_agent

def _auto_select_speaker(
self,
last_speaker: Agent,
Expand Down Expand Up @@ -644,28 +721,8 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
# Two-agent chat for speaker selection

# Agent for checking the response from the speaker_select_agent
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages=(
{checking_agent: messages}
if self.select_speaker_prompt_template is not None
else {checking_agent: messages[:-1]}
),
llm_config=selector.llm_config,
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
checking_agent, speaker_selection_agent = self._create_internal_agents(
agents, max_attempts, messages, validate_speaker_name, selector
)

# Create the starting message
Expand Down Expand Up @@ -747,24 +804,8 @@ def validate_speaker_name(recipient, messages, sender, config) -> Tuple[bool, Un
# Two-agent chat for speaker selection

# Agent for checking the response from the speaker_select_agent
checking_agent = ConversableAgent("checking_agent", default_auto_reply=max_attempts)

# Register the speaker validation function with the checking agent
checking_agent.register_reply(
[ConversableAgent, None],
reply_func=validate_speaker_name, # Validate each response
remove_other_reply_funcs=True,
)

# NOTE: Do we have a speaker prompt (select_speaker_prompt_template is not None)? If we don't, we need to feed in the last message to start the nested chat

# Agent for selecting a single agent name from the response
speaker_selection_agent = ConversableAgent(
"speaker_selection_agent",
system_message=self.select_speaker_msg(agents),
chat_messages={checking_agent: messages},
llm_config=selector.llm_config,
human_input_mode="NEVER", # Suppresses some extra terminal outputs, outputs will be handled by select_speaker_auto_verbose
checking_agent, speaker_selection_agent = self._create_internal_agents(
agents, max_attempts, messages, validate_speaker_name, selector
)

# Create the starting message
Expand Down
61 changes: 58 additions & 3 deletions test/agentchat/test_groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
import io
import json
import logging
from types import SimpleNamespace
from typing import Any, Dict, List, Optional
from unittest import TestCase, mock
from unittest import mock

import pytest
from test_assistant_agent import KEY_LOC, OAI_CONFIG_LIST

import autogen
from autogen import Agent, AssistantAgent, GroupChat, GroupChatManager
Expand Down Expand Up @@ -2062,6 +2062,60 @@ def test_manager_resume_messages():
return_agent, return_message = manager.resume(messages="Let's get this conversation started.")


def test_custom_model_client():
class CustomModelClient:
def __init__(self, config, **kwargs):
print(f"CustomModelClient config: {config}")

def create(self, params):
num_of_responses = params.get("n", 1)

response = SimpleNamespace()
response.choices = []
response.model = "test_model_name"

for _ in range(num_of_responses):
text = "this is a dummy text response"
choice = SimpleNamespace()
choice.message = SimpleNamespace()
choice.message.content = text
choice.message.function_call = None
response.choices.append(choice)
return response

def message_retrieval(self, response):
choices = response.choices
return [choice.message.content for choice in choices]

def cost(self, response) -> float:
response.cost = 0
return 0

@staticmethod
def get_usage(response):
return {}

llm_config = {"config_list": [{"model": "test_model_name", "model_client_cls": "CustomModelClient"}]}

group_chat = autogen.GroupChat(
agents=[],
messages=[],
max_round=3,
select_speaker_auto_llm_config=llm_config,
select_speaker_auto_model_client_cls=CustomModelClient,
)

checking_agent, speaker_selection_agent = group_chat._create_internal_agents(
agents=[], messages=[], max_attempts=3, validate_speaker_name=(True, "test")
)

# Check that the custom model client is assigned to the speaker selection agent
assert isinstance(speaker_selection_agent.client._clients[0], CustomModelClient)

# Check that the LLM Config is assigned
assert speaker_selection_agent.client._config_list == llm_config["config_list"]


def test_select_speaker_transform_messages():
"""Tests adding transform messages to a GroupChat for speaker selection when in 'auto' mode"""

Expand Down Expand Up @@ -2127,8 +2181,9 @@ def test_select_speaker_transform_messages():
# test_select_speaker_auto_messages()
# test_manager_messages_to_string()
# test_manager_messages_from_string()
test_manager_resume_functions()
# test_manager_resume_functions()
# test_manager_resume_returns()
# test_manager_resume_messages()
# test_custom_model_client()
# test_select_speaker_transform_messages()
pass
79 changes: 79 additions & 0 deletions website/docs/topics/groupchat/using_custom_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Using Custom Models

When using `GroupChatManager` we need to pass a `GroupChat` object in the constructor, a dataclass responsible for
gathering agents, preparing messages from prompt templates and selecting speakers
(eventually using `speaker_selection_method` as described [here](customized_speaker_selection)).

To do so GroupChat internally initializes two instances of ConversableAgent.
In order to control the model clients used by the agents instantiated within the GroupChat, which already receives the
`llm_config` passed to GroupChatManager, the optional `model_client_cls` attribute can be set.


## Example
First we need to define an `llm_config` and define some agents that will partake in the group chat:
```python
from autogen import GroupChat, ConversableAgent, GroupChatManager, UserProxyAgent
from somewhere import MyModelClient


# Define the custom model configuration
llm_config = {
"config_list": [
{
"model": "gpt-3.5-turbo",
"model_client_cls": "MyModelClient"
}
]
}

# Initialize the agents with the custom model
agent1 = ConversableAgent(
name="Agent 1",
llm_config=llm_config
)
agent1.register_model_client(model_client_cls=MyModelClient)

agent2 = ConversableAgent(
name="Agent 2",
llm_config=llm_config
)
agent2.register_model_client(model_client_cls=MyModelClient)

agent3 = ConversableAgent(
name="Agent 2",
llm_config=llm_config
)
agent3.register_model_client(model_client_cls=MyModelClient)

user_proxy = UserProxyAgent(name="user", llm_config=llm_config, code_execution_config={"use_docker": False})
user_proxy.register_model_client(MyModelClient)
```

Note that the agents definition illustrated here is minimal and might not suit your needs. The only aim is to show a
basic setup for a group chat scenario.

We then create a `GroupChat` and, if we want the underlying agents used by GroupChat to use our
custom client, we will pass it in the `model_client_cls` attribute.

Finally we create an instance of `GroupChatManager` and pass the config to it. This same config will be forwarded to
the GroupChat, that (if needed) will automatically handle registration of custom models only.

```python
# Create a GroupChat instance and add the agents
group_chat = GroupChat(agents=[agent1, agent2, agent3], messages=[], model_client_cls=MyModelClient)

# Create the GroupChatManager with the GroupChat, UserProxy, and model configuration
chat_manager = GroupChatManager(groupchat=group_chat, llm_config=llm_config)
chat_manager.register_model_client(model_client_cls=MyModelClient)

# Initiate the chat using the UserProxy
user_proxy.initiate_chat(chat_manager, initial_message="Suggest me the most trending papers in microbiology that you think might interest me")

```

This attribute can either be a class or a list of classes which adheres to the `ModelClient` protocol (see
[this link](../non-openai-models/about-using-nonopenai-models) for more info about defining a custom model client
class).

Note that it is not necessary to define a `model_client_cls` when working with Azure OpenAI, OpenAI or other non-custom
models natively supported by the library.

0 comments on commit ec4f3c0

Please sign in to comment.