Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

openai client #419

Merged
merged 2 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions autogen/oai/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from autogen.oai._client import OpenAIWrapper
from autogen.oai.completion import Completion, ChatCompletion
from autogen.oai.openai_utils import (
get_config_list,
Expand All @@ -9,6 +10,7 @@
)

__all__ = [
"OpenAIWrapper",
"Completion",
"ChatCompletion",
"get_config_list",
Expand Down
234 changes: 234 additions & 0 deletions autogen/oai/_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,234 @@
import sys
from typing import List, Optional, Dict, Callable
import logging
import inspect
from flaml.automl.logger import logger_formatter
from openai.types.chat import ChatCompletion
from openai.types.completion import Completion

from autogen.oai.openai_utils import get_key

try:
from openai import (
RateLimitError,
APIError,
BadRequestError,
APIConnectionError,
Timeout,
AuthenticationError,
)
from openai import OpenAI
import diskcache

ERROR = None
except ImportError:
ERROR = ImportError("please install openai>=1 and diskcache to use the autogen.oai subpackage.")
OpenAI = object
logger = logging.getLogger(__name__)
if not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler(stream=sys.stdout)
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)


class OpenAIWrapper:
"""A wrapper class for openai client."""

cache_path_root: str = ".cache"
extra_kwargs = {"seed", "filter_func", "allow_format_str_template", "context", "api_type", "api_version"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)

def __init__(self, *, config_list: List[Dict] = None, **base_config):
"""
Args:
config_list: a list of config dicts to override the base_config.
They can contain additional kwargs as allowed in the [create](/docs/reference/oai/_client/#create) method. E.g.,

```python
config_list=[
{
"model": "gpt-4",
"api_key": os.environ.get("AZURE_OPENAI_API_KEY"),
"api_type": "azure",
"base_url": os.environ.get("AZURE_OPENAI_API_BASE"),
"api_version": "2023-03-15-preview",
},
{
"model": "gpt-3.5-turbo",
"api_key": os.environ.get("OPENAI_API_KEY"),
"api_type": "open_ai",
"base_url": "https://api.openai.com/v1",
},
{
"model": "llama-7B",
"base_url": "http://127.0.0.1:8080",
"api_type": "open_ai",
}
]
```

base_config: base config. It can contain both keyword arguments for openai client
and additional kwargs.
"""
openai_config, extra_kwargs = self._separate_openai_config(base_config)
if type(config_list) is list and len(config_list) == 0:
logger.warning("openai client was provided with an empty config_list, which may not be intended.")
if config_list:
self._clients = [self._client(config, openai_config) for config in config_list]
self._config_list = [
{**extra_kwargs, **{k: v for k, v in config.items() if k not in self.openai_kwargs}}
for config in config_list
]
else:
self._clients = [OpenAI(**openai_config)]
self._config_list = [extra_kwargs]

def _separate_openai_config(self, config):
"""Separate the config into openai_config and extra_kwargs."""
openai_config = {k: v for k, v in config.items() if k in self.openai_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k not in self.openai_kwargs}
return openai_config, extra_kwargs

def _separate_create_config(self, config):
"""Separate the config into create_config and extra_kwargs."""
create_config = {k: v for k, v in config.items() if k not in self.extra_kwargs}
extra_kwargs = {k: v for k, v in config.items() if k in self.extra_kwargs}
return create_config, extra_kwargs

def _client(self, config, openai_config):
"""Create a client with the given config to overrdie openai_config,
after removing extra kwargs.
"""
config = {**openai_config, **{k: v for k, v in config.items() if k in self.openai_kwargs}}
client = OpenAI(**config)
return client

@classmethod
def instantiate(
cls,
template: str | Callable | None,
context: Optional[Dict] = None,
allow_format_str_template: Optional[bool] = False,
):
if not context or template is None:
return template
if isinstance(template, str):
return template.format(**context) if allow_format_str_template else template
return template(context)

