From 10a6d060b5e9398f10d72c066f032560405ad431 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 07:13:09 -0500 Subject: [PATCH 01/30] add client interface, response protocol, and move code into openai client class --- autogen/oai/client.py | 430 ++++++++++++++++++++++++------------------ 1 file changed, 249 insertions(+), 181 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 65ad14254091..7d5b327d591d 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -8,6 +8,8 @@ from flaml.automl.logger import logger_formatter from pydantic import BaseModel +from abc import ABC, abstractmethod +from typing import Protocol from autogen.oai import completion @@ -49,6 +51,221 @@ logger.addHandler(_ch) +def template_formatter( + template: str | Callable | None, + context: Optional[Dict] = None, + allow_format_str_template: Optional[bool] = False, +): + if not context or template is None: + return template + if isinstance(template, str): + return template.format(**context) if allow_format_str_template else template + return template(context) + + +class Client(ABC): + """ + A client class must implement the following methods: + - create must return a response object that implements the ClientResponseProtocol + - cost + + This class is used to create a client that can be used by OpenAIWrapper. + It mimicks the OpenAI class, but allows for custom clients to be used. + """ + + RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] + + class ClientResponseProtocol(Protocol): + class Choice(Protocol): + class Message(Protocol): + content: str | None + function_call: str | None + + choices: List[Choice] + config_id: int + cost: float + pass_filter: bool + model: str + + def update(self, config: Dict): + # update with anything here + pass + + @abstractmethod + def create(self, params) -> ClientResponseProtocol: + pass + + @abstractmethod + def cost(self, response: ClientResponseProtocol) -> float: + pass + + @staticmethod + def get_usage(response: ClientResponseProtocol) -> Dict: + return None + + +class OpenAIClient(Client): + def __init__(self, client): + self.client = client + + def create(self, params: Dict[str, Any]) -> ChatCompletion: + """Create a completion for a given config using openai's client. + + Args: + client: The openai client. + params: The params for the completion. + + Returns: + The completion. + """ + completions: Completions = self.client.chat.completions if "messages" in params else self.client.completions # type: ignore [attr-defined] + # If streaming is enabled and has messages, then iterate over the chunks of the response. + if params.get("stream", False) and "messages" in params: + response_contents = [""] * params.get("n", 1) + finish_reasons = [""] * params.get("n", 1) + completion_tokens = 0 + + # Set the terminal text color to green + print("\033[32m", end="") + + # Prepare for potential function call + full_function_call: Optional[Dict[str, Any]] = None + full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None + + # Send the chat completion request to OpenAI's API and process the response in chunks + for chunk in completions.create(**params): + if chunk.choices: + for choice in chunk.choices: + content = choice.delta.content + tool_calls_chunks = choice.delta.tool_calls + finish_reasons[choice.index] = choice.finish_reason + + # todo: remove this after function calls are removed from the API + # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail + # begin block + function_call_chunk = ( + choice.delta.function_call if hasattr(choice.delta, "function_call") else None + ) + # Handle function call + if function_call_chunk: + # Handle function call + if function_call_chunk: + full_function_call, completion_tokens = self._update_function_call_from_chunk( + function_call_chunk, full_function_call, completion_tokens + ) + if not content: + continue + # end block + + # Handle tool calls + if tool_calls_chunks: + for tool_calls_chunk in tool_calls_chunks: + # the current tool call to be reconstructed + ix = tool_calls_chunk.index + if full_tool_calls is None: + full_tool_calls = [] + if ix >= len(full_tool_calls): + # in case ix is not sequential + full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) + + full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk( + tool_calls_chunk, full_tool_calls[ix], completion_tokens + ) + if not content: + continue + + # End handle tool calls + + # If content is present, print it to the terminal and update response variables + if content is not None: + print(content, end="", flush=True) + response_contents[choice.index] += content + completion_tokens += 1 + else: + # print() + pass + + # Reset the terminal text color + print("\033[0m\n") + + # Prepare the final ChatCompletion object based on the accumulated data + model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API + prompt_tokens = count_token(params["messages"], model) + response = ChatCompletion( + id=chunk.id, + model=chunk.model, + created=chunk.created, + object="chat.completion", + choices=[], + usage=CompletionUsage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ), + ) + for i in range(len(response_contents)): + if OPENAIVERSION >= "1.5": # pragma: no cover + # OpenAI versions 1.5.0 and above + choice = Choice( + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, + ), + logprobs=None, + ) + else: + # OpenAI versions below 1.5.0 + choice = Choice( # type: ignore [call-arg] + index=i, + finish_reason=finish_reasons[i], + message=ChatCompletionMessage( + role="assistant", + content=response_contents[i], + function_call=full_function_call, + tool_calls=full_tool_calls, + ), + ) + + response.choices.append(choice) + else: + # If streaming is not enabled, send a regular chat completion request + params = params.copy() + params["stream"] = False + response = completions.create(**params) + + return response + + def cost(self, response: Union[ChatCompletion, Completion]) -> float: + """Calculate the cost of the response.""" + model = response.model + if model not in OAI_PRICE1K: + # TODO: add logging to warn that the model is not found + logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) + return 0 + + n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr] + n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr] + tmp_price1K = OAI_PRICE1K[model] + # First value is input token rate, second value is output token rate + if isinstance(tmp_price1K, tuple): + return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] + return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] + + @staticmethod + def get_usage(response: Union[ChatCompletion, Completion]) -> Dict: + return { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + "cost": response.cost, + "model": response.model, + } + + class OpenAIWrapper: """A wrapper class for openai client.""" @@ -144,9 +361,9 @@ def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> Open if openai_config["azure_deployment"] is not None: openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) - client = AzureOpenAI(**openai_config) + client = OpenAIClient(AzureOpenAI(**openai_config)) else: - client = OpenAI(**openai_config) + client = OpenAIClient(OpenAI(**openai_config)) return client @classmethod @@ -239,6 +456,9 @@ def yes_or_no_filter(context, response): filter_func = extra_kwargs.get("filter_func") context = extra_kwargs.get("context") + total_usage = None + actual_usage = None + # Try to load the response from cache if cache_seed is not None: with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache: @@ -251,19 +471,20 @@ def yes_or_no_filter(context, response): response.cost # type: ignore [attr-defined] except AttributeError: # update attribute if cost is not calculated - response.cost = self.cost(response) + response.cost = client.cost(response) cache.set(key, response) - self._update_usage_summary(response, use_cache=True) + total_usage = client.get_usage(response) # check the filter pass_filter = filter_func is None or filter_func(context=context, response=response) if pass_filter or i == last: # Return the response if it passes the filter or it is the last client response.config_id = i response.pass_filter = pass_filter + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) return response continue # filter is not passed; try the next config try: - response = self._completions_create(client, params) + response = client.create(params) except APIError as err: error_code = getattr(err, "code", None) if error_code == "content_filter": @@ -274,8 +495,10 @@ def yes_or_no_filter(context, response): raise else: # add cost calculation before caching no matter filter is passed or not - response.cost = self.cost(response) - self._update_usage_summary(response, use_cache=False) + response.cost = client.cost(response) + actual_usage = client.get_usage(response) + total_usage = actual_usage.copy() if actual_usage is not None else total_usage + self._update_usage(actual_usage=actual_usage, total_usage=total_usage) if cache_seed is not None: # Cache the response with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache: @@ -401,170 +624,31 @@ def _update_tool_calls_from_chunk( else: raise RuntimeError("Tool call is not found, this should not happen.") - def _completions_create(self, client: OpenAI, params: Dict[str, Any]) -> ChatCompletion: - """Create a completion for a given config using openai's client. - - Args: - client: The openai client. - params: The params for the completion. - - Returns: - The completion. - """ - completions: Completions = client.chat.completions if "messages" in params else client.completions # type: ignore [attr-defined] - # If streaming is enabled and has messages, then iterate over the chunks of the response. - if params.get("stream", False) and "messages" in params: - response_contents = [""] * params.get("n", 1) - finish_reasons = [""] * params.get("n", 1) - completion_tokens = 0 - - # Set the terminal text color to green - print("\033[32m", end="") - - # Prepare for potential function call - full_function_call: Optional[Dict[str, Any]] = None - full_tool_calls: Optional[List[Optional[Dict[str, Any]]]] = None - - # Send the chat completion request to OpenAI's API and process the response in chunks - for chunk in completions.create(**params): - if chunk.choices: - for choice in chunk.choices: - content = choice.delta.content - tool_calls_chunks = choice.delta.tool_calls - finish_reasons[choice.index] = choice.finish_reason - - # todo: remove this after function calls are removed from the API - # the code should work regardless of whether function calls are removed or not, but test_chat_functions_stream should fail - # begin block - function_call_chunk = ( - choice.delta.function_call if hasattr(choice.delta, "function_call") else None - ) - # Handle function call - if function_call_chunk: - # Handle function call - if function_call_chunk: - full_function_call, completion_tokens = self._update_function_call_from_chunk( - function_call_chunk, full_function_call, completion_tokens - ) - if not content: - continue - # end block - - # Handle tool calls - if tool_calls_chunks: - for tool_calls_chunk in tool_calls_chunks: - # the current tool call to be reconstructed - ix = tool_calls_chunk.index - if full_tool_calls is None: - full_tool_calls = [] - if ix >= len(full_tool_calls): - # in case ix is not sequential - full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) - - full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk( - tool_calls_chunk, full_tool_calls[ix], completion_tokens - ) - if not content: - continue - - # End handle tool calls - - # If content is present, print it to the terminal and update response variables - if content is not None: - print(content, end="", flush=True) - response_contents[choice.index] += content - completion_tokens += 1 - else: - # print() - pass - - # Reset the terminal text color - print("\033[0m\n") - - # Prepare the final ChatCompletion object based on the accumulated data - model = chunk.model.replace("gpt-35", "gpt-3.5") # hack for Azure API - prompt_tokens = count_token(params["messages"], model) - response = ChatCompletion( - id=chunk.id, - model=chunk.model, - created=chunk.created, - object="chat.completion", - choices=[], - usage=CompletionUsage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ), - ) - for i in range(len(response_contents)): - if OPENAIVERSION >= "1.5": # pragma: no cover - # OpenAI versions 1.5.0 and above - choice = Choice( - index=i, - finish_reason=finish_reasons[i], - message=ChatCompletionMessage( - role="assistant", - content=response_contents[i], - function_call=full_function_call, - tool_calls=full_tool_calls, - ), - logprobs=None, - ) - else: - # OpenAI versions below 1.5.0 - choice = Choice( # type: ignore [call-arg] - index=i, - finish_reason=finish_reasons[i], - message=ChatCompletionMessage( - role="assistant", - content=response_contents[i], - function_call=full_function_call, - tool_calls=full_tool_calls, - ), - ) - - response.choices.append(choice) - else: - # If streaming is not enabled, send a regular chat completion request - params = params.copy() - params["stream"] = False - response = completions.create(**params) - - return response + def _update_usage(self, actual_usage, total_usage): + def update_usage(usage_summary, response_usage): + model = response_usage["model"] + cost = response_usage["cost"] + prompt_tokens = response_usage["prompt_tokens"] + completion_tokens = response_usage["completion_tokens"] + total_tokens = response_usage["total_tokens"] - def _update_usage_summary(self, response: Union[ChatCompletion, Completion], use_cache: bool) -> None: - """Update the usage summary. - - Usage is calculated no matter filter is passed or not. - """ - try: - usage = response.usage - assert usage is not None - usage.prompt_tokens = 0 if usage.prompt_tokens is None else usage.prompt_tokens - usage.completion_tokens = 0 if usage.completion_tokens is None else usage.completion_tokens - usage.total_tokens = 0 if usage.total_tokens is None else usage.total_tokens - except (AttributeError, AssertionError): - logger.debug("Usage attribute is not found in the response.", exc_info=True) - return - - def update_usage(usage_summary: Optional[Dict[str, Any]]) -> Dict[str, Any]: if usage_summary is None: - usage_summary = {"total_cost": response.cost} # type: ignore [union-attr] + usage_summary = {"total_cost": cost} else: - usage_summary["total_cost"] += response.cost # type: ignore [union-attr] - - usage_summary[response.model] = { - "cost": usage_summary.get(response.model, {}).get("cost", 0) + response.cost, # type: ignore [union-attr] - "prompt_tokens": usage_summary.get(response.model, {}).get("prompt_tokens", 0) + usage.prompt_tokens, - "completion_tokens": usage_summary.get(response.model, {}).get("completion_tokens", 0) - + usage.completion_tokens, - "total_tokens": usage_summary.get(response.model, {}).get("total_tokens", 0) + usage.total_tokens, + usage_summary["total_cost"] += cost + + usage_summary[model] = { + "cost": usage_summary.get(model, {}).get("cost", 0) + cost, + "prompt_tokens": usage_summary.get(model, {}).get("prompt_tokens", 0) + prompt_tokens, + "completion_tokens": usage_summary.get(model, {}).get("completion_tokens", 0) + completion_tokens, + "total_tokens": usage_summary.get(model, {}).get("total_tokens", 0) + total_tokens, } return usage_summary - self.total_usage_summary = update_usage(self.total_usage_summary) - if not use_cache: - self.actual_usage_summary = update_usage(self.actual_usage_summary) + if total_usage is not None: + self.total_usage_summary = update_usage(self.total_usage_summary, total_usage) + if actual_usage is not None: + self.actual_usage_summary = update_usage(self.actual_usage_summary, actual_usage) def print_usage_summary(self, mode: Union[str, List[str]] = ["actual", "total"]) -> None: """Print the usage summary.""" @@ -623,22 +707,6 @@ def clear_usage_summary(self) -> None: self.total_usage_summary = None self.actual_usage_summary = None - def cost(self, response: Union[ChatCompletion, Completion]) -> float: - """Calculate the cost of the response.""" - model = response.model - if model not in OAI_PRICE1K: - # TODO: add logging to warn that the model is not found - logger.debug(f"Model {model} is not found. The cost will be 0.", exc_info=True) - return 0 - - n_input_tokens = response.usage.prompt_tokens # type: ignore [union-attr] - n_output_tokens = response.usage.completion_tokens # type: ignore [union-attr] - tmp_price1K = OAI_PRICE1K[model] - # First value is input token rate, second value is output token rate - if isinstance(tmp_price1K, tuple): - return (tmp_price1K[0] * n_input_tokens + tmp_price1K[1] * n_output_tokens) / 1000 # type: ignore [no-any-return] - return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000 # type: ignore [operator] - @classmethod def extract_text_or_completion_object( cls, response: Union[ChatCompletion, Completion] From d313fa308243ed0cdff1b640a644771fcb0d04f3 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 09:32:49 -0500 Subject: [PATCH 02/30] add ability to register custom client --- autogen/agentchat/conversable_agent.py | 5 +- autogen/oai/__init__.py | 3 +- autogen/oai/client.py | 43 ++++++++++--- test/oai/test_custom_client.py | 85 ++++++++++++++++++++++++++ 4 files changed, 125 insertions(+), 11 deletions(-) create mode 100644 test/oai/test_custom_client.py diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index e96cf36953e6..6d5c78d122fe 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -8,7 +8,7 @@ from collections import defaultdict from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union -from .. import OpenAIWrapper +from .. import OpenAIWrapper, Client from ..code_utils import ( DEFAULT_MODEL, UNKNOWN, @@ -1876,6 +1876,9 @@ def _decorator(func: F) -> F: return _decorator + def register_custom_client(self, ClientClass: Client, **kwargs): + self.client.register_custom_client(ClientClass, **kwargs) + def register_hook(self, hookable_method: Callable, hook: Callable): """ Registers a hook to be called by a hookable method, in order to add a capability to the agent. diff --git a/autogen/oai/__init__.py b/autogen/oai/__init__.py index dbcd2f796074..c5f5bb231a8b 100644 --- a/autogen/oai/__init__.py +++ b/autogen/oai/__init__.py @@ -1,4 +1,4 @@ -from autogen.oai.client import OpenAIWrapper +from autogen.oai.client import OpenAIWrapper, Client from autogen.oai.completion import Completion, ChatCompletion from autogen.oai.openai_utils import ( get_config_list, @@ -11,6 +11,7 @@ __all__ = [ "OpenAIWrapper", + "Client", "Completion", "ChatCompletion", "get_config_list", diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 7d5b327d591d..1568f1a3fa0b 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -322,15 +322,21 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base logger.warning("openai client was provided with an empty config_list, which may not be intended.") if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying - self._clients: List[OpenAI] = [ - self._client(config, openai_config) for config in config_list - ] # could modify the config + self._clients: List[OpenAI] = [] # could modify the config + for config in config_list: + c = self._client(config, openai_config) + if c is not None: + self._clients.append(c) + self._config_list = [ {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} for config in config_list ] else: - self._clients = [self._client(extra_kwargs, openai_config)] + self._clients = [] + c = self._client(extra_kwargs, openai_config) + if c is not None: + self._clients.append(c) self._config_list = [extra_kwargs] def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -345,7 +351,7 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} return create_config, extra_kwargs - def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAI: + def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAIClient: """Create a client with the given config to override openai_config, after removing extra kwargs. @@ -361,10 +367,23 @@ def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> Open if openai_config["azure_deployment"] is not None: openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) - client = OpenAIClient(AzureOpenAI(**openai_config)) - else: - client = OpenAIClient(OpenAI(**openai_config)) - return client + return OpenAIClient(AzureOpenAI(**openai_config)) + elif api_type is None: + return OpenAIClient(OpenAI(**openai_config)) + # config for a custom client is set + # skipping until the register_custom_client is called with the appropriate class + return None + + def register_custom_client(self, ClientClass: Client, **kwargs): + """Register a custom client. + + Args: + client: A custom client that follows the Client interface + """ + for config in self._config_list: + if config["api_type"] is not None and config["api_type"] == ClientClass.__name__: + client = ClientClass(config, **kwargs) + self._clients.append(client) @classmethod def instantiate( @@ -723,6 +742,12 @@ def extract_text_or_completion_object( if isinstance(response, Completion): return [choice.text for choice in choices] # type: ignore [union-attr] + if not isinstance(response, ChatCompletion) and not isinstance(response, Completion): + return [ + choice.message if choice.message.function_call is not None else choice.message.content + for choice in choices + ] + if TOOL_ENABLED: return [ # type: ignore [return-value] choice.message # type: ignore [union-attr] diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py new file mode 100644 index 000000000000..7760aff9d40c --- /dev/null +++ b/test/oai/test_custom_client.py @@ -0,0 +1,85 @@ +import pytest +from autogen import OpenAIWrapper +from autogen.oai import Client +from typing import Dict + +try: + from openai import OpenAI +except ImportError: + skip = True +else: + skip = False + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_custom_client(): + TEST_COST = 20000000 + TEST_CUSTOM_RESPONSE = "This is a custom response." + TEST_DEVICE = "cpu" + TEST_LOCAL_MODEL_NAME = "local_model_name" + TEST_OTHER_PARAMS_VAL = "other_params" + TEST_MAX_LENGTH = 1000 + + class CustomClient(Client): + def __init__(self, config: Dict, test_hook): + self.test_hook = test_hook + self.device = config["device"] + self.model = config["model"] + self.other_params = config["params"]["other_params"] + self.max_length = config["params"]["max_length"] + self.test_hook["called"] = True + # set all params to test hook + self.test_hook["device"] = self.device + self.test_hook["model"] = self.model + self.test_hook["other_params"] = self.other_params + self.test_hook["max_length"] = self.max_length + + def create(self, params): + if params.get("stream", False) and "messages" in params and "functions" not in params: + raise NotImplementedError("Custom Client does not support streaming or functions") + else: + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + choice.message.function_call = None + response.choices.append(choice) + response.model = self.model + return response + + def cost(self, response) -> float: + """Calculate the cost of the response.""" + response.cost = TEST_COST + return TEST_COST + + config_list = [ + { + "model": TEST_LOCAL_MODEL_NAME, + "device": TEST_DEVICE, + "api_type": "CustomClient", + "params": { + "max_length": TEST_MAX_LENGTH, + "other_params": TEST_OTHER_PARAMS_VAL, + }, + }, + ] + + test_hook = {"called": False} + + client = OpenAIWrapper(config_list=config_list) + client.register_custom_client(CustomClient, test_hook=test_hook) + + response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE + assert response.choices[0].message.function_call is None + assert response.cost == TEST_COST + + assert test_hook["called"] + assert test_hook["device"] == TEST_DEVICE + assert test_hook["model"] == TEST_LOCAL_MODEL_NAME + assert test_hook["other_params"] == TEST_OTHER_PARAMS_VAL + assert test_hook["max_length"] == TEST_MAX_LENGTH \ No newline at end of file From 340508ba643190a65f741f7a8529aaa5a5e384b9 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 10:10:27 -0500 Subject: [PATCH 03/30] tidy up code --- autogen/oai/client.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 1568f1a3fa0b..999b9196bcbe 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -320,23 +320,17 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base openai_config, extra_kwargs = self._separate_openai_config(base_config) if type(config_list) is list and len(config_list) == 0: logger.warning("openai client was provided with an empty config_list, which may not be intended.") + self._clients: List[Client] = [] if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying - self._clients: List[OpenAI] = [] # could modify the config for config in config_list: - c = self._client(config, openai_config) - if c is not None: - self._clients.append(c) - + self._register_client(config, openai_config) # could modify the config self._config_list = [ {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} for config in config_list ] else: - self._clients = [] - c = self._client(extra_kwargs, openai_config) - if c is not None: - self._clients.append(c) + self._register_client(extra_kwargs, openai_config) self._config_list = [extra_kwargs] def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -351,7 +345,7 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} return create_config, extra_kwargs - def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAIClient: + def _register_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAIClient: """Create a client with the given config to override openai_config, after removing extra kwargs. @@ -367,18 +361,17 @@ def _client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> Open if openai_config["azure_deployment"] is not None: openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) - return OpenAIClient(AzureOpenAI(**openai_config)) + self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) elif api_type is None: - return OpenAIClient(OpenAI(**openai_config)) - # config for a custom client is set + self._clients.append(OpenAIClient(OpenAI(**openai_config))) + # else a config for a custom client is set # skipping until the register_custom_client is called with the appropriate class - return None def register_custom_client(self, ClientClass: Client, **kwargs): """Register a custom client. Args: - client: A custom client that follows the Client interface + client: A custom client that follows the Client interface """ for config in self._config_list: if config["api_type"] is not None and config["api_type"] == ClientClass.__name__: From 655e972e2f605d3f8a930fc4d4b2a150553fdb34 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 10:32:59 -0500 Subject: [PATCH 04/30] adding checks and errors, and more unit tests --- autogen/oai/client.py | 10 +- notebook/agentchat_custom_model.ipynb | 716 ++++++++++++++++++++++++++ test/oai/test_custom_client.py | 44 +- 3 files changed, 767 insertions(+), 3 deletions(-) create mode 100644 notebook/agentchat_custom_model.ipynb diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 999b9196bcbe..ec234d06275d 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -70,7 +70,7 @@ class Client(ABC): - cost This class is used to create a client that can be used by OpenAIWrapper. - It mimicks the OpenAI class, but allows for custom clients to be used. + It mimics the OpenAI class, but allows for custom clients to be used. """ RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] @@ -324,7 +324,7 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying for config in config_list: - self._register_client(config, openai_config) # could modify the config + self._register_client(config, openai_config) # could modify the config self._config_list = [ {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} for config in config_list @@ -373,10 +373,14 @@ def register_custom_client(self, ClientClass: Client, **kwargs): Args: client: A custom client that follows the Client interface """ + found = False for config in self._config_list: if config["api_type"] is not None and config["api_type"] == ClientClass.__name__: client = ClientClass(config, **kwargs) self._clients.append(client) + found = True + if not found: + raise ValueError(f'Custom client "{ClientClass.__name__}" was not found in the config_list.') @classmethod def instantiate( @@ -453,6 +457,8 @@ def yes_or_no_filter(context, response): if ERROR: raise ERROR last = len(self._clients) - 1 + if len(self._clients) == 0: + raise RuntimeError("No client model is registered. Please register a model client first.") for i, client in enumerate(self._clients): # merge the input config with the i-th config in the config list full_config = {**config, **self._config_list[i]} diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb new file mode 100644 index 000000000000..bf4a3950e712 --- /dev/null +++ b/notebook/agentchat_custom_model.ipynb @@ -0,0 +1,716 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\"Open" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": { + "slideshow": { + "slide_type": "slide" + } + }, + "source": [ + "# Agent Chat with custom model loading\n", + "\n", + "In this notebook, we demonstrate how a custom model can be defined and loaded, and what interface it needs to comply to.\n", + "\n", + "## Requirements\n", + "\n", + "AutoGen requires `Python>=3.8`. To run this notebook example, please install:\n", + "```bash\n", + "pip install pyautogen torch transformers sentencepiece\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "execution": { + "iopub.execute_input": "2023-02-13T23:40:52.317406Z", + "iopub.status.busy": "2023-02-13T23:40:52.316561Z", + "iopub.status.idle": "2023-02-13T23:40:52.321193Z", + "shell.execute_reply": "2023-02-13T23:40:52.320628Z" + } + }, + "outputs": [], + "source": [ + "# %pip install pyautogen~=0.2.0b4 torch git+https://github.com/huggingface/transformers sentencepiece" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import autogen\n", + "from autogen.oai import Client\n", + "from autogen import AssistantAgent, UserProxyAgent\n", + "from transformers import AutoTokenizer, GenerationConfig, AutoModelForCausalLM\n", + "from types import SimpleNamespace" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create and configure the custom client / model loader\n", + "\n", + "A custom client can be created in many ways, but needs to adhere to the `Client` interface and response structure which is defined in client.py\n", + "\n", + "```python\n", + "\n", + "class Client(ABC):\n", + " \"\"\"\n", + " A client class must implement the following methods:\n", + " - create must return a response object that implements the ClientResponseProtocol\n", + " - cost\n", + "\n", + " This class is used to create a client that can be used by OpenAIWrapper.\n", + " It mimicks the OpenAI class, but allows for custom clients to be used.\n", + " \"\"\"\n", + "\n", + " RESPONSE_USAGE_KEYS = [\"prompt_tokens\", \"completion_tokens\", \"total_tokens\", \"cost\", \"model\"]\n", + "\n", + " class ClientResponseProtocol(Protocol):\n", + " class Choice(Protocol):\n", + " class Message(Protocol):\n", + " content: str | None\n", + " function_call: str | None\n", + "\n", + " choices: List[Choice]\n", + " config_id: int\n", + " cost: float\n", + " pass_filter: bool\n", + " model: str\n", + "\n", + " def update(self, config: Dict):\n", + " pass\n", + "\n", + " @abstractmethod\n", + " def create(self, params) -> ClientResponseProtocol:\n", + " pass\n", + "\n", + " @abstractmethod\n", + " def cost(self, response: ClientResponseProtocol) -> float:\n", + " pass\n", + "\n", + " @staticmethod\n", + " def get_usage(response: ClientResponseProtocol) -> Dict:\n", + " return None\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Example of simple custom client\n", + "\n", + "Following the huggingface example for using [Mistral's Open-Orca](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# custom client with custom model loader\n", + "\n", + "\n", + "class CustomClient(Client):\n", + " def __init__(self, config, **kwargs):\n", + " print(f\"CustomClient config: {config}\")\n", + " if \"print_message\" in kwargs:\n", + " print(f\"print message from kwargs: {kwargs['print_message']}\")\n", + " self.device = config.get(\"device\", \"cpu\")\n", + " self.model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(self.device)\n", + " self.model_name = config[\"model\"]\n", + " self.tokenizer = AutoTokenizer.from_pretrained(config[\"model\"], use_fast=False)\n", + " self.tokenizer.pad_token_id = self.tokenizer.eos_token_id\n", + "\n", + " # params are set by the user and consumed by the user since they are providing a custom model\n", + " # so anything can be done here\n", + " gen_config_params = config.get(\"params\", {})\n", + " self.max_length = gen_config_params.get(\"max_length\", 256)\n", + "\n", + " print(f\"Loaded model {config['model']} to {self.device}\")\n", + "\n", + " def create(self, params):\n", + " if params.get(\"stream\", False) and \"messages\" in params and \"functions\" not in params:\n", + " raise NotImplementedError(\"Local models do not support streaming or functions\")\n", + " else:\n", + " num_of_responses = params.get(\"n\", 1)\n", + "\n", + " # can create my own data response class\n", + " # here using SimpleNamespace for simplicity\n", + " # as long as it adheres to the ClientResponseProtocol\n", + "\n", + " response = SimpleNamespace()\n", + "\n", + " inputs = self.tokenizer.apply_chat_template(\n", + " params[\"messages\"], return_tensors=\"pt\", add_generation_prompt=True\n", + " ).to(self.device)\n", + " inputs_length = inputs.shape[-1]\n", + "\n", + " # add inputs_length to max_length\n", + " max_length = self.max_length + inputs_length\n", + " generation_config = GenerationConfig(\n", + " max_length=max_length,\n", + " eos_token_id=self.tokenizer.eos_token_id,\n", + " pad_token_id=self.tokenizer.eos_token_id,\n", + " )\n", + "\n", + " response.choices = []\n", + " response.model = self.model_name\n", + "\n", + " for _ in range(num_of_responses):\n", + " outputs = self.model.generate(inputs, generation_config=generation_config)\n", + " # Decode only the newly generated text, excluding the prompt\n", + " text = self.tokenizer.decode(outputs[0, inputs_length:])\n", + " choice = SimpleNamespace()\n", + " choice.message = SimpleNamespace()\n", + " choice.message.content = text\n", + " choice.message.function_call = None\n", + " response.choices.append(choice)\n", + "\n", + " return response\n", + "\n", + " def cost(self, response) -> float:\n", + " \"\"\"Calculate the cost of the response.\"\"\"\n", + " response.cost = 0\n", + " return 0\n", + "\n", + " def get_usage(self, response):\n", + " # returns a dict of prompt_tokens, completion_tokens, total_tokens, cost, model\n", + " # if usage needs to be tracked, else None\n", + " return None" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set your API Endpoint\n", + "\n", + "The [`config_list_from_json`](https://microsoft.github.io/autogen/docs/reference/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file.\n", + "\n", + "It first looks for an environment variable of a specified name (\"OAI_CONFIG_LIST\" in this example), which needs to be a valid json string. If that variable is not found, it looks for a json file with the same name. It filters the configs by models (you can filter by other keys as well).\n", + "\n", + "The json looks like the following:\n", + "```json\n", + "[\n", + " {\n", + " \"model\": \"gpt-4\",\n", + " \"api_key\": \"\"\n", + " },\n", + " {\n", + " \"model\": \"gpt-4\",\n", + " \"api_key\": \"\",\n", + " \"base_url\": \"\",\n", + " \"api_type\": \"azure\",\n", + " \"api_version\": \"2023-08-01-preview\"\n", + " },\n", + " {\n", + " \"model\": \"gpt-4-32k\",\n", + " \"api_key\": \"\",\n", + " \"base_url\": \"\",\n", + " \"api_type\": \"azure\",\n", + " \"api_version\": \"2023-08-01-preview\"\n", + " }\n", + "]\n", + "```\n", + "\n", + "You can set the value of config_list in any way you prefer. Please refer to this [notebook](https://github.com/microsoft/autogen/blob/main/notebook/oai_openai_utils.ipynb) for full code examples of the different methods." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Set the config for the custom model\n", + "\n", + "You can add any paramteres that are needed for the custom model loading in the same configuration list, as long as `model` is specified\n", + "\n", + "It is important to add the `api_type` with the name of the new client class\n", + "\n", + "```json\n", + "{\n", + " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", + " \"api_type\": \"CustomClient\",\n", + " \"device\": \"cuda\",\n", + " \"n\": 1,\n", + " \"params\": {\n", + " \"max_length\": 1000,\n", + " }\n", + "},\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_list_custom = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " filter_dict={\"model\": [\"Open-Orca/Mistral-7B-OpenOrca\"]},\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Construct Agents\n", + "\n", + "Consturct a simple conversation between a User proxy and an Assistent agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})\n", + "user_proxy = UserProxyAgent(\"user_proxy\", code_execution_config={\"work_dir\": \"coding\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Register the custom client class to the assistant agent" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant.register_custom_client(\n", + " CustomClient,\n", + " print_message=\"print this message on creation to demonstrate passing other args to CustomClient constructor\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "user_proxy.initiate_chat(assistant, message=\"Plot a chart of NVDA and TESLA stock price change YTD.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.5" + }, + "vscode": { + "interpreter": { + "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1" + } + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": { + "2d910cfd2d2a4fc49fc30fbbdc5576a7": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "454146d0f7224f038689031002906e6f": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HBoxModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_e4ae2b6f5a974fd4bafb6abb9d12ff26", + "IPY_MODEL_577e1e3cc4db4942b0883577b3b52755", + "IPY_MODEL_b40bdfb1ac1d4cffb7cefcb870c64d45" + ], + "layout": "IPY_MODEL_dc83c7bff2f241309537a8119dfc7555", + "tabbable": null, + "tooltip": null + } + }, + "577e1e3cc4db4942b0883577b3b52755": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "FloatProgressModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_2d910cfd2d2a4fc49fc30fbbdc5576a7", + "max": 1, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_74a6ba0c3cbc4051be0a83e152fe1e62", + "tabbable": null, + "tooltip": null, + "value": 1 + } + }, + "6086462a12d54bafa59d3c4566f06cb2": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "74a6ba0c3cbc4051be0a83e152fe1e62": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "ProgressStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "7d3f3d9e15894d05a4d188ff4f466554": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "b40bdfb1ac1d4cffb7cefcb870c64d45": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_f1355871cc6f4dd4b50d9df5af20e5c8", + "placeholder": "​", + "style": "IPY_MODEL_ca245376fd9f4354af6b2befe4af4466", + "tabbable": null, + "tooltip": null, + "value": " 1/1 [00:00<00:00, 44.69it/s]" + } + }, + "ca245376fd9f4354af6b2befe4af4466": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLStyleModel", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "StyleView", + "background": null, + "description_width": "", + "font_size": null, + "text_color": null + } + }, + "dc83c7bff2f241309537a8119dfc7555": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e4ae2b6f5a974fd4bafb6abb9d12ff26": { + "model_module": "@jupyter-widgets/controls", + "model_module_version": "2.0.0", + "model_name": "HTMLModel", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "2.0.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "2.0.0", + "_view_name": "HTMLView", + "description": "", + "description_allow_html": false, + "layout": "IPY_MODEL_6086462a12d54bafa59d3c4566f06cb2", + "placeholder": "​", + "style": "IPY_MODEL_7d3f3d9e15894d05a4d188ff4f466554", + "tabbable": null, + "tooltip": null, + "value": "100%" + } + }, + "f1355871cc6f4dd4b50d9df5af20e5c8": { + "model_module": "@jupyter-widgets/base", + "model_module_version": "2.0.0", + "model_name": "LayoutModel", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "2.0.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "2.0.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border_bottom": null, + "border_left": null, + "border_right": null, + "border_top": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + } + }, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 7760aff9d40c..7f4c60012a47 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -82,4 +82,46 @@ def cost(self, response) -> float: assert test_hook["device"] == TEST_DEVICE assert test_hook["model"] == TEST_LOCAL_MODEL_NAME assert test_hook["other_params"] == TEST_OTHER_PARAMS_VAL - assert test_hook["max_length"] == TEST_MAX_LENGTH \ No newline at end of file + assert test_hook["max_length"] == TEST_MAX_LENGTH + + +def test_registering_with_wrong_name_missing_raises_error(): + class CustomClient(Client): + def __init__(self, config: Dict): + pass + + def create(self, params): + return None + + def cost(self, response) -> float: + return 0 + + config_list = [ + { + "model": "local_model_name", + "api_type": "CustomClientButWrongName", + }, + ] + client = OpenAIWrapper(config_list=config_list) + + with pytest.raises(ValueError): + client.register_custom_client(CustomClient) + + +def test_custom_client_not_registered_raises_error(): + config_list = [ + { + "model": "local_model_name", + "device": "cpu", + "api_type": "CustomClient", + "params": { + "max_length": 1000, + "other_params": "other_params", + }, + }, + ] + + client = OpenAIWrapper(config_list=config_list) + + with pytest.raises(RuntimeError): + client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) From 32a9ce3d8b4a567f0cfdbc319a936a190b0105c7 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 10:40:15 -0500 Subject: [PATCH 05/30] remove code --- autogen/oai/client.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index ec234d06275d..498d5e707a76 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -51,18 +51,6 @@ logger.addHandler(_ch) -def template_formatter( - template: str | Callable | None, - context: Optional[Dict] = None, - allow_format_str_template: Optional[bool] = False, -): - if not context or template is None: - return template - if isinstance(template, str): - return template.format(**context) if allow_format_str_template else template - return template(context) - - class Client(ABC): """ A client class must implement the following methods: From 29ca1797ccb2b37f80d1e542d50699fcf0c1c2d0 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 10:41:35 -0500 Subject: [PATCH 06/30] fix error msg --- autogen/oai/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 498d5e707a76..b9b84ff48ec4 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -446,7 +446,7 @@ def yes_or_no_filter(context, response): raise ERROR last = len(self._clients) - 1 if len(self._clients) == 0: - raise RuntimeError("No client model is registered. Please register a model client first.") + raise RuntimeError("No model client is registered. Please register a model client first.") for i, client in enumerate(self._clients): # merge the input config with the i-th config in the config list full_config = {**config, **self._config_list[i]} From 64863f2e45e617e3cfbca351952b32a582c26f37 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 10:47:14 -0500 Subject: [PATCH 07/30] add use_docer False in notebook --- notebook/agentchat_custom_model.ipynb | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index bf4a3950e712..8f8d61a0a6c3 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -287,7 +287,13 @@ "outputs": [], "source": [ "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})\n", - "user_proxy = UserProxyAgent(\"user_proxy\", code_execution_config={\"work_dir\": \"coding\"})" + "user_proxy = UserProxyAgent(\n", + " \"user_proxy\",\n", + " code_execution_config={\n", + " \"work_dir\": \"coding\",\n", + " \"use_docker\": False, # Please set use_docker=True if docker is available to run the generated code. Using docker is safer than running the generated code directly.\n", + " },\n", + ")" ] }, { From 7bd1026c7e0c736db311a182f6389fff829649a2 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 12:30:49 -0500 Subject: [PATCH 08/30] better error message --- autogen/oai/client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index c85997f9e5fe..bdaa53cff1c6 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -372,7 +372,9 @@ def register_custom_client(self, ClientClass: Client, **kwargs): self._clients.append(client) found = True if not found: - raise ValueError(f'Custom client "{ClientClass.__name__}" was not found in the config_list.') + raise ValueError( + f'Custom client "{ClientClass.__name__}" was not found in the config_list. Please make sure to include an entry in the config_list with api_type = "{ClientClass.__name__}"' + ) @classmethod def instantiate( From 8e0e276854ad34cf6baacad814dde5d1a8a91ad7 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 18 Jan 2024 13:10:17 -0500 Subject: [PATCH 09/30] add another example to custom model notebook --- notebook/agentchat_custom_model.ipynb | 113 ++++++++++++++++++++++++-- 1 file changed, 107 insertions(+), 6 deletions(-) diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 8f8d61a0a6c3..b9fbf377eddc 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -131,8 +131,6 @@ "class CustomClient(Client):\n", " def __init__(self, config, **kwargs):\n", " print(f\"CustomClient config: {config}\")\n", - " if \"print_message\" in kwargs:\n", - " print(f\"print message from kwargs: {kwargs['print_message']}\")\n", " self.device = config.get(\"device\", \"cpu\")\n", " self.model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(self.device)\n", " self.model_name = config[\"model\"]\n", @@ -266,7 +264,7 @@ "source": [ "config_list_custom = autogen.config_list_from_json(\n", " \"OAI_CONFIG_LIST\",\n", - " filter_dict={\"model\": [\"Open-Orca/Mistral-7B-OpenOrca\"]},\n", + " filter_dict={\"api_type\": [\"CustomClient\"]},\n", ")" ] }, @@ -309,12 +307,115 @@ "metadata": {}, "outputs": [], "source": [ - "assistant.register_custom_client(\n", - " CustomClient,\n", - " print_message=\"print this message on creation to demonstrate passing other args to CustomClient constructor\",\n", + "assistant.register_custom_client(CustomClient)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "user_proxy.initiate_chat(assistant, message=\"Plot a chart of NVDA and TESLA stock price change YTD.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Register a custom client class with a pre-loaded model\n", + "\n", + "If you want to have more control over when the model gets loaded, you can load the model yourself and pass it as an argument to the CustomClient during registration" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# custom client with custom model loader\n", + "\n", + "\n", + "class CustomClientWithArguments(CustomClient):\n", + " def __init__(self, config, model, tokenizer, **kwargs):\n", + " print(f\"CustomClientWithArguments config: {config}\")\n", + "\n", + " self.model_name = config[\"model\"]\n", + " self.model = model\n", + " self.tokenizer = tokenizer\n", + "\n", + " self.device = config.get(\"device\", \"cpu\")\n", + "\n", + " gen_config_params = config.get(\"params\", {})\n", + " self.max_length = gen_config_params.get(\"max_length\", 256)\n", + " print(f\"Loaded model {config['model']} to {self.device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Update the config list to include the CustomClientWithArguments\n", + "\n", + "```json\n", + "{\n", + " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", + " \"api_type\": \"CustomClientWithArguments\",\n", + " \"device\": \"cuda\",\n", + " \"n\": 1,\n", + " \"params\": {\n", + " \"max_length\": 1000,\n", + " }\n", + "},\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_list_custom = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " filter_dict={\"api_type\": [\"CustomClientWithArguments\"]},\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load model here\n", + "\n", + "config = config_list_custom[0]\n", + "device = config.get(\"device\", \"cpu\")\n", + "model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n", + "tokenizer = AutoTokenizer.from_pretrained(config[\"model\"], use_fast=False)\n", + "tokenizer.pad_token_id = tokenizer.eos_token_id" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assistant.register_custom_client(CustomClientWithArguments, model=model, tokenizer=tokenizer)" + ] + }, { "cell_type": "code", "execution_count": null, From 865fa5d6ab937159bf2b01bff34883951bdc703f Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 19 Jan 2024 05:31:13 -0500 Subject: [PATCH 10/30] renames --- autogen/agentchat/conversable_agent.py | 4 ++-- autogen/oai/client.py | 14 +++++++------- notebook/agentchat_custom_model.ipynb | 4 ++-- test/oai/test_custom_client.py | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 315b518fd07d..5a23f0bf7dc5 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1900,8 +1900,8 @@ def _decorator(func: F) -> F: return _decorator - def register_custom_client(self, ClientClass: Client, **kwargs): - self.client.register_custom_client(ClientClass, **kwargs) + def register_client(self, ClientClass: Client, **kwargs): + self.client.register_client(ClientClass, **kwargs) def register_hook(self, hookable_method: Callable, hook: Callable): """ diff --git a/autogen/oai/client.py b/autogen/oai/client.py index bdaa53cff1c6..b5e02bdb66bc 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -97,7 +97,7 @@ def get_usage(response: ClientResponseProtocol) -> Dict: class OpenAIClient(Client): def __init__(self, client): - self.client = client + self._oai_client = client def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config using openai's client. @@ -109,7 +109,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: Returns: The completion. """ - completions: Completions = self.client.chat.completions if "messages" in params else self.client.completions # type: ignore [attr-defined] + completions: Completions = self._oai_client.chat.completions if "messages" in params else self._oai_client.completions # type: ignore [attr-defined] # If streaming is enabled and has messages, then iterate over the chunks of the response. if params.get("stream", False) and "messages" in params: response_contents = [""] * params.get("n", 1) @@ -316,13 +316,13 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying for config in config_list: - self._register_client(config, openai_config) # could modify the config + self._register_openai_client(config, openai_config) # could modify the config self._config_list = [ {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} for config in config_list ] else: - self._register_client(extra_kwargs, openai_config) + self._register_openai_client(extra_kwargs, openai_config) self._config_list = [extra_kwargs] def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: @@ -337,7 +337,7 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} return create_config, extra_kwargs - def _register_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAIClient: + def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAIClient: """Create a client with the given config to override openai_config, after removing extra kwargs. @@ -357,9 +357,9 @@ def _register_client(self, config: Dict[str, Any], openai_config: Dict[str, Any] elif api_type is None: self._clients.append(OpenAIClient(OpenAI(**openai_config))) # else a config for a custom client is set - # skipping until the register_custom_client is called with the appropriate class + # skipping until the register_client is called with the appropriate class - def register_custom_client(self, ClientClass: Client, **kwargs): + def register_client(self, ClientClass: Client, **kwargs): """Register a custom client. Args: diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index b9fbf377eddc..f6c868061e85 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -307,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "assistant.register_custom_client(CustomClient)" + "assistant.register_client(CustomClient)" ] }, { @@ -413,7 +413,7 @@ "metadata": {}, "outputs": [], "source": [ - "assistant.register_custom_client(CustomClientWithArguments, model=model, tokenizer=tokenizer)" + "assistant.register_client(CustomClientWithArguments, model=model, tokenizer=tokenizer)" ] }, { diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 7f4c60012a47..3bd1160feb74 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -71,7 +71,7 @@ def cost(self, response) -> float: test_hook = {"called": False} client = OpenAIWrapper(config_list=config_list) - client.register_custom_client(CustomClient, test_hook=test_hook) + client.register_client(CustomClient, test_hook=test_hook) response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE @@ -105,7 +105,7 @@ def cost(self, response) -> float: client = OpenAIWrapper(config_list=config_list) with pytest.raises(ValueError): - client.register_custom_client(CustomClient) + client.register_client(CustomClient) def test_custom_client_not_registered_raises_error(): From c202b5d30e3a9019c1b0b935df3d31a476648e83 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 19 Jan 2024 06:08:08 -0500 Subject: [PATCH 11/30] rename and have register_client take model name too --- autogen/agentchat/conversable_agent.py | 11 ++++++++-- autogen/oai/client.py | 20 ++++++++++-------- notebook/agentchat_custom_model.ipynb | 13 +++++++----- test/oai/test_custom_client.py | 29 +++++++++++++++++++++++--- 4 files changed, 54 insertions(+), 19 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 5a23f0bf7dc5..19296bcf65ad 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1900,8 +1900,15 @@ def _decorator(func: F) -> F: return _decorator - def register_client(self, ClientClass: Client, **kwargs): - self.client.register_client(ClientClass, **kwargs) + def register_client(self, model: str, client_cls: Client, **kwargs): + """Register a custom client. + + Args: + model: The model name, as specified in the config list + client_cls: A custom client class that follows the Client interface + **kwargs: The kwargs for the custom client class to be initialized with + """ + self.client.register_client(model, client_cls, **kwargs) def register_hook(self, hookable_method: Callable, hook: Callable): """ diff --git a/autogen/oai/client.py b/autogen/oai/client.py index b5e02bdb66bc..dee089fed3af 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -78,10 +78,6 @@ class Message(Protocol): pass_filter: bool model: str - def update(self, config: Dict): - # update with anything here - pass - @abstractmethod def create(self, params) -> ClientResponseProtocol: pass @@ -359,21 +355,27 @@ def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[st # else a config for a custom client is set # skipping until the register_client is called with the appropriate class - def register_client(self, ClientClass: Client, **kwargs): + def register_client(self, model: str, client_cls: Client, **kwargs): """Register a custom client. Args: - client: A custom client that follows the Client interface + model: The model name, as specified in the config list + client_cls: A custom client class that follows the Client interface + **kwargs: The kwargs for the custom client class to be initialized with """ found = False for config in self._config_list: - if config["api_type"] is not None and config["api_type"] == ClientClass.__name__: - client = ClientClass(config, **kwargs) + if ( + config["api_type"] is not None + and config["api_type"] == client_cls.__name__ + and config["model"] == model + ): + client = client_cls(config, **kwargs) self._clients.append(client) found = True if not found: raise ValueError( - f'Custom client "{ClientClass.__name__}" was not found in the config_list. Please make sure to include an entry in the config_list with api_type = "{ClientClass.__name__}"' + f'Pair of model "{model}" and api_type "{client_cls.__name__}" was not found in the config_list. Please make sure to include an entry in the config_list with "model": "{model}" and "api_type": "{client_cls.__name__}"' ) @classmethod diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index f6c868061e85..2501f36bdd1a 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -307,7 +307,8 @@ "metadata": {}, "outputs": [], "source": [ - "assistant.register_client(CustomClient)" + "config = config_list_custom[0]\n", + "assistant.register_client(model=config[\"model\"], client_cls=CustomClient)" ] }, { @@ -338,11 +339,11 @@ "\n", "\n", "class CustomClientWithArguments(CustomClient):\n", - " def __init__(self, config, model, tokenizer, **kwargs):\n", + " def __init__(self, config, loaded_model, tokenizer, **kwargs):\n", " print(f\"CustomClientWithArguments config: {config}\")\n", "\n", " self.model_name = config[\"model\"]\n", - " self.model = model\n", + " self.model = loaded_model\n", " self.tokenizer = tokenizer\n", "\n", " self.device = config.get(\"device\", \"cpu\")\n", @@ -393,7 +394,7 @@ "\n", "config = config_list_custom[0]\n", "device = config.get(\"device\", \"cpu\")\n", - "model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n", + "loaded_model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(device)\n", "tokenizer = AutoTokenizer.from_pretrained(config[\"model\"], use_fast=False)\n", "tokenizer.pad_token_id = tokenizer.eos_token_id" ] @@ -413,7 +414,9 @@ "metadata": {}, "outputs": [], "source": [ - "assistant.register_client(CustomClientWithArguments, model=model, tokenizer=tokenizer)" + "assistant.register_client(\n", + " model=config[\"model\"], client_cls=CustomClientWithArguments, loaded_model=loaded_model, tokenizer=tokenizer\n", + ")" ] }, { diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 3bd1160feb74..9507b057bda3 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -71,7 +71,7 @@ def cost(self, response) -> float: test_hook = {"called": False} client = OpenAIWrapper(config_list=config_list) - client.register_client(CustomClient, test_hook=test_hook) + client.register_client(model=TEST_LOCAL_MODEL_NAME, client_cls=CustomClient, test_hook=test_hook) response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE @@ -85,7 +85,7 @@ def cost(self, response) -> float: assert test_hook["max_length"] == TEST_MAX_LENGTH -def test_registering_with_wrong_name_missing_raises_error(): +def test_registering_with_wrong_cls_name_raises_error(): class CustomClient(Client): def __init__(self, config: Dict): pass @@ -105,7 +105,30 @@ def cost(self, response) -> float: client = OpenAIWrapper(config_list=config_list) with pytest.raises(ValueError): - client.register_client(CustomClient) + client.register_client(model="local_model_name", client_cls=CustomClient) + + +def test_registering_with_wrong_model_name_raises_error(): + class CustomClient(Client): + def __init__(self, config: Dict): + pass + + def create(self, params): + return None + + def cost(self, response) -> float: + return 0 + + config_list = [ + { + "model": "local_model_name_but_wrong_name", + "api_type": "CustomClient", + }, + ] + client = OpenAIWrapper(config_list=config_list) + + with pytest.raises(ValueError): + client.register_client(model="local_model_name", client_cls=CustomClient) def test_custom_client_not_registered_raises_error(): From 8e8fe930365aa5ac05dcb134d7dd5f7b5068f65a Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 19 Jan 2024 06:32:19 -0500 Subject: [PATCH 12/30] make Client protocol and remove inheritance --- autogen/oai/client.py | 29 +++++++++++++++-------- notebook/agentchat_custom_model.ipynb | 33 +++++++++++++++------------ test/oai/test_custom_client.py | 14 +++++++++--- 3 files changed, 49 insertions(+), 27 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index dee089fed3af..164a76d91a98 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -8,7 +8,6 @@ from flaml.automl.logger import logger_formatter from pydantic import BaseModel -from abc import ABC, abstractmethod from typing import Protocol from autogen.cache.cache import Cache @@ -54,11 +53,17 @@ LEGACY_CACHE_DIR = ".cache" -class Client(ABC): +class Client(Protocol): """ A client class must implement the following methods: - create must return a response object that implements the ClientResponseProtocol - - cost + - cost must return the cost of the response + - get_usage must return a dict with the following keys: + - prompt_tokens + - completion_tokens + - total_tokens + - cost + - model This class is used to create a client that can be used by OpenAIWrapper. It mimics the OpenAI class, but allows for custom clients to be used. @@ -78,20 +83,21 @@ class Message(Protocol): pass_filter: bool model: str - @abstractmethod def create(self, params) -> ClientResponseProtocol: - pass + ... # pragma: no cover - @abstractmethod def cost(self, response: ClientResponseProtocol) -> float: - pass + ... # pragma: no cover @staticmethod def get_usage(response: ClientResponseProtocol) -> Dict: - return None + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + ... # pragma: no cover -class OpenAIClient(Client): +class OpenAIClient: + """Follows the Client protocol and wraps the OpenAI client.""" + def __init__(self, client): self._oai_client = client @@ -653,6 +659,11 @@ def _update_tool_calls_from_chunk( def _update_usage(self, actual_usage, total_usage): def update_usage(usage_summary, response_usage): + # go through RESPONSE_USAGE_KEYS and check that they are in response_usage and if not just return usage_summary + for key in Client.RESPONSE_USAGE_KEYS: + if key not in response_usage: + return usage_summary + model = response_usage["model"] cost = response_usage["cost"] prompt_tokens = response_usage["prompt_tokens"] diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 2501f36bdd1a..513dbc518d81 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -64,18 +64,24 @@ "source": [ "## Create and configure the custom client / model loader\n", "\n", - "A custom client can be created in many ways, but needs to adhere to the `Client` interface and response structure which is defined in client.py\n", + "A custom client can be created in many ways, but needs to adhere to the `Client` protocol and response structure which is defined in client.py\n", "\n", "```python\n", "\n", - "class Client(ABC):\n", + "class Client(Protocol):\n", " \"\"\"\n", " A client class must implement the following methods:\n", " - create must return a response object that implements the ClientResponseProtocol\n", - " - cost\n", + " - cost must return the cost of the response\n", + " - get_usage must return a dict with the following keys:\n", + " - prompt_tokens\n", + " - completion_tokens\n", + " - total_tokens\n", + " - cost\n", + " - model\n", "\n", " This class is used to create a client that can be used by OpenAIWrapper.\n", - " It mimicks the OpenAI class, but allows for custom clients to be used.\n", + " It mimics the OpenAI class, but allows for custom clients to be used.\n", " \"\"\"\n", "\n", " RESPONSE_USAGE_KEYS = [\"prompt_tokens\", \"completion_tokens\", \"total_tokens\", \"cost\", \"model\"]\n", @@ -92,20 +98,16 @@ " pass_filter: bool\n", " model: str\n", "\n", - " def update(self, config: Dict):\n", - " pass\n", - "\n", - " @abstractmethod\n", " def create(self, params) -> ClientResponseProtocol:\n", - " pass\n", + " ... # pragma: no cover\n", "\n", - " @abstractmethod\n", " def cost(self, response: ClientResponseProtocol) -> float:\n", - " pass\n", + " ... # pragma: no cover\n", "\n", " @staticmethod\n", " def get_usage(response: ClientResponseProtocol) -> Dict:\n", - " return None\n", + " \"\"\"Return usage summary of the response using RESPONSE_USAGE_KEYS.\"\"\"\n", + " ... # pragma: no cover\n", "\n", "```" ] @@ -128,7 +130,7 @@ "# custom client with custom model loader\n", "\n", "\n", - "class CustomClient(Client):\n", + "class CustomClient:\n", " def __init__(self, config, **kwargs):\n", " print(f\"CustomClient config: {config}\")\n", " self.device = config.get(\"device\", \"cpu\")\n", @@ -189,10 +191,11 @@ " response.cost = 0\n", " return 0\n", "\n", - " def get_usage(self, response):\n", + " @staticmethod\n", + " def get_usage(response):\n", " # returns a dict of prompt_tokens, completion_tokens, total_tokens, cost, model\n", " # if usage needs to be tracked, else None\n", - " return None" + " return {}" ] }, { diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 9507b057bda3..64d68e576e63 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -20,7 +20,7 @@ def test_custom_client(): TEST_OTHER_PARAMS_VAL = "other_params" TEST_MAX_LENGTH = 1000 - class CustomClient(Client): + class CustomClient: def __init__(self, config: Dict, test_hook): self.test_hook = test_hook self.device = config["device"] @@ -56,6 +56,10 @@ def cost(self, response) -> float: response.cost = TEST_COST return TEST_COST + @staticmethod + def get_usage(response) -> Dict: + return {} + config_list = [ { "model": TEST_LOCAL_MODEL_NAME, @@ -86,7 +90,7 @@ def cost(self, response) -> float: def test_registering_with_wrong_cls_name_raises_error(): - class CustomClient(Client): + class CustomClient: def __init__(self, config: Dict): pass @@ -109,7 +113,7 @@ def cost(self, response) -> float: def test_registering_with_wrong_model_name_raises_error(): - class CustomClient(Client): + class CustomClient: def __init__(self, config: Dict): pass @@ -119,6 +123,10 @@ def create(self, params): def cost(self, response) -> float: return 0 + @staticmethod + def get_usage(response) -> Dict: + return {} + config_list = [ { "model": "local_model_name_but_wrong_name", From fe57334ddd83ec9c0eb8b5dc3a14b818688959bf Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 25 Jan 2024 05:25:53 -0500 Subject: [PATCH 13/30] rename and more error checking for registered agents --- autogen/agentchat/conversable_agent.py | 8 +- autogen/oai/client.py | 102 +++++++++++++++++-------- test/oai/test_custom_client.py | 98 +++++++++++++++++++----- 3 files changed, 152 insertions(+), 56 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 19296bcf65ad..9b967e48c761 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1900,15 +1900,15 @@ def _decorator(func: F) -> F: return _decorator - def register_client(self, model: str, client_cls: Client, **kwargs): - """Register a custom client. + def register_model_client(self, model: str, model_client_cls: Client, **kwargs): + """Register a model. Args: model: The model name, as specified in the config list - client_cls: A custom client class that follows the Client interface + model_client_cls: A custom client class that follows the Client interface **kwargs: The kwargs for the custom client class to be initialized with """ - self.client.register_client(model, client_cls, **kwargs) + self.client.register_model_client(model, model_client_cls, **kwargs) def register_hook(self, hookable_method: Callable, hook: Callable): """ diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 164a76d91a98..7ccef88524dd 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -314,18 +314,23 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base openai_config, extra_kwargs = self._separate_openai_config(base_config) if type(config_list) is list and len(config_list) == 0: logger.warning("openai client was provided with an empty config_list, which may not be intended.") + self._clients: List[Client] = [] + self._config_list: List[Dict[str, Any]] = [] + if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying for config in config_list: - self._register_openai_client(config, openai_config) # could modify the config - self._config_list = [ - {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} - for config in config_list - ] + activated = self._register_openai_client(config, openai_config) # could modify the config + self._config_list.append( + { + "config": {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}, + "activated": activated, + } + ) else: - self._register_openai_client(extra_kwargs, openai_config) - self._config_list = [extra_kwargs] + activated = self._register_openai_client(extra_kwargs, openai_config) + self._config_list = [{"config": extra_kwargs, "activated": activated}] def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Separate the config into openai_config and extra_kwargs.""" @@ -339,7 +344,13 @@ def _separate_create_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} return create_config, extra_kwargs - def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> OpenAIClient: + def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: + openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) + if openai_config["azure_deployment"] is not None: + openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") + openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) + + def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> bool: """Create a client with the given config to override openai_config, after removing extra kwargs. @@ -348,40 +359,55 @@ def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[st "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot from the name and create a client that connects to "gpt-35-turbo" Azure deployment. """ + activated = False openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} api_type = config.get("api_type") - if api_type is not None and api_type.startswith("azure"): - openai_config["azure_deployment"] = openai_config.get("azure_deployment", config.get("model")) - if openai_config["azure_deployment"] is not None: - openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") - openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) - self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) - elif api_type is None: + if api_type is None: self._clients.append(OpenAIClient(OpenAI(**openai_config))) - # else a config for a custom client is set - # skipping until the register_client is called with the appropriate class + activated = True + else: + if api_type.startswith("azure"): + self._configure_azure_openai(config, openai_config) + self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) + activated = True + elif api_type.startswith("custom"): + # else a config for a custom client is set + # skipping until the register_model_client is called with the appropriate class + logging.info( + "Detected custom model client in config, skipping registration until register_model_client is called." + ) + else: + raise ValueError( + f"api_type {api_type} is not supported, please select one from ['azure', 'custom'], or remove and let it default to openai" + ) + return activated - def register_client(self, model: str, client_cls: Client, **kwargs): - """Register a custom client. + def register_model_client(self, model: str, model_client_cls: Client, **kwargs): + """Register a model client client. Args: model: The model name, as specified in the config list - client_cls: A custom client class that follows the Client interface + model_client_cls: A custom client class that follows the Client interface **kwargs: The kwargs for the custom client class to be initialized with """ - found = False - for config in self._config_list: - if ( - config["api_type"] is not None - and config["api_type"] == client_cls.__name__ - and config["model"] == model - ): - client = client_cls(config, **kwargs) + for i, config in enumerate(self._config_list): + config_model = config.get("config", {}).get("model") + activated = config.get("activated") + + if config_model == model: + if activated and model_client_cls.__name__ == self._clients[i].__class__.__name__: + raise ValueError( + f'Model "{model}" with model client "{model_client_cls.__name__}" is already registered.' + ) + + client = model_client_cls(config.get("config", {}), **kwargs) self._clients.append(client) - found = True - if not found: + config["activated"] = True + break + else: raise ValueError( - f'Pair of model "{model}" and api_type "{client_cls.__name__}" was not found in the config_list. Please make sure to include an entry in the config_list with "model": "{model}" and "api_type": "{client_cls.__name__}"' + f'Model "{model}" is being registered but was not found in the config_list. ' + 'Please make sure to include an entry in the config_list with "model": "{model}"' ) @classmethod @@ -465,10 +491,20 @@ def yes_or_no_filter(context, response): raise ERROR last = len(self._clients) - 1 if len(self._clients) == 0: - raise RuntimeError("No model client is registered. Please register a model client first.") + raise RuntimeError( + "No model client is active. Please populate the config list or register any custom model clients." + ) + # Check if all configs in config list are activated + non_activated_confs = [ + config_["config"]["model"] for config_ in self._config_list if not config_.get("activated") + ] + if non_activated_confs: + raise RuntimeError( + f"Model client(s) {non_activated_confs} are not activated. Please register the custom model clients or filter them out form the config list." + ) for i, client in enumerate(self._clients): # merge the input config with the i-th config in the config list - full_config = {**config, **self._config_list[i]} + full_config = {**config, **self._config_list[i]["config"]} # separate the config into create_config and extra_kwargs create_config, extra_kwargs = self._separate_create_config(full_config) api_type = extra_kwargs.get("api_type") diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 64d68e576e63..88ecd8640c0f 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -12,7 +12,7 @@ @pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_custom_client(): +def test_custom_model_client(): TEST_COST = 20000000 TEST_CUSTOM_RESPONSE = "This is a custom response." TEST_DEVICE = "cpu" @@ -20,7 +20,7 @@ def test_custom_client(): TEST_OTHER_PARAMS_VAL = "other_params" TEST_MAX_LENGTH = 1000 - class CustomClient: + class CustomModel: def __init__(self, config: Dict, test_hook): self.test_hook = test_hook self.device = config["device"] @@ -63,8 +63,8 @@ def get_usage(response) -> Dict: config_list = [ { "model": TEST_LOCAL_MODEL_NAME, + "api_type": "custom", "device": TEST_DEVICE, - "api_type": "CustomClient", "params": { "max_length": TEST_MAX_LENGTH, "other_params": TEST_OTHER_PARAMS_VAL, @@ -75,7 +75,7 @@ def get_usage(response) -> Dict: test_hook = {"called": False} client = OpenAIWrapper(config_list=config_list) - client.register_client(model=TEST_LOCAL_MODEL_NAME, client_cls=CustomClient, test_hook=test_hook) + client.register_model_client(model=TEST_LOCAL_MODEL_NAME, model_client_cls=CustomModel, test_hook=test_hook) response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE @@ -89,8 +89,9 @@ def get_usage(response) -> Dict: assert test_hook["max_length"] == TEST_MAX_LENGTH -def test_registering_with_wrong_cls_name_raises_error(): - class CustomClient: +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_registering_with_wrong_model_name_raises_error(): + class CustomModel: def __init__(self, config: Dict): pass @@ -100,20 +101,45 @@ def create(self, params): def cost(self, response) -> float: return 0 + @staticmethod + def get_usage(response) -> Dict: + return {} + config_list = [ { - "model": "local_model_name", - "api_type": "CustomClientButWrongName", + "model": "local_model_name_but_wrong_name", + "api_type": "custom", }, ] client = OpenAIWrapper(config_list=config_list) with pytest.raises(ValueError): - client.register_client(model="local_model_name", client_cls=CustomClient) + client.register_model_client(model="local_model_name", model_client_cls=CustomModel) -def test_registering_with_wrong_model_name_raises_error(): - class CustomClient: +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_no_client_registered_raises_error(): + config_list = [ + { + "model": "local_model_name", + "api_type": "custom", + "device": "cpu", + "params": { + "max_length": 1000, + "other_params": "other_params", + }, + }, + ] + + client = OpenAIWrapper(config_list=config_list) + + with pytest.raises(RuntimeError): + client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_not_all_clients_registered_raises_error(): + class CustomModel: def __init__(self, config: Dict): pass @@ -129,22 +155,54 @@ def get_usage(response) -> Dict: config_list = [ { - "model": "local_model_name_but_wrong_name", - "api_type": "CustomClient", + "model": "local_model_name", + "api_type": "custom", + "device": "cpu", + "params": { + "max_length": 1000, + "other_params": "other_params", + }, + }, + { + "model": "local_model_name_2", + "api_type": "custom", + "device": "cpu", + "params": { + "max_length": 1000, + "other_params": "other_params", + }, }, ] + client = OpenAIWrapper(config_list=config_list) - with pytest.raises(ValueError): - client.register_client(model="local_model_name", client_cls=CustomClient) + client.register_model_client(model="local_model_name", model_client_cls=CustomModel) + + with pytest.raises(RuntimeError): + client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_registering_same_client_twice_raises_error(): + class CustomModel: + def __init__(self, config: Dict): + pass + + def create(self, params): + return None + + def cost(self, response) -> float: + return 0 + @staticmethod + def get_usage(response) -> Dict: + return {} -def test_custom_client_not_registered_raises_error(): config_list = [ { "model": "local_model_name", + "api_type": "custom", "device": "cpu", - "api_type": "CustomClient", "params": { "max_length": 1000, "other_params": "other_params", @@ -154,5 +212,7 @@ def test_custom_client_not_registered_raises_error(): client = OpenAIWrapper(config_list=config_list) - with pytest.raises(RuntimeError): - client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) + client.register_model_client(model="local_model_name", model_client_cls=CustomModel) + + with pytest.raises(ValueError): + client.register_model_client(model="local_model_name", model_client_cls=CustomModel) From 4bd1ba14d57e3c33d7e461bc3ee954753e91d5ce Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 25 Jan 2024 05:27:28 -0500 Subject: [PATCH 14/30] update notebook --- notebook/agentchat_custom_model.ipynb | 56 +++++++-------------------- 1 file changed, 14 insertions(+), 42 deletions(-) diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 513dbc518d81..f3bc2374ca1f 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -52,7 +52,6 @@ "outputs": [], "source": [ "import autogen\n", - "from autogen.oai import Client\n", "from autogen import AssistantAgent, UserProxyAgent\n", "from transformers import AutoTokenizer, GenerationConfig, AutoModelForCausalLM\n", "from types import SimpleNamespace" @@ -130,9 +129,9 @@ "# custom client with custom model loader\n", "\n", "\n", - "class CustomClient:\n", + "class CustomModelClient:\n", " def __init__(self, config, **kwargs):\n", - " print(f\"CustomClient config: {config}\")\n", + " print(f\"CustomModelClient config: {config}\")\n", " self.device = config.get(\"device\", \"cpu\")\n", " self.model = AutoModelForCausalLM.from_pretrained(config[\"model\"]).to(self.device)\n", " self.model_name = config[\"model\"]\n", @@ -244,12 +243,12 @@ "\n", "You can add any paramteres that are needed for the custom model loading in the same configuration list, as long as `model` is specified\n", "\n", - "It is important to add the `api_type` with the name of the new client class\n", + "It is important to add the `api_type` and set it to `\"custom\"`.\n", "\n", "```json\n", "{\n", " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", - " \"api_type\": \"CustomClient\",\n", + " \"api_type\": \"custom\",\n", " \"device\": \"cuda\",\n", " \"n\": 1,\n", " \"params\": {\n", @@ -267,7 +266,7 @@ "source": [ "config_list_custom = autogen.config_list_from_json(\n", " \"OAI_CONFIG_LIST\",\n", - " filter_dict={\"api_type\": [\"CustomClient\"]},\n", + " filter_dict={\"api_type\": [\"custom\"]},\n", ")" ] }, @@ -311,7 +310,7 @@ "outputs": [], "source": [ "config = config_list_custom[0]\n", - "assistant.register_client(model=config[\"model\"], client_cls=CustomClient)" + "assistant.register_model_client(model=config[\"model\"], model_client_cls=CustomModelClient)" ] }, { @@ -341,9 +340,9 @@ "# custom client with custom model loader\n", "\n", "\n", - "class CustomClientWithArguments(CustomClient):\n", + "class CustomModelClientWithArguments(CustomModelClient):\n", " def __init__(self, config, loaded_model, tokenizer, **kwargs):\n", - " print(f\"CustomClientWithArguments config: {config}\")\n", + " print(f\"CustomModelClientWithArguments config: {config}\")\n", "\n", " self.model_name = config[\"model\"]\n", " self.model = loaded_model\n", @@ -356,37 +355,6 @@ " print(f\"Loaded model {config['model']} to {self.device}\")" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Update the config list to include the CustomClientWithArguments\n", - "\n", - "```json\n", - "{\n", - " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", - " \"api_type\": \"CustomClientWithArguments\",\n", - " \"device\": \"cuda\",\n", - " \"n\": 1,\n", - " \"params\": {\n", - " \"max_length\": 1000,\n", - " }\n", - "},\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "config_list_custom = autogen.config_list_from_json(\n", - " \"OAI_CONFIG_LIST\",\n", - " filter_dict={\"api_type\": [\"CustomClientWithArguments\"]},\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -417,8 +385,11 @@ "metadata": {}, "outputs": [], "source": [ - "assistant.register_client(\n", - " model=config[\"model\"], client_cls=CustomClientWithArguments, loaded_model=loaded_model, tokenizer=tokenizer\n", + "assistant.register_model_client(\n", + " model=config[\"model\"],\n", + " model_client_cls=CustomModelClientWithArguments,\n", + " loaded_model=loaded_model,\n", + " tokenizer=tokenizer,\n", ")" ] }, @@ -428,6 +399,7 @@ "metadata": {}, "outputs": [], "source": [ + "# throw here if user forgot to register the client\n", "user_proxy.initiate_chat(assistant, message=\"Plot a chart of NVDA and TESLA stock price change YTD.\")" ] } From 69c1ab58d27632029e0d228f310df2a11963fe58 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 25 Jan 2024 07:54:42 -0500 Subject: [PATCH 15/30] notebook cleanup and added blog --- notebook/agentchat_custom_model.ipynb | 29 ++-- .../blog/2024-01-26-Custom-Models/index.mdx | 156 ++++++++++++++++++ 2 files changed, 173 insertions(+), 12 deletions(-) create mode 100644 website/blog/2024-01-26-Custom-Models/index.mdx diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index f3bc2374ca1f..5d3e1d94af2a 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -19,7 +19,9 @@ "source": [ "# Agent Chat with custom model loading\n", "\n", - "In this notebook, we demonstrate how a custom model can be defined and loaded, and what interface it needs to comply to.\n", + "In this notebook, we demonstrate how a custom model can be defined and loaded, and what protocol it needs to comply to.\n", + "\n", + "**NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent's**\n", "\n", "## Requirements\n", "\n", @@ -61,9 +63,11 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Create and configure the custom client / model loader\n", + "## Create and configure the custom model\n", + "\n", + "A custom model class can be created in many ways, but needs to adhere to the `Client` protocol and response structure which is defined in client.py and shown below.\n", "\n", - "A custom client can be created in many ways, but needs to adhere to the `Client` protocol and response structure which is defined in client.py\n", + "The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier.\n", "\n", "```python\n", "\n", @@ -98,15 +102,15 @@ " model: str\n", "\n", " def create(self, params) -> ClientResponseProtocol:\n", - " ... # pragma: no cover\n", + " ...\n", "\n", " def cost(self, response: ClientResponseProtocol) -> float:\n", - " ... # pragma: no cover\n", + " ...\n", "\n", " @staticmethod\n", " def get_usage(response: ClientResponseProtocol) -> Dict:\n", " \"\"\"Return usage summary of the response using RESPONSE_USAGE_KEYS.\"\"\"\n", - " ... # pragma: no cover\n", + " ...\n", "\n", "```" ] @@ -117,7 +121,9 @@ "source": [ "## Example of simple custom client\n", "\n", - "Following the huggingface example for using [Mistral's Open-Orca](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca)" + "Following the huggingface example for using [Mistral's Open-Orca](https://huggingface.co/Open-Orca/Mistral-7B-OpenOrca)\n", + "\n", + "For the response object, python's `SimpleNamespace` is used to create a simple object that can be used to store the response data, but any object that follows the `ClientResponseProtocol` can be used.\n" ] }, { @@ -146,8 +152,8 @@ " print(f\"Loaded model {config['model']} to {self.device}\")\n", "\n", " def create(self, params):\n", - " if params.get(\"stream\", False) and \"messages\" in params and \"functions\" not in params:\n", - " raise NotImplementedError(\"Local models do not support streaming or functions\")\n", + " if params.get(\"stream\", False) and \"messages\" in params:\n", + " raise NotImplementedError(\"Local models do not support streaming.\")\n", " else:\n", " num_of_responses = params.get(\"n\", 1)\n", "\n", @@ -319,7 +325,7 @@ "metadata": {}, "outputs": [], "source": [ - "user_proxy.initiate_chat(assistant, message=\"Plot a chart of NVDA and TESLA stock price change YTD.\")" + "user_proxy.initiate_chat(assistant, message=\"Write python code to print Hello World!\")" ] }, { @@ -399,8 +405,7 @@ "metadata": {}, "outputs": [], "source": [ - "# throw here if user forgot to register the client\n", - "user_proxy.initiate_chat(assistant, message=\"Plot a chart of NVDA and TESLA stock price change YTD.\")" + "user_proxy.initiate_chat(assistant, message=\"Write python code to print Hello World!\")" ] } ], diff --git a/website/blog/2024-01-26-Custom-Models/index.mdx b/website/blog/2024-01-26-Custom-Models/index.mdx new file mode 100644 index 000000000000..3958c42d6d56 --- /dev/null +++ b/website/blog/2024-01-26-Custom-Models/index.mdx @@ -0,0 +1,156 @@ +--- +title: "AutoGen with Custom Models: Empowering Users to Use Their Own Inference Mechanism" +authors: + - olgavrou +tags: [AutoGen] +--- + +## TL;DR + +AutoGen now supports custom models! This feature empowers users to define and load their own models, allowing for a more flexible and personalized inference mechanism. By adhering to a specific protocol, you can integrate your custom model for use with AutoGen and respond to prompts any way needed by using any model/API call/hardcoded response you want. + +**NOTE: Depending on what model you use, you may need to play with the default prompts of the Agent's** + +## Quickstart + +An interactive and easy way to get started is by following the notebook [here]() which loads a local model from HuggingFace into AutoGen and uses it for inference, and making changes to the class provided. + +### Step 1: Create the custom model client class + +To get started with using custom models in AutoGen, you need to create a model client class that adheres to the `Client` protocol defined in `client.py`. The `Client` class should implement these methods: + +- `create()`: Returns a response object that implements the `ClientResponseProtocol` (more details in the Protocol section). +- `cost()`: Returns the cost of the response. +- `get_usage()`: Returns a dictionary with keys from `RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]`. + +E.g. of a bare bones dummy custom class: + +```python +class CustomModelClient: + def __init__(self, config, **kwargs): + print(f"CustomModelClient config: {config}") + + def create(self, params): + num_of_responses = params.get("n", 1) + + # can create my own data response class + # here using SimpleNamespace for simplicity + # as long as it adheres to the ClientResponseProtocol + + response = SimpleNamespace() + response.choices = [] + response.model = "model_name" # should match the OAI_CONFIG_LIST registration + + for _ in range(num_of_responses): + text = "this is a dummy text response" + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = text + choice.message.function_call = None + response.choices.append(choice) + return response + + def cost(self, response) -> float: + response.cost = 0 + return 0 + + @staticmethod + def get_usage(response): + return {} +``` + +### Step 2: Add the configuration to the OAI_CONFIG_LIST + +The two fields that are necessary in the config are `"model" [str]` and `"api_type":"custom"`. Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.: + +```json +{ + "model": "Open-Orca/Mistral-7B-OpenOrca", + "api_type": "custom", + "device": "cuda", + "n": 1, + "params": { + "max_length": 1000, + } +} +``` + +### Step 3: Register the new custom model to the agent that will use it + +If a configuration with the field `"api_type":"custom"` has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized: + +```python +my_agent.register_model_client(model="Open-Orca/Mistral-7B-OpenOrca", model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor]) +``` + +`model` matches the one specified in the `OAI_CONFIG_LIST` and `CustomModelClient` is the class that adheres to the `Client` protocol (more details on the protocol below). + +If the new model client is in the config list but not registered by the time the chat is initialized, then an error will be raised. + +## Protocol details + +A custom model class can be created in many ways, but needs to adhere to the `Client` protocol and response structure which is defined in `client.py` and shown below. + +The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier. + +```python + +class Client(Protocol): + """ + A client class must implement the following methods: + - create must return a response object that implements the ClientResponseProtocol + - cost must return the cost of the response + - get_usage must return a dict with the following keys: + - prompt_tokens + - completion_tokens + - total_tokens + - cost + - model + + This class is used to create a client that can be used by OpenAIWrapper. + It mimics the OpenAI class, but allows for custom clients to be used. + """ + + RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] + + class ClientResponseProtocol(Protocol): + class Choice(Protocol): + class Message(Protocol): + content: str | None + function_call: str | None + + choices: List[Choice] + config_id: int + cost: float + pass_filter: bool + model: str + + def create(self, params) -> ClientResponseProtocol: + ... + + def cost(self, response: ClientResponseProtocol) -> float: + ... + + @staticmethod + def get_usage(response: ClientResponseProtocol) -> Dict: + """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" + ... + +``` + +## Troubleshooting steps + +If something doesn't work then run through the checklist: + +- Make sure you have followed the client protocol and client response protocol when creating the custom model class + - `create()` method: `ClientResponseProtocol` must be followed when returning an inference response during `create` call. + - `cost()`method: returns an integer, and if you don't care about cost tracking you can just return `0`. + - `get_usage()`: returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary `{}`. +- Make sure you have a corresponding entry in the `OAI_CONFIG_LIST` and that that entry has the `"model" [str]` and `"api_type":"custom"` fields. +- Make sure you have registered the client using the corresponding config entry and your new class `agent.register_model_client(model="", model_client_cls=, [other optional args])` +- Make sure you have registered only unique pairs of (`model`, `model_client_cls`) and don't try to register the same pair twice. +- Any other troubleshooting might need to be done in the custom code itself. + +## Conclusion + +With the ability to use custom models, AutoGen now offers even more flexibility and power for your AI applications. Whether you've trained your own model or want to use a specific pre-trained model, AutoGen can accommodate your needs. Happy coding! From 2127effff027727c4e81c74f5094e98bf7d9b6eb Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 25 Jan 2024 10:12:22 -0500 Subject: [PATCH 16/30] add link --- notebook/agentchat_custom_model.ipynb | 4 +--- website/blog/2024-01-26-Custom-Models/index.mdx | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 5d3e1d94af2a..8d1f5b9e93b8 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -70,7 +70,6 @@ "The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier.\n", "\n", "```python\n", - "\n", "class Client(Protocol):\n", " \"\"\"\n", " A client class must implement the following methods:\n", @@ -111,8 +110,7 @@ " def get_usage(response: ClientResponseProtocol) -> Dict:\n", " \"\"\"Return usage summary of the response using RESPONSE_USAGE_KEYS.\"\"\"\n", " ...\n", - "\n", - "```" + "```\n" ] }, { diff --git a/website/blog/2024-01-26-Custom-Models/index.mdx b/website/blog/2024-01-26-Custom-Models/index.mdx index 3958c42d6d56..a95e71972774 100644 --- a/website/blog/2024-01-26-Custom-Models/index.mdx +++ b/website/blog/2024-01-26-Custom-Models/index.mdx @@ -13,7 +13,7 @@ AutoGen now supports custom models! This feature empowers users to define and lo ## Quickstart -An interactive and easy way to get started is by following the notebook [here]() which loads a local model from HuggingFace into AutoGen and uses it for inference, and making changes to the class provided. +An interactive and easy way to get started is by following the notebook [here](https://github.com/microsoft/autogen/blob/main/notebook/agentchat_custom_model.ipynb) which loads a local model from HuggingFace into AutoGen and uses it for inference, and making changes to the class provided. ### Step 1: Create the custom model client class From 8396e304bfa9b4043730feac0a074156b1a8da2e Mon Sep 17 00:00:00 2001 From: olgavrou Date: Thu, 25 Jan 2024 23:35:38 -0500 Subject: [PATCH 17/30] adding message retrieval to client protocol for more flexible response --- autogen/oai/client.py | 75 +++++++++++-------- notebook/agentchat_custom_model.ipynb | 29 +++++-- test/oai/test_custom_client.py | 39 ++++++---- .../blog/2024-01-26-Custom-Models/index.mdx | 24 ++++-- 4 files changed, 106 insertions(+), 61 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 7ccef88524dd..dcb7b8c26804 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -66,7 +66,8 @@ class Client(Protocol): - model This class is used to create a client that can be used by OpenAIWrapper. - It mimics the OpenAI class, but allows for custom clients to be used. + The response returned from create must adhere to the ClientResponseProtocol but can be extended however needed. + The message_retrieval method must be implemented to return a list of str or a list of messages from the response. """ RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] @@ -75,17 +76,24 @@ class ClientResponseProtocol(Protocol): class Choice(Protocol): class Message(Protocol): content: str | None - function_call: str | None choices: List[Choice] - config_id: int - cost: float - pass_filter: bool model: str def create(self, params) -> ClientResponseProtocol: ... # pragma: no cover + def message_retrieval( + self, response: ClientResponseProtocol + ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + ... # pragma: no cover + def cost(self, response: ClientResponseProtocol) -> float: ... # pragma: no cover @@ -101,6 +109,27 @@ class OpenAIClient: def __init__(self, client): self._oai_client = client + def message_retrieval( + self, response: Union[ChatCompletion, Completion] + ) -> Union[List[str], List[ChatCompletionMessage]]: + """Retrieve the messages from the response.""" + choices = response.choices + if isinstance(response, Completion): + return [choice.text for choice in choices] # type: ignore [union-attr] + + if TOOL_ENABLED: + return [ # type: ignore [return-value] + choice.message # type: ignore [union-attr] + if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] + else choice.message.content # type: ignore [union-attr] + for choice in choices + ] + else: + return [ # type: ignore [return-value] + choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] + for choice in choices + ] + def create(self, params: Dict[str, Any]) -> ChatCompletion: """Create a completion for a given config using openai's client. @@ -454,9 +483,9 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params - def create(self, **config: Any) -> ChatCompletion: - """Make a completion for a given config using openai's clients. - Besides the kwargs allowed in openai's client, we allow the following additional kwargs. + def create(self, **config: Any) -> Client.ClientResponseProtocol: + """Make a completion for a given config using available clients. + Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config. Args: @@ -533,7 +562,7 @@ def yes_or_no_filter(context, response): with cache_client as cache: # Try to get the response from cache key = get_key(params) - response: ChatCompletion = cache.get(key, None) + response: Client.ClientResponseProtocol = cache.get(key, None) if response is not None: try: @@ -563,6 +592,7 @@ def yes_or_no_filter(context, response): if i == last: raise else: + response.message_retrieval_function = client.message_retrieval # add cost calculation before caching no matter filter is passed or not response.cost = client.cost(response) actual_usage = client.get_usage(response) @@ -783,8 +813,8 @@ def clear_usage_summary(self) -> None: @classmethod def extract_text_or_completion_object( - cls, response: Union[ChatCompletion, Completion] - ) -> Union[List[str], List[ChatCompletionMessage]]: + cls, response: Client.ClientResponseProtocol + ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]: """Extract the text or ChatCompletion objects from a completion or chat response. Args: @@ -793,28 +823,7 @@ def extract_text_or_completion_object( Returns: A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present. """ - choices = response.choices - if isinstance(response, Completion): - return [choice.text for choice in choices] # type: ignore [union-attr] - - if not isinstance(response, ChatCompletion) and not isinstance(response, Completion): - return [ - choice.message if choice.message.function_call is not None else choice.message.content - for choice in choices - ] - - if TOOL_ENABLED: - return [ # type: ignore [return-value] - choice.message # type: ignore [union-attr] - if choice.message.function_call is not None or choice.message.tool_calls is not None # type: ignore [union-attr] - else choice.message.content # type: ignore [union-attr] - for choice in choices - ] - else: - return [ # type: ignore [return-value] - choice.message if choice.message.function_call is not None else choice.message.content # type: ignore [union-attr] - for choice in choices - ] + return response.message_retrieval_function(response) # TODO: logging diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 8d1f5b9e93b8..7b8c7157fd8f 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -67,7 +67,9 @@ "\n", "A custom model class can be created in many ways, but needs to adhere to the `Client` protocol and response structure which is defined in client.py and shown below.\n", "\n", - "The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier.\n", + "The response protocol has some minimum requirements, but can be extended to include any additional information that is needed.\n", + "Message retrieval therefore can be customized, but needs to return a list of strings or a list of `ClientResponseProtocol.Choice.Message` objects.\n", + "\n", "\n", "```python\n", "class Client(Protocol):\n", @@ -83,7 +85,8 @@ " - model\n", "\n", " This class is used to create a client that can be used by OpenAIWrapper.\n", - " It mimics the OpenAI class, but allows for custom clients to be used.\n", + " The response returned from create must adhere to the ClientResponseProtocol but can be extended however needed.\n", + " The message_retrieval method must be implemented to return a list of str or a list of messages from the response.\n", " \"\"\"\n", "\n", " RESPONSE_USAGE_KEYS = [\"prompt_tokens\", \"completion_tokens\", \"total_tokens\", \"cost\", \"model\"]\n", @@ -92,17 +95,24 @@ " class Choice(Protocol):\n", " class Message(Protocol):\n", " content: str | None\n", - " function_call: str | None\n", "\n", " choices: List[Choice]\n", - " config_id: int\n", - " cost: float\n", - " pass_filter: bool\n", " model: str\n", "\n", " def create(self, params) -> ClientResponseProtocol:\n", " ...\n", "\n", + " def message_retrieval(\n", + " self, response: ClientResponseProtocol\n", + " ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]:\n", + " \"\"\"\n", + " Retrieve and return a list of strings or a list of Choice.Message from the response.\n", + "\n", + " NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object,\n", + " since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used.\n", + " \"\"\"\n", + " ...\n", + "\n", " def cost(self, response: ClientResponseProtocol) -> float:\n", " ...\n", "\n", @@ -189,6 +199,11 @@ "\n", " return response\n", "\n", + " def message_retrieval(self, response):\n", + " \"\"\"Retrieve the messages from the response.\"\"\"\n", + " choices = response.choices\n", + " return [choice.message.content for choice in choices]\n", + "\n", " def cost(self, response) -> float:\n", " \"\"\"Calculate the cost of the response.\"\"\"\n", " response.cost = 0\n", @@ -290,7 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})\n", + "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom, \"cache_seed\": None})\n", "user_proxy = UserProxyAgent(\n", " \"user_proxy\",\n", " code_execution_config={\n", diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 88ecd8640c0f..be019d529329 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -35,21 +35,20 @@ def __init__(self, config: Dict, test_hook): self.test_hook["max_length"] = self.max_length def create(self, params): - if params.get("stream", False) and "messages" in params and "functions" not in params: - raise NotImplementedError("Custom Client does not support streaming or functions") - else: - from types import SimpleNamespace - - response = SimpleNamespace() - # need to follow Client.ClientResponseProtocol - response.choices = [] - choice = SimpleNamespace() - choice.message = SimpleNamespace() - choice.message.content = TEST_CUSTOM_RESPONSE - choice.message.function_call = None - response.choices.append(choice) - response.model = self.model - return response + from types import SimpleNamespace + + response = SimpleNamespace() + # need to follow Client.ClientResponseProtocol + response.choices = [] + choice = SimpleNamespace() + choice.message = SimpleNamespace() + choice.message.content = TEST_CUSTOM_RESPONSE + response.choices.append(choice) + response.model = self.model + return response + + def message_retrieval(self, response): + return [response.choices[0].message.content] def cost(self, response) -> float: """Calculate the cost of the response.""" @@ -79,7 +78,6 @@ def get_usage(response) -> Dict: response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE - assert response.choices[0].message.function_call is None assert response.cost == TEST_COST assert test_hook["called"] @@ -98,6 +96,9 @@ def __init__(self, config: Dict): def create(self, params): return None + def message_retrieval(self, response): + return [] + def cost(self, response) -> float: return 0 @@ -146,6 +147,9 @@ def __init__(self, config: Dict): def create(self, params): return None + def message_retrieval(self, response): + return [] + def cost(self, response) -> float: return 0 @@ -191,6 +195,9 @@ def __init__(self, config: Dict): def create(self, params): return None + def message_retrieval(self, response): + return [] + def cost(self, response) -> float: return 0 diff --git a/website/blog/2024-01-26-Custom-Models/index.mdx b/website/blog/2024-01-26-Custom-Models/index.mdx index a95e71972774..987b8d31c3d1 100644 --- a/website/blog/2024-01-26-Custom-Models/index.mdx +++ b/website/blog/2024-01-26-Custom-Models/index.mdx @@ -20,6 +20,7 @@ An interactive and easy way to get started is by following the notebook [here](h To get started with using custom models in AutoGen, you need to create a model client class that adheres to the `Client` protocol defined in `client.py`. The `Client` class should implement these methods: - `create()`: Returns a response object that implements the `ClientResponseProtocol` (more details in the Protocol section). +- `message_retrieval()`: Processes the response object and returns a list of strings or a list of message objects (more details in the Protocol section). - `cost()`: Returns the cost of the response. - `get_usage()`: Returns a dictionary with keys from `RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]`. @@ -50,6 +51,10 @@ class CustomModelClient: response.choices.append(choice) return response + def message_retrieval(self, response): + choices = response.choices + return [choice.message.content for choice in choices] + def cost(self, response) -> float: response.cost = 0 return 0 @@ -108,7 +113,8 @@ class Client(Protocol): - model This class is used to create a client that can be used by OpenAIWrapper. - It mimics the OpenAI class, but allows for custom clients to be used. + The response returned from create must adhere to the ClientResponseProtocol but can be extended however needed. + The message_retrieval method must be implemented to return a list of str or a list of messages from the response. """ RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] @@ -117,17 +123,24 @@ class Client(Protocol): class Choice(Protocol): class Message(Protocol): content: str | None - function_call: str | None choices: List[Choice] - config_id: int - cost: float - pass_filter: bool model: str def create(self, params) -> ClientResponseProtocol: ... + def message_retrieval( + self, response: ClientResponseProtocol + ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]: + """ + Retrieve and return a list of strings or a list of Choice.Message from the response. + + NOTE: if a list of Choice.Message is returned, it currently needs to contain the fields of OpenAI's ChatCompletion Message object, + since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. + """ + ... + def cost(self, response: ClientResponseProtocol) -> float: ... @@ -144,6 +157,7 @@ If something doesn't work then run through the checklist: - Make sure you have followed the client protocol and client response protocol when creating the custom model class - `create()` method: `ClientResponseProtocol` must be followed when returning an inference response during `create` call. + - `message_retrieval()` method: returns a list of strings or a list of message objects. If a list of message objects is returned, they currently must contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. - `cost()`method: returns an integer, and if you don't care about cost tracking you can just return `0`. - `get_usage()`: returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary `{}`. - Make sure you have a corresponding entry in the `OAI_CONFIG_LIST` and that that entry has the `"model" [str]` and `"api_type":"custom"` fields. From fe0ffb50a0f34226f7dd51db83cad8c0b88132ee Mon Sep 17 00:00:00 2001 From: olgavrou Date: Mon, 29 Jan 2024 17:25:16 -0500 Subject: [PATCH 18/30] Update autogen/oai/client.py Co-authored-by: Eric Zhu --- autogen/oai/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 7ccef88524dd..0514c56e15fe 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -383,7 +383,7 @@ def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[st return activated def register_model_client(self, model: str, model_client_cls: Client, **kwargs): - """Register a model client client. + """Register a model client. Args: model: The model name, as specified in the config list From a8563a2ad3387af932808eca9a351c7686bd7e60 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 30 Jan 2024 17:44:46 +0000 Subject: [PATCH 19/30] don't add retrieval function to cache --- autogen/oai/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index d2dd5e052eb2..40ba7a029a4e 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -565,6 +565,7 @@ def yes_or_no_filter(context, response): response: Client.ClientResponseProtocol = cache.get(key, None) if response is not None: + response.message_retrieval_function = client.message_retrieval try: response.cost # type: ignore [attr-defined] except AttributeError: @@ -592,7 +593,6 @@ def yes_or_no_filter(context, response): if i == last: raise else: - response.message_retrieval_function = client.message_retrieval # add cost calculation before caching no matter filter is passed or not response.cost = client.cost(response) actual_usage = client.get_usage(response) @@ -603,6 +603,7 @@ def yes_or_no_filter(context, response): with cache_client as cache: cache.set(key, response) + response.message_retrieval_function = client.message_retrieval # check the filter pass_filter = filter_func is None or filter_func(context=context, response=response) if pass_filter or i == last: From ed06d75484a50f2baa4c990e213458ca904950f8 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Tue, 30 Jan 2024 19:53:44 +0000 Subject: [PATCH 20/30] added placeholder cllient class during initial client init, and rewrote registration --- autogen/agentchat/conversable_agent.py | 5 +- autogen/oai/client.py | 80 +++++++++---------- notebook/agentchat_custom_model.ipynb | 46 +++++++++-- test/oai/test_custom_client.py | 59 +++----------- .../blog/2024-01-26-Custom-Models/index.mdx | 16 ++-- 5 files changed, 95 insertions(+), 111 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index db1af9953a19..abd85c0ae99a 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1914,15 +1914,14 @@ def _decorator(func: F) -> F: return _decorator - def register_model_client(self, model: str, model_client_cls: Client, **kwargs): + def register_model_client(self, model_client_cls: Client, **kwargs): """Register a model. Args: - model: The model name, as specified in the config list model_client_cls: A custom client class that follows the Client interface **kwargs: The kwargs for the custom client class to be initialized with """ - self.client.register_model_client(model, model_client_cls, **kwargs) + self.client.register_model_client(model_client_cls, **kwargs) def register_hook(self, hookable_method: Callable, hook: Callable): """ diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 40ba7a029a4e..f58a2f8896c5 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -103,6 +103,11 @@ def get_usage(response: ClientResponseProtocol) -> Dict: ... # pragma: no cover +class PlaceHolderClient: + def __init__(self, config): + self.config = config + + class OpenAIClient: """Follows the Client protocol and wraps the OpenAI client.""" @@ -350,16 +355,13 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base if config_list: config_list = [config.copy() for config in config_list] # make a copy before modifying for config in config_list: - activated = self._register_openai_client(config, openai_config) # could modify the config + self._register_default_client(config, openai_config) # could modify the config self._config_list.append( - { - "config": {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}, - "activated": activated, - } + {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} ) else: - activated = self._register_openai_client(extra_kwargs, openai_config) - self._config_list = [{"config": extra_kwargs, "activated": activated}] + self._register_default_client(extra_kwargs, openai_config) + self._config_list = [extra_kwargs] def _separate_openai_config(self, config: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Separate the config into openai_config and extra_kwargs.""" @@ -379,7 +381,7 @@ def _configure_azure_openai(self, config: Dict[str, Any], openai_config: Dict[st openai_config["azure_deployment"] = openai_config["azure_deployment"].replace(".", "") openai_config["azure_endpoint"] = openai_config.get("azure_endpoint", openai_config.pop("base_url", None)) - def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> bool: + def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[str, Any]) -> None: """Create a client with the given config to override openai_config, after removing extra kwargs. @@ -388,55 +390,49 @@ def _register_openai_client(self, config: Dict[str, Any], openai_config: Dict[st "gpt-35-turbo" and define model "gpt-3.5-turbo" in the config the function will remove the dot from the name and create a client that connects to "gpt-35-turbo" Azure deployment. """ - activated = False openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} api_type = config.get("api_type") if api_type is None: self._clients.append(OpenAIClient(OpenAI(**openai_config))) - activated = True else: if api_type.startswith("azure"): self._configure_azure_openai(config, openai_config) self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) - activated = True - elif api_type.startswith("custom"): + else: # else a config for a custom client is set # skipping until the register_model_client is called with the appropriate class - logging.info( - "Detected custom model client in config, skipping registration until register_model_client is called." - ) - else: - raise ValueError( - f"api_type {api_type} is not supported, please select one from ['azure', 'custom'], or remove and let it default to openai" + self._clients.append(PlaceHolderClient(config)) + logger.info( + f"Detected custom model client in config: {api_type}, skipping registration until register_model_client is called." ) - return activated - def register_model_client(self, model: str, model_client_cls: Client, **kwargs): + def register_model_client(self, model_client_cls: Client, **kwargs): """Register a model client. Args: - model: The model name, as specified in the config list model_client_cls: A custom client class that follows the Client interface **kwargs: The kwargs for the custom client class to be initialized with """ - for i, config in enumerate(self._config_list): - config_model = config.get("config", {}).get("model") - activated = config.get("activated") - - if config_model == model: - if activated and model_client_cls.__name__ == self._clients[i].__class__.__name__: - raise ValueError( - f'Model "{model}" with model client "{model_client_cls.__name__}" is already registered.' - ) - - client = model_client_cls(config.get("config", {}), **kwargs) - self._clients.append(client) - config["activated"] = True - break + existing_client_class = False + for i, client in enumerate(self._clients): + if isinstance(client, PlaceHolderClient): + placeholder_config = client.config + + if placeholder_config in self._config_list: + if placeholder_config.get("api_type") == model_client_cls.__name__: + self._clients[i] = model_client_cls(placeholder_config, **kwargs) + return + elif isinstance(client, model_client_cls): + existing_client_class = True + + if existing_client_class: + logger.warn( + f"Model client {model_client_cls.__name__} is already registered. Add more entires in the config_list to use multiple model clients." + ) else: raise ValueError( - f'Model "{model}" is being registered but was not found in the config_list. ' - 'Please make sure to include an entry in the config_list with "model": "{model}"' + f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. ' + f'Please make sure to include an entry in the config_list with "api_type": "{model_client_cls.__name__}"' ) @classmethod @@ -524,16 +520,14 @@ def yes_or_no_filter(context, response): "No model client is active. Please populate the config list or register any custom model clients." ) # Check if all configs in config list are activated - non_activated_confs = [ - config_["config"]["model"] for config_ in self._config_list if not config_.get("activated") - ] - if non_activated_confs: + non_activated = [client.config["api_type"] for client in self._clients if isinstance(client, PlaceHolderClient)] + if non_activated: raise RuntimeError( - f"Model client(s) {non_activated_confs} are not activated. Please register the custom model clients or filter them out form the config list." + f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list." ) for i, client in enumerate(self._clients): # merge the input config with the i-th config in the config list - full_config = {**config, **self._config_list[i]["config"]} + full_config = {**config, **self._config_list[i]} # separate the config into create_config and extra_kwargs create_config, extra_kwargs = self._separate_create_config(full_config) api_type = extra_kwargs.get("api_type") diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 7b8c7157fd8f..82a6ba9bc5bf 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -260,14 +260,14 @@ "source": [ "## Set the config for the custom model\n", "\n", - "You can add any paramteres that are needed for the custom model loading in the same configuration list, as long as `model` is specified\n", + "You can add any paramteres that are needed for the custom model loading in the same configuration list.\n", "\n", - "It is important to add the `api_type` and set it to `\"custom\"`.\n", + "It is important to add the `api_type` and set it to a string that corresponds to the class name: `\"CustomModelClient\"`.\n", "\n", "```json\n", "{\n", " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", - " \"api_type\": \"custom\",\n", + " \"api_type\": \"CustomModelClient\",\n", " \"device\": \"cuda\",\n", " \"n\": 1,\n", " \"params\": {\n", @@ -285,7 +285,7 @@ "source": [ "config_list_custom = autogen.config_list_from_json(\n", " \"OAI_CONFIG_LIST\",\n", - " filter_dict={\"api_type\": [\"custom\"]},\n", + " filter_dict={\"api_type\": [\"CustomModelClient\"]},\n", ")" ] }, @@ -305,7 +305,7 @@ "metadata": {}, "outputs": [], "source": [ - "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom, \"cache_seed\": None})\n", + "assistant = AssistantAgent(\"assistant\", llm_config={\"config_list\": config_list_custom})\n", "user_proxy = UserProxyAgent(\n", " \"user_proxy\",\n", " code_execution_config={\n", @@ -329,7 +329,7 @@ "outputs": [], "source": [ "config = config_list_custom[0]\n", - "assistant.register_model_client(model=config[\"model\"], model_client_cls=CustomModelClient)" + "assistant.register_model_client(model_client_cls=CustomModelClient)" ] }, { @@ -389,6 +389,37 @@ "tokenizer.pad_token_id = tokenizer.eos_token_id" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Add the config of the new custom model\n", + "\n", + "```json\n", + "{\n", + " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", + " \"api_type\": \"CustomModelClientWithArguments\",\n", + " \"device\": \"cuda\",\n", + " \"n\": 1,\n", + " \"params\": {\n", + " \"max_length\": 1000,\n", + " }\n", + "},\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "config_list_custom = autogen.config_list_from_json(\n", + " \"OAI_CONFIG_LIST\",\n", + " filter_dict={\"api_type\": [\"CustomModelClientWithArguments\"]},\n", + ")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -405,7 +436,6 @@ "outputs": [], "source": [ "assistant.register_model_client(\n", - " model=config[\"model\"],\n", " model_client_cls=CustomModelClientWithArguments,\n", " loaded_model=loaded_model,\n", " tokenizer=tokenizer,\n", @@ -438,7 +468,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.8.10" }, "vscode": { "interpreter": { diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index be019d529329..ff188b072c41 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -62,7 +62,7 @@ def get_usage(response) -> Dict: config_list = [ { "model": TEST_LOCAL_MODEL_NAME, - "api_type": "custom", + "api_type": "CustomModel", "device": TEST_DEVICE, "params": { "max_length": TEST_MAX_LENGTH, @@ -74,7 +74,7 @@ def get_usage(response) -> Dict: test_hook = {"called": False} client = OpenAIWrapper(config_list=config_list) - client.register_model_client(model=TEST_LOCAL_MODEL_NAME, model_client_cls=CustomModel, test_hook=test_hook) + client.register_model_client(model_client_cls=CustomModel, test_hook=test_hook) response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) assert response.choices[0].message.content == TEST_CUSTOM_RESPONSE @@ -88,7 +88,7 @@ def get_usage(response) -> Dict: @pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_registering_with_wrong_model_name_raises_error(): +def test_registering_with_wrong_class_name_raises_error(): class CustomModel: def __init__(self, config: Dict): pass @@ -108,14 +108,14 @@ def get_usage(response) -> Dict: config_list = [ { - "model": "local_model_name_but_wrong_name", - "api_type": "custom", + "model": "local_model_name", + "api_type": "CustomModelWrongName", }, ] client = OpenAIWrapper(config_list=config_list) with pytest.raises(ValueError): - client.register_model_client(model="local_model_name", model_client_cls=CustomModel) + client.register_model_client(model_client_cls=CustomModel) @pytest.mark.skipif(skip, reason="openai>=1 not installed") @@ -123,7 +123,7 @@ def test_no_client_registered_raises_error(): config_list = [ { "model": "local_model_name", - "api_type": "custom", + "api_type": "CustomModel", "device": "cpu", "params": { "max_length": 1000, @@ -160,7 +160,7 @@ def get_usage(response) -> Dict: config_list = [ { "model": "local_model_name", - "api_type": "custom", + "api_type": "CustomModel", "device": "cpu", "params": { "max_length": 1000, @@ -169,7 +169,7 @@ def get_usage(response) -> Dict: }, { "model": "local_model_name_2", - "api_type": "custom", + "api_type": "CustomModel", "device": "cpu", "params": { "max_length": 1000, @@ -180,46 +180,7 @@ def get_usage(response) -> Dict: client = OpenAIWrapper(config_list=config_list) - client.register_model_client(model="local_model_name", model_client_cls=CustomModel) + client.register_model_client(model_client_cls=CustomModel) with pytest.raises(RuntimeError): client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) - - -@pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_registering_same_client_twice_raises_error(): - class CustomModel: - def __init__(self, config: Dict): - pass - - def create(self, params): - return None - - def message_retrieval(self, response): - return [] - - def cost(self, response) -> float: - return 0 - - @staticmethod - def get_usage(response) -> Dict: - return {} - - config_list = [ - { - "model": "local_model_name", - "api_type": "custom", - "device": "cpu", - "params": { - "max_length": 1000, - "other_params": "other_params", - }, - }, - ] - - client = OpenAIWrapper(config_list=config_list) - - client.register_model_client(model="local_model_name", model_client_cls=CustomModel) - - with pytest.raises(ValueError): - client.register_model_client(model="local_model_name", model_client_cls=CustomModel) diff --git a/website/blog/2024-01-26-Custom-Models/index.mdx b/website/blog/2024-01-26-Custom-Models/index.mdx index 987b8d31c3d1..14c5cc9bd895 100644 --- a/website/blog/2024-01-26-Custom-Models/index.mdx +++ b/website/blog/2024-01-26-Custom-Models/index.mdx @@ -66,12 +66,12 @@ class CustomModelClient: ### Step 2: Add the configuration to the OAI_CONFIG_LIST -The two fields that are necessary in the config are `"model" [str]` and `"api_type":"custom"`. Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.: +The field that is necessary is setting `api_type` to the name of the new class (as a string) `"api_type":"CustomModelClient"`. Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.: ```json { "model": "Open-Orca/Mistral-7B-OpenOrca", - "api_type": "custom", + "api_type": "CustomModelClient", "device": "cuda", "n": 1, "params": { @@ -82,13 +82,13 @@ The two fields that are necessary in the config are `"model" [str]` and `"api_ty ### Step 3: Register the new custom model to the agent that will use it -If a configuration with the field `"api_type":"custom"` has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized: +If a configuration with the field `"api_type":""` has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized: ```python -my_agent.register_model_client(model="Open-Orca/Mistral-7B-OpenOrca", model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor]) +my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor]) ``` -`model` matches the one specified in the `OAI_CONFIG_LIST` and `CustomModelClient` is the class that adheres to the `Client` protocol (more details on the protocol below). +`model_client_cls=CustomModelClient` arg matches the one specified in the `OAI_CONFIG_LIST` and `CustomModelClient` is the class that adheres to the `Client` protocol (more details on the protocol below). If the new model client is in the config list but not registered by the time the chat is initialized, then an error will be raised. @@ -160,9 +160,9 @@ If something doesn't work then run through the checklist: - `message_retrieval()` method: returns a list of strings or a list of message objects. If a list of message objects is returned, they currently must contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. - `cost()`method: returns an integer, and if you don't care about cost tracking you can just return `0`. - `get_usage()`: returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary `{}`. -- Make sure you have a corresponding entry in the `OAI_CONFIG_LIST` and that that entry has the `"model" [str]` and `"api_type":"custom"` fields. -- Make sure you have registered the client using the corresponding config entry and your new class `agent.register_model_client(model="", model_client_cls=, [other optional args])` -- Make sure you have registered only unique pairs of (`model`, `model_client_cls`) and don't try to register the same pair twice. +- Make sure you have a corresponding entry in the `OAI_CONFIG_LIST` and that that entry has the `"api_type":""` field. +- Make sure you have registered the client using the corresponding config entry and your new class `agent.register_model_client(model_client_cls=, [other optional args])` +- Make sure that all of the custom models defined in the `OAI_CONFIG_LIST` have been registered. - Any other troubleshooting might need to be done in the custom code itself. ## Conclusion From a737045508547fc5c4afb43469caa672c051aed6 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 26 Jan 2024 03:17:59 -0500 Subject: [PATCH 21/30] fix spelling --- autogen/oai/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index f58a2f8896c5..3183244f3bd4 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -427,7 +427,7 @@ def register_model_client(self, model_client_cls: Client, **kwargs): if existing_client_class: logger.warn( - f"Model client {model_client_cls.__name__} is already registered. Add more entires in the config_list to use multiple model clients." + f"Model client {model_client_cls.__name__} is already registered. Add more entries in the config_list to use multiple model clients." ) else: raise ValueError( From 1aa74df2a815ebb02ebf4b9b67333f43f52a77f7 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 26 Jan 2024 08:04:53 -0500 Subject: [PATCH 22/30] fix failing openai test --- autogen/oai/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 3183244f3bd4..93e202ee6015 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -177,7 +177,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: if function_call_chunk: # Handle function call if function_call_chunk: - full_function_call, completion_tokens = self._update_function_call_from_chunk( + full_function_call, completion_tokens = OpenAIWrapper._update_function_call_from_chunk( function_call_chunk, full_function_call, completion_tokens ) if not content: @@ -195,7 +195,7 @@ def create(self, params: Dict[str, Any]) -> ChatCompletion: # in case ix is not sequential full_tool_calls = full_tool_calls + [None] * (ix - len(full_tool_calls) + 1) - full_tool_calls[ix], completion_tokens = self._update_tool_calls_from_chunk( + full_tool_calls[ix], completion_tokens = OpenAIWrapper._update_tool_calls_from_chunk( tool_calls_chunk, full_tool_calls[ix], completion_tokens ) if not content: From 4e6d65c84bf48ca0969864fb20f312dbb962acda Mon Sep 17 00:00:00 2001 From: olgavrou Date: Wed, 31 Jan 2024 10:15:29 -0500 Subject: [PATCH 23/30] Update autogen/agentchat/conversable_agent.py Co-authored-by: Chi Wang --- autogen/agentchat/conversable_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 771a7625485d..f28c6b33c06c 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -1913,7 +1913,7 @@ def _decorator(func: F) -> F: return _decorator def register_model_client(self, model_client_cls: Client, **kwargs): - """Register a model. + """Register a model client. Args: model_client_cls: A custom client class that follows the Client interface From 67ad1d33fb948b5cf0cd42a744ba91812e86fce8 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 26 Jan 2024 10:28:23 -0500 Subject: [PATCH 24/30] api_type req made model_client_cls requirement --- autogen/oai/client.py | 34 +++++++++++-------- notebook/agentchat_custom_model.ipynb | 12 +++---- test/oai/test_custom_client.py | 10 +++--- .../blog/2024-01-26-Custom-Models/index.mdx | 8 ++--- 4 files changed, 35 insertions(+), 29 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 93e202ee6015..162b25d2c575 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -392,19 +392,23 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s """ openai_config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} api_type = config.get("api_type") - if api_type is None: - self._clients.append(OpenAIClient(OpenAI(**openai_config))) + model_client_cls_name = config.get("model_client_cls") + if model_client_cls_name is not None: + # a config for a custom client is set + # adding placeholder until the register_model_client is called with the appropriate class + self._clients.append(PlaceHolderClient(config)) + logger.info( + f"Detected custom model client in config: {model_client_cls_name}, model client can not be used until register_model_client is called." + ) else: - if api_type.startswith("azure"): - self._configure_azure_openai(config, openai_config) - self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) + if api_type is None: + self._clients.append(OpenAIClient(OpenAI(**openai_config))) else: - # else a config for a custom client is set - # skipping until the register_model_client is called with the appropriate class - self._clients.append(PlaceHolderClient(config)) - logger.info( - f"Detected custom model client in config: {api_type}, skipping registration until register_model_client is called." - ) + if api_type.startswith("azure"): + self._configure_azure_openai(config, openai_config) + self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) + else: + raise ValueError(f"api_type {api_type} is not supported.") def register_model_client(self, model_client_cls: Client, **kwargs): """Register a model client. @@ -419,7 +423,7 @@ def register_model_client(self, model_client_cls: Client, **kwargs): placeholder_config = client.config if placeholder_config in self._config_list: - if placeholder_config.get("api_type") == model_client_cls.__name__: + if placeholder_config.get("model_client_cls") == model_client_cls.__name__: self._clients[i] = model_client_cls(placeholder_config, **kwargs) return elif isinstance(client, model_client_cls): @@ -432,7 +436,7 @@ def register_model_client(self, model_client_cls: Client, **kwargs): else: raise ValueError( f'Model client "{model_client_cls.__name__}" is being registered but was not found in the config_list. ' - f'Please make sure to include an entry in the config_list with "api_type": "{model_client_cls.__name__}"' + f'Please make sure to include an entry in the config_list with "model_client_cls": "{model_client_cls.__name__}"' ) @classmethod @@ -520,7 +524,9 @@ def yes_or_no_filter(context, response): "No model client is active. Please populate the config list or register any custom model clients." ) # Check if all configs in config list are activated - non_activated = [client.config["api_type"] for client in self._clients if isinstance(client, PlaceHolderClient)] + non_activated = [ + client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient) + ] if non_activated: raise RuntimeError( f"Model client(s) {non_activated} are not activated. Please register the custom model clients using `register_model_client` or filter them out form the config list." diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 82a6ba9bc5bf..50e471167437 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -262,12 +262,12 @@ "\n", "You can add any paramteres that are needed for the custom model loading in the same configuration list.\n", "\n", - "It is important to add the `api_type` and set it to a string that corresponds to the class name: `\"CustomModelClient\"`.\n", + "It is important to add the `model_client_cls` field and set it to a string that corresponds to the class name: `\"CustomModelClient\"`.\n", "\n", "```json\n", "{\n", " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", - " \"api_type\": \"CustomModelClient\",\n", + " \"model_client_cls\": \"CustomModelClient\",\n", " \"device\": \"cuda\",\n", " \"n\": 1,\n", " \"params\": {\n", @@ -285,7 +285,7 @@ "source": [ "config_list_custom = autogen.config_list_from_json(\n", " \"OAI_CONFIG_LIST\",\n", - " filter_dict={\"api_type\": [\"CustomModelClient\"]},\n", + " filter_dict={\"model_client_cls\": [\"CustomModelClient\"]},\n", ")" ] }, @@ -398,7 +398,7 @@ "```json\n", "{\n", " \"model\": \"Open-Orca/Mistral-7B-OpenOrca\",\n", - " \"api_type\": \"CustomModelClientWithArguments\",\n", + " \"model_client_cls\": \"CustomModelClientWithArguments\",\n", " \"device\": \"cuda\",\n", " \"n\": 1,\n", " \"params\": {\n", @@ -416,7 +416,7 @@ "source": [ "config_list_custom = autogen.config_list_from_json(\n", " \"OAI_CONFIG_LIST\",\n", - " filter_dict={\"api_type\": [\"CustomModelClientWithArguments\"]},\n", + " filter_dict={\"model_client_cls\": [\"CustomModelClientWithArguments\"]},\n", ")" ] }, @@ -468,7 +468,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.9.5" }, "vscode": { "interpreter": { diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index ff188b072c41..676b674ebbd3 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -62,7 +62,7 @@ def get_usage(response) -> Dict: config_list = [ { "model": TEST_LOCAL_MODEL_NAME, - "api_type": "CustomModel", + "model_client_cls": "CustomModel", "device": TEST_DEVICE, "params": { "max_length": TEST_MAX_LENGTH, @@ -109,7 +109,7 @@ def get_usage(response) -> Dict: config_list = [ { "model": "local_model_name", - "api_type": "CustomModelWrongName", + "model_client_cls": "CustomModelWrongName", }, ] client = OpenAIWrapper(config_list=config_list) @@ -123,7 +123,7 @@ def test_no_client_registered_raises_error(): config_list = [ { "model": "local_model_name", - "api_type": "CustomModel", + "model_client_cls": "CustomModel", "device": "cpu", "params": { "max_length": 1000, @@ -160,7 +160,7 @@ def get_usage(response) -> Dict: config_list = [ { "model": "local_model_name", - "api_type": "CustomModel", + "model_client_cls": "CustomModel", "device": "cpu", "params": { "max_length": 1000, @@ -169,7 +169,7 @@ def get_usage(response) -> Dict: }, { "model": "local_model_name_2", - "api_type": "CustomModel", + "model_client_cls": "CustomModel", "device": "cpu", "params": { "max_length": 1000, diff --git a/website/blog/2024-01-26-Custom-Models/index.mdx b/website/blog/2024-01-26-Custom-Models/index.mdx index 14c5cc9bd895..14d9a662f080 100644 --- a/website/blog/2024-01-26-Custom-Models/index.mdx +++ b/website/blog/2024-01-26-Custom-Models/index.mdx @@ -66,12 +66,12 @@ class CustomModelClient: ### Step 2: Add the configuration to the OAI_CONFIG_LIST -The field that is necessary is setting `api_type` to the name of the new class (as a string) `"api_type":"CustomModelClient"`. Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.: +The field that is necessary is setting `model_client_cls` to the name of the new class (as a string) `"model_client_cls":"CustomModelClient"`. Any other fields will be forwarded to the class constructor, so you have full control over what parameters to specify and how to use them. E.g.: ```json { "model": "Open-Orca/Mistral-7B-OpenOrca", - "api_type": "CustomModelClient", + "model_client_cls": "CustomModelClient", "device": "cuda", "n": 1, "params": { @@ -82,7 +82,7 @@ The field that is necessary is setting `api_type` to the name of the new class ( ### Step 3: Register the new custom model to the agent that will use it -If a configuration with the field `"api_type":""` has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized: +If a configuration with the field `"model_client_cls":""` has been added to an Agent's config list, then the corresponding model with the desired class must be registered after the agent is created and before the conversation is initialized: ```python my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor]) @@ -160,7 +160,7 @@ If something doesn't work then run through the checklist: - `message_retrieval()` method: returns a list of strings or a list of message objects. If a list of message objects is returned, they currently must contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. - `cost()`method: returns an integer, and if you don't care about cost tracking you can just return `0`. - `get_usage()`: returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary `{}`. -- Make sure you have a corresponding entry in the `OAI_CONFIG_LIST` and that that entry has the `"api_type":""` field. +- Make sure you have a corresponding entry in the `OAI_CONFIG_LIST` and that that entry has the `"model_client_cls":""` field. - Make sure you have registered the client using the corresponding config entry and your new class `agent.register_model_client(model_client_cls=, [other optional args])` - Make sure that all of the custom models defined in the `OAI_CONFIG_LIST` have been registered. - Any other troubleshooting might need to be done in the custom code itself. From 7888a7cc1a3f6f32e77503b6754284aca86a49c7 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 26 Jan 2024 12:15:36 -0500 Subject: [PATCH 25/30] remove raise error if client list is empty - client list will never be empty it will have placeholders --- autogen/oai/client.py | 4 ---- test/oai/test_custom_client.py | 20 -------------------- 2 files changed, 24 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 162b25d2c575..6988069bfdf6 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -519,10 +519,6 @@ def yes_or_no_filter(context, response): if ERROR: raise ERROR last = len(self._clients) - 1 - if len(self._clients) == 0: - raise RuntimeError( - "No model client is active. Please populate the config list or register any custom model clients." - ) # Check if all configs in config list are activated non_activated = [ client.config["model_client_cls"] for client in self._clients if isinstance(client, PlaceHolderClient) diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 676b674ebbd3..68bd2d5cf2ab 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -118,26 +118,6 @@ def get_usage(response) -> Dict: client.register_model_client(model_client_cls=CustomModel) -@pytest.mark.skipif(skip, reason="openai>=1 not installed") -def test_no_client_registered_raises_error(): - config_list = [ - { - "model": "local_model_name", - "model_client_cls": "CustomModel", - "device": "cpu", - "params": { - "max_length": 1000, - "other_params": "other_params", - }, - }, - ] - - client = OpenAIWrapper(config_list=config_list) - - with pytest.raises(RuntimeError): - client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None) - - @pytest.mark.skipif(skip, reason="openai>=1 not installed") def test_not_all_clients_registered_raises_error(): class CustomModel: From ac7eed1e9140b1700e233543758879ec6fcf9b9e Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 26 Jan 2024 13:02:46 -0500 Subject: [PATCH 26/30] rename Client -> ModelClient --- autogen/agentchat/conversable_agent.py | 4 +-- autogen/oai/__init__.py | 4 +-- notebook/agentchat_custom_model.ipynb | 23 +++++++------- test/oai/test_custom_client.py | 2 +- .../blog/2024-01-26-Custom-Models/index.mdx | 30 +++++++++---------- 5 files changed, 31 insertions(+), 32 deletions(-) diff --git a/autogen/agentchat/conversable_agent.py b/autogen/agentchat/conversable_agent.py index 6ffdc121d596..4f883ea658fa 100644 --- a/autogen/agentchat/conversable_agent.py +++ b/autogen/agentchat/conversable_agent.py @@ -8,7 +8,7 @@ from collections import defaultdict from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union -from .. import OpenAIWrapper, Client +from .. import OpenAIWrapper, ModelClient from ..cache.cache import Cache from ..code_utils import ( DEFAULT_MODEL, @@ -1937,7 +1937,7 @@ def _decorator(func: F) -> F: return _decorator - def register_model_client(self, model_client_cls: Client, **kwargs): + def register_model_client(self, model_client_cls: ModelClient, **kwargs): """Register a model client. Args: diff --git a/autogen/oai/__init__.py b/autogen/oai/__init__.py index 2b0df192fafe..1cf57f04456a 100644 --- a/autogen/oai/__init__.py +++ b/autogen/oai/__init__.py @@ -1,4 +1,4 @@ -from autogen.oai.client import OpenAIWrapper, Client +from autogen.oai.client import OpenAIWrapper, ModelClient from autogen.oai.completion import Completion, ChatCompletion from autogen.oai.openai_utils import ( get_config_list, @@ -12,7 +12,7 @@ __all__ = [ "OpenAIWrapper", - "Client", + "ModelClient", "Completion", "ChatCompletion", "get_config_list", diff --git a/notebook/agentchat_custom_model.ipynb b/notebook/agentchat_custom_model.ipynb index 50e471167437..b58b5d93a055 100644 --- a/notebook/agentchat_custom_model.ipynb +++ b/notebook/agentchat_custom_model.ipynb @@ -65,17 +65,17 @@ "source": [ "## Create and configure the custom model\n", "\n", - "A custom model class can be created in many ways, but needs to adhere to the `Client` protocol and response structure which is defined in client.py and shown below.\n", + "A custom model class can be created in many ways, but needs to adhere to the `ModelClient` protocol and response structure which is defined in client.py and shown below.\n", "\n", "The response protocol has some minimum requirements, but can be extended to include any additional information that is needed.\n", - "Message retrieval therefore can be customized, but needs to return a list of strings or a list of `ClientResponseProtocol.Choice.Message` objects.\n", + "Message retrieval therefore can be customized, but needs to return a list of strings or a list of `ModelClientResponseProtocol.Choice.Message` objects.\n", "\n", "\n", "```python\n", - "class Client(Protocol):\n", + "class ModelClient(Protocol):\n", " \"\"\"\n", " A client class must implement the following methods:\n", - " - create must return a response object that implements the ClientResponseProtocol\n", + " - create must return a response object that implements the ModelClientResponseProtocol\n", " - cost must return the cost of the response\n", " - get_usage must return a dict with the following keys:\n", " - prompt_tokens\n", @@ -85,13 +85,13 @@ " - model\n", "\n", " This class is used to create a client that can be used by OpenAIWrapper.\n", - " The response returned from create must adhere to the ClientResponseProtocol but can be extended however needed.\n", + " The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed.\n", " The message_retrieval method must be implemented to return a list of str or a list of messages from the response.\n", " \"\"\"\n", "\n", " RESPONSE_USAGE_KEYS = [\"prompt_tokens\", \"completion_tokens\", \"total_tokens\", \"cost\", \"model\"]\n", "\n", - " class ClientResponseProtocol(Protocol):\n", + " class ModelClientResponseProtocol(Protocol):\n", " class Choice(Protocol):\n", " class Message(Protocol):\n", " content: str | None\n", @@ -99,12 +99,12 @@ " choices: List[Choice]\n", " model: str\n", "\n", - " def create(self, params) -> ClientResponseProtocol:\n", + " def create(self, params) -> ModelClientResponseProtocol:\n", " ...\n", "\n", " def message_retrieval(\n", - " self, response: ClientResponseProtocol\n", - " ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]:\n", + " self, response: ModelClientResponseProtocol\n", + " ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]:\n", " \"\"\"\n", " Retrieve and return a list of strings or a list of Choice.Message from the response.\n", "\n", @@ -113,11 +113,11 @@ " \"\"\"\n", " ...\n", "\n", - " def cost(self, response: ClientResponseProtocol) -> float:\n", + " def cost(self, response: ModelClientResponseProtocol) -> float:\n", " ...\n", "\n", " @staticmethod\n", - " def get_usage(response: ClientResponseProtocol) -> Dict:\n", + " def get_usage(response: ModelClientResponseProtocol) -> Dict:\n", " \"\"\"Return usage summary of the response using RESPONSE_USAGE_KEYS.\"\"\"\n", " ...\n", "```\n" @@ -328,7 +328,6 @@ "metadata": {}, "outputs": [], "source": [ - "config = config_list_custom[0]\n", "assistant.register_model_client(model_client_cls=CustomModelClient)" ] }, diff --git a/test/oai/test_custom_client.py b/test/oai/test_custom_client.py index 68bd2d5cf2ab..8e536921795f 100644 --- a/test/oai/test_custom_client.py +++ b/test/oai/test_custom_client.py @@ -1,6 +1,6 @@ import pytest from autogen import OpenAIWrapper -from autogen.oai import Client +from autogen.oai import ModelClient from typing import Dict try: diff --git a/website/blog/2024-01-26-Custom-Models/index.mdx b/website/blog/2024-01-26-Custom-Models/index.mdx index 14d9a662f080..796b0c00d20e 100644 --- a/website/blog/2024-01-26-Custom-Models/index.mdx +++ b/website/blog/2024-01-26-Custom-Models/index.mdx @@ -17,9 +17,9 @@ An interactive and easy way to get started is by following the notebook [here](h ### Step 1: Create the custom model client class -To get started with using custom models in AutoGen, you need to create a model client class that adheres to the `Client` protocol defined in `client.py`. The `Client` class should implement these methods: +To get started with using custom models in AutoGen, you need to create a model client class that adheres to the `ModelClient` protocol defined in `client.py`. The new model client class should implement these methods: -- `create()`: Returns a response object that implements the `ClientResponseProtocol` (more details in the Protocol section). +- `create()`: Returns a response object that implements the `ModelClientResponseProtocol` (more details in the Protocol section). - `message_retrieval()`: Processes the response object and returns a list of strings or a list of message objects (more details in the Protocol section). - `cost()`: Returns the cost of the response. - `get_usage()`: Returns a dictionary with keys from `RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"]`. @@ -36,7 +36,7 @@ class CustomModelClient: # can create my own data response class # here using SimpleNamespace for simplicity - # as long as it adheres to the ClientResponseProtocol + # as long as it adheres to the ModelClientResponseProtocol response = SimpleNamespace() response.choices = [] @@ -88,22 +88,22 @@ If a configuration with the field `"model_client_cls":""` has been a my_agent.register_model_client(model_client_cls=CustomModelClient, [other args that will be forwarded to CustomModelClient constructor]) ``` -`model_client_cls=CustomModelClient` arg matches the one specified in the `OAI_CONFIG_LIST` and `CustomModelClient` is the class that adheres to the `Client` protocol (more details on the protocol below). +`model_client_cls=CustomModelClient` arg matches the one specified in the `OAI_CONFIG_LIST` and `CustomModelClient` is the class that adheres to the `ModelClient` protocol (more details on the protocol below). If the new model client is in the config list but not registered by the time the chat is initialized, then an error will be raised. ## Protocol details -A custom model class can be created in many ways, but needs to adhere to the `Client` protocol and response structure which is defined in `client.py` and shown below. +A custom model class can be created in many ways, but needs to adhere to the `ModelClient` protocol and response structure which is defined in `client.py` and shown below. The response protocol is currently using the minimum required fields from the autogen codebase that match the OpenAI response structure. Any response protocol that matches the OpenAI response structure will probably be more resilient to future changes, but we are starting off with minimum requirements to make adpotion of this feature easier. ```python -class Client(Protocol): +class ModelClient(Protocol): """ A client class must implement the following methods: - - create must return a response object that implements the ClientResponseProtocol + - create must return a response object that implements the ModelClientResponseProtocol - cost must return the cost of the response - get_usage must return a dict with the following keys: - prompt_tokens @@ -113,13 +113,13 @@ class Client(Protocol): - model This class is used to create a client that can be used by OpenAIWrapper. - The response returned from create must adhere to the ClientResponseProtocol but can be extended however needed. + The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. The message_retrieval method must be implemented to return a list of str or a list of messages from the response. """ RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] - class ClientResponseProtocol(Protocol): + class ModelClientResponseProtocol(Protocol): class Choice(Protocol): class Message(Protocol): content: str | None @@ -127,12 +127,12 @@ class Client(Protocol): choices: List[Choice] model: str - def create(self, params) -> ClientResponseProtocol: + def create(self, params) -> ModelClientResponseProtocol: ... def message_retrieval( - self, response: ClientResponseProtocol - ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]: + self, response: ModelClientResponseProtocol + ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -141,11 +141,11 @@ class Client(Protocol): """ ... - def cost(self, response: ClientResponseProtocol) -> float: + def cost(self, response: ModelClientResponseProtocol) -> float: ... @staticmethod - def get_usage(response: ClientResponseProtocol) -> Dict: + def get_usage(response: ModelClientResponseProtocol) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" ... @@ -156,7 +156,7 @@ class Client(Protocol): If something doesn't work then run through the checklist: - Make sure you have followed the client protocol and client response protocol when creating the custom model class - - `create()` method: `ClientResponseProtocol` must be followed when returning an inference response during `create` call. + - `create()` method: `ModelClientResponseProtocol` must be followed when returning an inference response during `create` call. - `message_retrieval()` method: returns a list of strings or a list of message objects. If a list of message objects is returned, they currently must contain the fields of OpenAI's ChatCompletion Message object, since that is expected for function or tool calling in the rest of the codebase at the moment, unless a custom agent is being used. - `cost()`method: returns an integer, and if you don't care about cost tracking you can just return `0`. - `get_usage()`: returns a dictionary, and if you don't care about usage tracking you can just return an empty dictionary `{}`. From 95d127b0bfb92f090e7ec9695bc541b3c131e7d8 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Fri, 26 Jan 2024 13:06:55 -0500 Subject: [PATCH 27/30] add forgotten file --- autogen/oai/client.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 6988069bfdf6..8801935207eb 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -53,10 +53,10 @@ LEGACY_CACHE_DIR = ".cache" -class Client(Protocol): +class ModelClient(Protocol): """ A client class must implement the following methods: - - create must return a response object that implements the ClientResponseProtocol + - create must return a response object that implements the ModelClientResponseProtocol - cost must return the cost of the response - get_usage must return a dict with the following keys: - prompt_tokens @@ -66,13 +66,13 @@ class Client(Protocol): - model This class is used to create a client that can be used by OpenAIWrapper. - The response returned from create must adhere to the ClientResponseProtocol but can be extended however needed. + The response returned from create must adhere to the ModelClientResponseProtocol but can be extended however needed. The message_retrieval method must be implemented to return a list of str or a list of messages from the response. """ RESPONSE_USAGE_KEYS = ["prompt_tokens", "completion_tokens", "total_tokens", "cost", "model"] - class ClientResponseProtocol(Protocol): + class ModelClientResponseProtocol(Protocol): class Choice(Protocol): class Message(Protocol): content: str | None @@ -80,12 +80,12 @@ class Message(Protocol): choices: List[Choice] model: str - def create(self, params) -> ClientResponseProtocol: + def create(self, params) -> ModelClientResponseProtocol: ... # pragma: no cover def message_retrieval( - self, response: ClientResponseProtocol - ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]: + self, response: ModelClientResponseProtocol + ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: """ Retrieve and return a list of strings or a list of Choice.Message from the response. @@ -94,11 +94,11 @@ def message_retrieval( """ ... # pragma: no cover - def cost(self, response: ClientResponseProtocol) -> float: + def cost(self, response: ModelClientResponseProtocol) -> float: ... # pragma: no cover @staticmethod - def get_usage(response: ClientResponseProtocol) -> Dict: + def get_usage(response: ModelClientResponseProtocol) -> Dict: """Return usage summary of the response using RESPONSE_USAGE_KEYS.""" ... # pragma: no cover @@ -349,7 +349,7 @@ def __init__(self, *, config_list: Optional[List[Dict[str, Any]]] = None, **base if type(config_list) is list and len(config_list) == 0: logger.warning("openai client was provided with an empty config_list, which may not be intended.") - self._clients: List[Client] = [] + self._clients: List[ModelClient] = [] self._config_list: List[Dict[str, Any]] = [] if config_list: @@ -410,11 +410,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s else: raise ValueError(f"api_type {api_type} is not supported.") - def register_model_client(self, model_client_cls: Client, **kwargs): + def register_model_client(self, model_client_cls: ModelClient, **kwargs): """Register a model client. Args: - model_client_cls: A custom client class that follows the Client interface + model_client_cls: A custom client class that follows the ModelClient interface **kwargs: The kwargs for the custom client class to be initialized with """ existing_client_class = False @@ -483,7 +483,7 @@ def _construct_create_params(self, create_config: Dict[str, Any], extra_kwargs: ] return params - def create(self, **config: Any) -> Client.ClientResponseProtocol: + def create(self, **config: Any) -> ModelClient.ModelClientResponseProtocol: """Make a completion for a given config using available clients. Besides the kwargs allowed in openai's [or other] client, we allow the following additional kwargs. The config in each client will be overridden by the config. @@ -558,7 +558,7 @@ def yes_or_no_filter(context, response): with cache_client as cache: # Try to get the response from cache key = get_key(params) - response: Client.ClientResponseProtocol = cache.get(key, None) + response: ModelClient.ModelClientResponseProtocol = cache.get(key, None) if response is not None: response.message_retrieval_function = client.message_retrieval @@ -723,7 +723,7 @@ def _update_tool_calls_from_chunk( def _update_usage(self, actual_usage, total_usage): def update_usage(usage_summary, response_usage): # go through RESPONSE_USAGE_KEYS and check that they are in response_usage and if not just return usage_summary - for key in Client.RESPONSE_USAGE_KEYS: + for key in ModelClient.RESPONSE_USAGE_KEYS: if key not in response_usage: return usage_summary @@ -810,8 +810,8 @@ def clear_usage_summary(self) -> None: @classmethod def extract_text_or_completion_object( - cls, response: Client.ClientResponseProtocol - ) -> Union[List[str], List[Client.ClientResponseProtocol.Choice.Message]]: + cls, response: ModelClient.ModelClientResponseProtocol + ) -> Union[List[str], List[ModelClient.ModelClientResponseProtocol.Choice.Message]]: """Extract the text or ChatCompletion objects from a completion or chat response. Args: From dc7605e466ead1d7879130d7c73d5058e70fb9a5 Mon Sep 17 00:00:00 2001 From: olgavrou Date: Sat, 27 Jan 2024 01:40:21 -0500 Subject: [PATCH 28/30] type hints, small fixes, docstr comment --- autogen/oai/client.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 8801935207eb..52f0ece1f79e 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -75,12 +75,12 @@ class ModelClient(Protocol): class ModelClientResponseProtocol(Protocol): class Choice(Protocol): class Message(Protocol): - content: str | None + content: Optional[str] choices: List[Choice] model: str - def create(self, params) -> ModelClientResponseProtocol: + def create(self, **params: Any) -> ModelClientResponseProtocol: ... # pragma: no cover def message_retrieval( @@ -111,7 +111,7 @@ def __init__(self, config): class OpenAIClient: """Follows the Client protocol and wraps the OpenAI client.""" - def __init__(self, client): + def __init__(self, client: Union[OpenAI, AzureOpenAI]): self._oai_client = client def message_retrieval( @@ -515,6 +515,9 @@ def yes_or_no_filter(context, response): - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. - api_version (str | None): The api version. Default to None. E.g., "2023-08-01-preview". + Raises: + - RuntimeError: If all declared custom model clients are not registered + - APIError: If any model client create call raises an APIError """ if ERROR: raise ERROR From 26d45dee1c3849cc6c6db47548aa14df4729a94e Mon Sep 17 00:00:00 2001 From: olgavrou Date: Sat, 27 Jan 2024 02:09:51 -0500 Subject: [PATCH 29/30] fix test by fetching internal client --- autogen/agentchat/contrib/gpt_assistant_agent.py | 2 +- test/agentchat/contrib/test_gpt_assistant.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/autogen/agentchat/contrib/gpt_assistant_agent.py b/autogen/agentchat/contrib/gpt_assistant_agent.py index a8419f0d2ad3..b588b2b59f5a 100644 --- a/autogen/agentchat/contrib/gpt_assistant_agent.py +++ b/autogen/agentchat/contrib/gpt_assistant_agent.py @@ -56,7 +56,7 @@ def __init__( oai_wrapper = OpenAIWrapper(**llm_config) if len(oai_wrapper._clients) > 1: logger.warning("GPT Assistant only supports one OpenAI client. Using the first client in the list.") - self._openai_client = oai_wrapper._clients[0] + self._openai_client = oai_wrapper._clients[0]._oai_client openai_assistant_id = llm_config.get("assistant_id", None) if openai_assistant_id is None: # try to find assistant by name first diff --git a/test/agentchat/contrib/test_gpt_assistant.py b/test/agentchat/contrib/test_gpt_assistant.py index dbe51192ec47..ea190893f6ce 100644 --- a/test/agentchat/contrib/test_gpt_assistant.py +++ b/test/agentchat/contrib/test_gpt_assistant.py @@ -192,7 +192,7 @@ def test_get_assistant_files(): and assert that the retrieved instructions match the set instructions. """ current_file_path = os.path.abspath(__file__) - openai_client = OpenAIWrapper(config_list=config_list)._clients[0] + openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client file = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") name = "For test_get_assistant_files" @@ -238,7 +238,7 @@ def test_assistant_retrieval(): "description": "This is a test function 2", } - openai_client = OpenAIWrapper(config_list=config_list)._clients[0] + openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client current_file_path = os.path.abspath(__file__) file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") @@ -312,7 +312,7 @@ def test_assistant_mismatch_retrieval(): "description": "This is a test function 3", } - openai_client = OpenAIWrapper(config_list=config_list)._clients[0] + openai_client = OpenAIWrapper(config_list=config_list)._clients[0]._oai_client current_file_path = os.path.abspath(__file__) file_1 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") file_2 = openai_client.files.create(file=open(current_file_path, "rb"), purpose="assistants") From f9bca4ecb974f0040004557f22a244b45036eeda Mon Sep 17 00:00:00 2001 From: olgavrou Date: Sat, 27 Jan 2024 03:05:56 -0500 Subject: [PATCH 30/30] fix api type checking --- autogen/oai/client.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/autogen/oai/client.py b/autogen/oai/client.py index 52f0ece1f79e..8480b51e4d68 100644 --- a/autogen/oai/client.py +++ b/autogen/oai/client.py @@ -401,14 +401,11 @@ def _register_default_client(self, config: Dict[str, Any], openai_config: Dict[s f"Detected custom model client in config: {model_client_cls_name}, model client can not be used until register_model_client is called." ) else: - if api_type is None: - self._clients.append(OpenAIClient(OpenAI(**openai_config))) + if api_type is not None and api_type.startswith("azure"): + self._configure_azure_openai(config, openai_config) + self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) else: - if api_type.startswith("azure"): - self._configure_azure_openai(config, openai_config) - self._clients.append(OpenAIClient(AzureOpenAI(**openai_config))) - else: - raise ValueError(f"api_type {api_type} is not supported.") + self._clients.append(OpenAIClient(OpenAI(**openai_config))) def register_model_client(self, model_client_cls: ModelClient, **kwargs): """Register a model client.