From ca319d82d83331edd57b14dd915f65560070a9ec Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Wed, 25 Oct 2023 14:33:17 +0000 Subject: [PATCH 1/2] openai client --- autogen/oai/_client.py | 223 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 223 insertions(+) create mode 100644 autogen/oai/_client.py diff --git a/autogen/oai/_client.py b/autogen/oai/_client.py new file mode 100644 index 00000000000..3b60eb788c4 --- /dev/null +++ b/autogen/oai/_client.py @@ -0,0 +1,223 @@ +import sys +from typing import List, Optional, Dict, Callable +import logging +from flaml.automl.logger import logger_formatter +from openai.types.chat import ChatCompletion +from openai.types.completion import Completion + +from autogen.oai.openai_utils import get_key + +try: + from openai import ( + RateLimitError, + APIError, + BadRequestError, + APIConnectionError, + Timeout, + AuthenticationError, + ) + from openai import OpenAI + import diskcache + + ERROR = None +except ImportError: + ERROR = ImportError("please install openai>=1 and diskcache to use the autogen.oai subpackage.") + OpenAI = object +logger = logging.getLogger(__name__) +if not logger.handlers: + # Add the console handler. + _ch = logging.StreamHandler(stream=sys.stdout) + _ch.setFormatter(logger_formatter) + logger.addHandler(_ch) + + +class OpenAIWrapper(OpenAI): + """A wrapper class for openai client.""" + + cache_path_root: str = ".cache" + additional_kwargs = {"seed", "filter_func", "allow_format_str_template", "context", "api_type", "api_version"} + + def __init__(self, *, config_list: List[Dict] = None, **base_config): + """ + Args: + config_list: a list of config dicts to override the base_config. + They can contain additional kwargs as allowed in the [create](/docs/reference/oai/_client/#create) method. E.g., + + ```python + config_list=[ + { + "model": "gpt-4", + "api_key": os.environ.get("AZURE_OPENAI_API_KEY"), + "api_type": "azure", + "base_url": os.environ.get("AZURE_OPENAI_API_BASE"), + "api_version": "2023-03-15-preview", + }, + { + "model": "gpt-3.5-turbo", + "api_key": os.environ.get("OPENAI_API_KEY"), + "api_type": "open_ai", + "base_url": "https://api.openai.com/v1", + }, + { + "model": "llama-7B", + "base_url": "http://127.0.0.1:8080", + "api_type": "open_ai", + } + ] + ``` + + base_config: base config. It can contain both keyword arguments for openai client + and additional kwargs. + """ + openai_config, extra_kwargs = self._separate_config(base_config) + super().__init__(**openai_config) + if type(config_list) is list and len(config_list) == 0: + logger.warning("openai client was provided with an empty config_list, which may not be intended.") + if config_list: + self._clients = [self._client(config, openai_config) for config in config_list] + self._config_list = [ + {**extra_kwargs, **{k: v for k, v in config.items() if k in self.additional_kwargs}} + for config in config_list + ] + else: + self._clients = [self] + self._config_list = [extra_kwargs] + + def _separate_config(self, config): + """Separate the config into openai_config and additional_kwargs.""" + openai_config = {k: v for k, v in config.items() if k not in self.additional_kwargs} + additional_kwargs = {k: v for k, v in config.items() if k in self.additional_kwargs} + return openai_config, additional_kwargs + + def _client(self, config, openai_config): + """Create a client with the given config to overrdie openai_config, + after removing additional kwargs. + """ + config = {**openai_config, **{k: v for k, v in config.items() if k not in self.additional_kwargs}} + client = OpenAI(**config) + return client + + @classmethod + def instantiate( + cls, + template: str | Callable | None, + context: Optional[Dict] = None, + allow_format_str_template: Optional[bool] = False, + ): + if not context or template is None: + return template + if isinstance(template, str): + return template.format(**context) if allow_format_str_template else template + return template(context) + + def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict: + """Prime the create_config with additional_kwargs.""" + # Validate the config + prompt = create_config.get("prompt") + messages = create_config.get("messages") + if (prompt is None) == (messages is None): + raise ValueError("Either prompt or messages should be in create config but not both.") + context = extra_kwargs.get("context") + if context is None: + # No need to instantiate if no context is provided. + return create_config + # Instantiate the prompt or messages + allow_format_str_template = extra_kwargs.get("allow_format_str_template", False) + # Make a copy of the config + params = create_config.copy() + if prompt is not None: + # Instantiate the prompt + params["prompt"] = self.instantiate(prompt, context, allow_format_str_template) + elif context: + # Instantiate the messages + params["messages"] = [ + { + **m, + "content": self.instantiate(m["content"], context, allow_format_str_template), + } + if m.get("content") + else m + for m in messages + ] + return params + + def create(self, **config): + """Make a completion for a given config using openai's clients. + Besides the kwargs allowed in openai's client, we allow the following additional kwargs. + The config in each client will be overriden by the config. + + Args: + - context (Dict | None): The context to instantiate the prompt or messages. Default to None. + It needs to contain keys that are used by the prompt template or the filter function. + E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`. + The actual prompt will be: + "Complete the following sentence: Today I feel". + More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating). + - `seed` (int | None) for the cache. Default to 41. + An integer seed is useful when implementing "controlled randomness" for the completion. + None for no caching. + - filter_func (Callable | None): A function that takes in the context and the response + and returns a boolean to indicate whether the response is valid. E.g., + + ```python + def yes_or_no_filter(context, response): + return context.get("yes_or_no_choice", False) is False or any( + text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response) + ) + ``` + + - allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false. + """ + if ERROR: + raise ERROR + create_config, extra_kwargs = self._separate_config(config) + last = len(self._clients) - 1 + for i, client in enumerate(self._clients): + final_extra_kwargs = {**self._config_list[i], **extra_kwargs} + params = self._construct_create_params(create_config, final_extra_kwargs) + seed = final_extra_kwargs.get("seed", 41) + filter_func = final_extra_kwargs.get("filter_func") + context = extra_kwargs.get("context") + with diskcache.Cache(f"{self.cache_path_root}/{seed}") as cache: + if seed is not None: + # Try to get the response from cache + key = get_key(params) + response = cache.get(key, None) + if response is not None: + # check the filter + pass_filter = filter_func is None or filter_func(context=context, response=response) + if pass_filter or i == last: + # Return the response if it passes the filter or it is the last client + response.config_id = i + response.pass_filter = pass_filter + # TODO: add response.cost + return response + completions = client.chat.completions if "messages" in params else client.completions + try: + response = completions.create(**params) + except APIConnectionError: + # This seems to be the only error raised by openai + logger.debug(f"config {i} failed", exc_info=1) + if i == last: + raise + else: + if seed is not None: + # Cache the response + cache.set(key, response) + return response + + def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]: + """Extract the text or function calls from a completion or chat response. + + Args: + response (ChatCompletion | Completion): The response from openai. + + Returns: + A list of text or function calls in the responses. + """ + choices = response.choices + if isinstance(response, Completion): + return [choice.text for choice in choices] + return [ + choice.message if choice.message.function_call is not None else choice.message.content for choice in choices + ] From 242b4133fac5792a57a068dda5868a2fff73cd67 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Wed, 25 Oct 2023 16:23:05 +0000 Subject: [PATCH 2/2] client test --- autogen/oai/__init__.py | 2 ++ autogen/oai/_client.py | 47 +++++++++++++++++++++++++---------------- setup.py | 2 +- test/test_client.py | 33 +++++++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 19 deletions(-) create mode 100644 test/test_client.py diff --git a/autogen/oai/__init__.py b/autogen/oai/__init__.py index a1b34b33683..e98ef45ead1 100644 --- a/autogen/oai/__init__.py +++ b/autogen/oai/__init__.py @@ -1,3 +1,4 @@ +from autogen.oai._client import OpenAIWrapper from autogen.oai.completion import Completion, ChatCompletion from autogen.oai.openai_utils import ( get_config_list, @@ -9,6 +10,7 @@ ) __all__ = [ + "OpenAIWrapper", "Completion", "ChatCompletion", "get_config_list", diff --git a/autogen/oai/_client.py b/autogen/oai/_client.py index 3b60eb788c4..55e31ea15b1 100644 --- a/autogen/oai/_client.py +++ b/autogen/oai/_client.py @@ -1,6 +1,7 @@ import sys from typing import List, Optional, Dict, Callable import logging +import inspect from flaml.automl.logger import logger_formatter from openai.types.chat import ChatCompletion from openai.types.completion import Completion @@ -31,11 +32,12 @@ logger.addHandler(_ch) -class OpenAIWrapper(OpenAI): +class OpenAIWrapper: """A wrapper class for openai client.""" cache_path_root: str = ".cache" - additional_kwargs = {"seed", "filter_func", "allow_format_str_template", "context", "api_type", "api_version"} + extra_kwargs = {"seed", "filter_func", "allow_format_str_template", "context", "api_type", "api_version"} + openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs) def __init__(self, *, config_list: List[Dict] = None, **base_config): """ @@ -69,31 +71,36 @@ def __init__(self, *, config_list: List[Dict] = None, **base_config): base_config: base config. It can contain both keyword arguments for openai client and additional kwargs. """ - openai_config, extra_kwargs = self._separate_config(base_config) - super().__init__(**openai_config) + openai_config, extra_kwargs = self._separate_openai_config(base_config) if type(config_list) is list and len(config_list) == 0: logger.warning("openai client was provided with an empty config_list, which may not be intended.") if config_list: self._clients = [self._client(config, openai_config) for config in config_list] self._config_list = [ - {**extra_kwargs, **{k: v for k, v in config.items() if k in self.additional_kwargs}} + {**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}} for config in config_list ] else: - self._clients = [self] + self._clients = [OpenAI(**openai_config)] self._config_list = [extra_kwargs] - def _separate_config(self, config): - """Separate the config into openai_config and additional_kwargs.""" - openai_config = {k: v for k, v in config.items() if k not in self.additional_kwargs} - additional_kwargs = {k: v for k, v in config.items() if k in self.additional_kwargs} - return openai_config, additional_kwargs + def _separate_openai_config(self, config): + """Separate the config into openai_config and extra_kwargs.""" + openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs} + extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs} + return openai_config, extra_kwargs + + def _separate_create_config(self, config): + """Separate the config into create_config and extra_kwargs.""" + create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs} + extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs} + return create_config, extra_kwargs def _client(self, config, openai_config): """Create a client with the given config to overrdie openai_config, - after removing additional kwargs. + after removing extra kwargs. """ - config = {**openai_config, **{k: v for k, v in config.items() if k not in self.additional_kwargs}} + config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}} client = OpenAI(**config) return client @@ -170,13 +177,17 @@ def yes_or_no_filter(context, response): """ if ERROR: raise ERROR - create_config, extra_kwargs = self._separate_config(config) last = len(self._clients) - 1 for i, client in enumerate(self._clients): - final_extra_kwargs = {**self._config_list[i], **extra_kwargs} - params = self._construct_create_params(create_config, final_extra_kwargs) - seed = final_extra_kwargs.get("seed", 41) - filter_func = final_extra_kwargs.get("filter_func") + # merge the input config with the i-th config in the config list + full_config = {**config, **self._config_list[i]} + # separate the config into create_config and extra_kwargs + create_config, extra_kwargs = self._separate_create_config(full_config) + # construct the create params + params = self._construct_create_params(create_config, extra_kwargs) + # get the seed, filter_func and context + seed = extra_kwargs.get("seed", 41) + filter_func = extra_kwargs.get("filter_func") context = extra_kwargs.get("context") with diskcache.Cache(f"{self.cache_path_root}/{seed}") as cache: if seed is not None: diff --git a/setup.py b/setup.py index d47e2dca362..688a783f6e0 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ __version__ = version["__version__"] install_requires = [ - "openai>=1", + "openai==1.0.0b3", "diskcache", "termcolor", "flaml", diff --git a/test/test_client.py b/test/test_client.py new file mode 100644 index 00000000000..9a39df2801e --- /dev/null +++ b/test/test_client.py @@ -0,0 +1,33 @@ +import pytest +from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai + +try: + from openai import OpenAI +except ImportError: + skip = True + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_chat_completion(): + config_list = config_list_from_json( + env_or_file="OAI_CONFIG_LIST", + file_location="notebook", + ) + client = OpenAIWrapper(config_list=config_list) + response = client.create(messages=[{"role": "user", "content": "1+1="}]) + print(response) + print(client.extract_text_or_function_call(response)) + + +@pytest.mark.skipif(skip, reason="openai>=1 not installed") +def test_completion(): + config_list = config_list_openai_aoai("notebook") + client = OpenAIWrapper(config_list=config_list) + response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct") + print(response) + print(client.extract_text_or_function_call(response)) + + +if __name__ == "__main__": + test_chat_completion() + test_completion()