def _construct_create_params(self, create_config: Dict, extra_kwargs: Dict) -> Dict:
"""Prime the create_config with additional_kwargs."""
# Validate the config
prompt = create_config.get("prompt")
messages = create_config.get("messages")
if (prompt is None) == (messages is None):
raise ValueError("Either prompt or messages should be in create config but not both.")
context = extra_kwargs.get("context")
if context is None:
# No need to instantiate if no context is provided.
return create_config
# Instantiate the prompt or messages
allow_format_str_template = extra_kwargs.get("allow_format_str_template", False)
# Make a copy of the config
params = create_config.copy()
if prompt is not None:
# Instantiate the prompt
params["prompt"] = self.instantiate(prompt, context, allow_format_str_template)
elif context:
# Instantiate the messages
params["messages"] = [
{
**m,
"content": self.instantiate(m["content"], context, allow_format_str_template),
}
if m.get("content")
else m
for m in messages
]
return params

def create(self, **config):
"""Make a completion for a given config using openai's clients.
Besides the kwargs allowed in openai's client, we allow the following additional kwargs.
The config in each client will be overriden by the config.

Args:
- context (Dict | None): The context to instantiate the prompt or messages. Default to None.
It needs to contain keys that are used by the prompt template or the filter function.
E.g., `prompt="Complete the following sentence: {prefix}, context={"prefix": "Today I feel"}`.
The actual prompt will be:
"Complete the following sentence: Today I feel".
More examples can be found at [templating](/docs/Use-Cases/enhanced_inference#templating).
- `seed` (int | None) for the cache. Default to 41.
An integer seed is useful when implementing "controlled randomness" for the completion.
None for no caching.
- filter_func (Callable | None): A function that takes in the context and the response
and returns a boolean to indicate whether the response is valid. E.g.,

```python
def yes_or_no_filter(context, response):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response)
)
```

- allow_format_str_template (bool | None): Whether to allow format string template in the config. Default to false.
"""
if ERROR:
raise ERROR
last = len(self._clients) - 1
for i, client in enumerate(self._clients):
# merge the input config with the i-th config in the config list
full_config = {**config, **self._config_list[i]}
# separate the config into create_config and extra_kwargs
create_config, extra_kwargs = self._separate_create_config(full_config)
# construct the create params
params = self._construct_create_params(create_config, extra_kwargs)
# get the seed, filter_func and context
seed = extra_kwargs.get("seed", 41)
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")
with diskcache.Cache(f"{self.cache_path_root}/{seed}") as cache:
if seed is not None:
# Try to get the response from cache
key = get_key(params)
response = cache.get(key, None)
if response is not None:
# check the filter
pass_filter = filter_func is None or filter_func(context=context, response=response)
if pass_filter or i == last:
# Return the response if it passes the filter or it is the last client
response.config_id = i
response.pass_filter = pass_filter
# TODO: add response.cost
return response
completions = client.chat.completions if "messages" in params else client.completions
try:
response = completions.create(**params)
except APIConnectionError:
# This seems to be the only error raised by openai
logger.debug(f"config {i} failed", exc_info=1)
if i == last:
raise
else:
if seed is not None:
# Cache the response
cache.set(key, response)
return response

def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
"""Extract the text or function calls from a completion or chat response.

Args:
response (ChatCompletion | Completion): The response from openai.

Returns:
A list of text or function calls in the responses.
"""
choices = response.choices
if isinstance(response, Completion):
return [choice.text for choice in choices]
return [
choice.message if choice.message.function_call is not None else choice.message.content for choice in choices
]
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
__version__ = version["__version__"]

install_requires = [
"openai>=1",
"openai==1.0.0b3",
"diskcache",
"termcolor",
"flaml",
Expand Down
33 changes: 33 additions & 0 deletions test/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai

try:
from openai import OpenAI
except ImportError:
skip = True


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_completion():
config_list = config_list_from_json(
env_or_file="OAI_CONFIG_LIST",
file_location="notebook",
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}])
print(response)
print(client.extract_text_or_function_call(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion():
config_list = config_list_openai_aoai("notebook")
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
print(response)
print(client.extract_text_or_function_call(response))


if __name__ == "__main__":
test_chat_completion()
test_completion()
Loading