Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update openai model support #1082

Merged
merged 6 commits into from
Jun 16, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flaml/autogen/agent/assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def receive(self, message, sender):
self._conversations[sender.name] = [{"content": self._system_message, "role": "system"}]
super().receive(message, sender)
responses = oai.ChatCompletion.create(messages=self._conversations[sender.name], **self._config)
# TODO: handle function_call
response = oai.ChatCompletion.extract_text(responses)[0]
self._send(response, sender)

Expand Down
41 changes: 32 additions & 9 deletions flaml/autogen/oai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,16 @@ class Completion(openai_Completion):
# set of models that support chat completion
chat_models = {
"gpt-3.5-turbo",
"gpt-3.5-turbo-0301",
"gpt-3.5-turbo-0301", # deprecate in Sep
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k",
"gpt-35-turbo",
"gpt-4",
"gpt-4-32k",
"gpt-4-32k-0314",
"gpt-4-0314",
"gpt-4-32k-0314", # deprecate in Sep
"gpt-4-0314", # deprecate in Sep
"gpt-4-0613",
"gpt-4-32k-0613",
}

# price per 1k tokens
Expand All @@ -62,13 +66,17 @@ class Completion(openai_Completion):
"code-davinci-002": 0.1,
"text-davinci-002": 0.02,
"text-davinci-003": 0.02,
"gpt-3.5-turbo": 0.002,
"gpt-3.5-turbo-0301": 0.002,
"gpt-3.5-turbo": (0.0015, 0.002),
"gpt-3.5-turbo-0301": (0.0015, 0.002), # deprecate in Sep
"gpt-3.5-turbo-0613": (0.0015, 0.002),
"gpt-3.5-turbo-16k": (0.003, 0.004),
"gpt-35-turbo": 0.002,
"gpt-4": (0.03, 0.06),
"gpt-4-0314": (0.03, 0.06),
"gpt-4-32k": (0.06, 0.12),
"gpt-4-32k-0314": (0.06, 0.12),
"gpt-4-0314": (0.03, 0.06), # deprecate in Sep
"gpt-4-32k-0314": (0.06, 0.12), # deprecate in Sep
"gpt-4-0613": (0.03, 0.06),
"gpt-4-32k-0613": (0.06, 0.12),
}

default_search_space = {
Expand Down Expand Up @@ -386,7 +394,7 @@ def _eval(cls, config: dict, prune=True, eval_only=False):
result["cost"] = cost
return result
# evaluate the quality of the responses
responses = cls.extract_text(response)
responses = cls.extract_text_or_function_call(response)
usage = response["usage"]
n_input_tokens = usage["prompt_tokens"]
n_output_tokens = usage.get("completion_tokens", 0)
Expand Down Expand Up @@ -898,7 +906,7 @@ def eval_func(responses, **data):
response = cls.create(data_i, use_cache, **config)
cost += response["cost"]
# evaluate the quality of the responses
responses = cls.extract_text(response)
responses = cls.extract_text_or_function_call(response)
if eval_func is not None:
metrics = eval_func(responses, **data_i)
elif hasattr(cls, "_eval_func"):
Expand Down Expand Up @@ -991,6 +999,21 @@ def extract_text(cls, response: dict) -> List[str]:
return [choice["text"] for choice in choices]
return [choice["message"].get("content", "") for choice in choices]

@classmethod
def extract_text_or_function_call(cls, response: dict) -> List[str]:
"""Extract the text or function calls from a completion or chat response.

Args:
response (dict): The response from OpenAI API.

Returns:
A list of text or function calls in the responses.
"""
choices = response["choices"]
if "text" in choices[0]:
return [choice["text"] for choice in choices]
yiranwu0 marked this conversation as resolved.
Show resolved Hide resolved
return [choice["message"].get("content") or choice["message"].get("function_call", "") for choice in choices]

@classmethod
@property
def logged_history(cls) -> Dict:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@
"pytorch-forecasting>=0.9.0",
],
"benchmark": ["catboost>=0.26", "psutil==5.8.0", "xgboost==1.3.3", "pandas==1.1.4"],
"openai": ["openai==0.27.4", "diskcache"],
"autogen": ["openai==0.27.4", "diskcache", "docker"],
"openai": ["openai==0.27.8", "diskcache"],
"autogen": ["openai==0.27.8", "diskcache", "docker"],
"synapse": [
"joblibspark>=0.5.0",
"optuna==2.8.0",
Expand Down
4 changes: 2 additions & 2 deletions test/autogen/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ def test_gpt35(human_input_mode="NEVER", max_consecutive_auto_reply=5):
import openai
except ImportError:
return
config_list = oai.config_list_from_models(key_file_path=KEY_LOC, model_list=["gpt-3.5-turbo"])
config_list = oai.config_list_from_models(key_file_path=KEY_LOC, model_list=["gpt-3.5-turbo-0613"])
assistant = AssistantAgent(
"coding_agent",
request_timeout=600,
# request_timeout=600,
seed=40,
max_tokens=1024,
config_list=config_list,
Expand Down
4 changes: 2 additions & 2 deletions website/docs/Use-Cases/Auto-Generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,14 @@ Set `compact=False` in `start_logging()` to switch.
},
}
```
It can be seen that the individual API call history contain redundant information of the conversation. For a long conversation the degree of redundancy is high.
It can be seen that the individual API call history contains redundant information of the conversation. For a long conversation the degree of redundancy is high.
The compact history is more efficient and the individual API call history contains more details.

### Other Utilities

- a [`cost`](../reference/autogen/oai/completion#cost) function to calculate the cost of an API call.
- a [`test`](../reference/autogen/oai/completion#test) function to conveniently evaluate the configuration over test data.
- a [`extract_text`](../reference/autogen/oai/completion#extract_text) function to extract the text from a completion or chat response.
- an [`extract_text_or_function_call`](../reference/autogen/oai/completion#extract_text_or_function_call) function to extract the text or function call from a completion or chat response.


## Agents (Experimental)
Expand Down