Skip to content

Convert ChatCompletionMessage to Dict after completion #791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Dec 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
687fdb8
update
yiranwu0 Nov 28, 2023
3788c0a
update
yiranwu0 Nov 28, 2023
f34aec5
Merge remote-tracking branch 'origin/main' into functotool
yiranwu0 Dec 2, 2023
84a8bdb
update signature
yiranwu0 Dec 2, 2023
1051e1a
Merge branch 'main' into functotool
yiranwu0 Dec 2, 2023
96297df
Merge remote-tracking branch 'origin/main' into functotool
yiranwu0 Dec 4, 2023
1d6a18b
update
yiranwu0 Dec 4, 2023
0f0fd46
update
yiranwu0 Dec 4, 2023
92173ca
Merge branch 'main' into functotool
yiranwu0 Dec 4, 2023
1b4dba2
Merge branch 'main' into functotool
qingyun-wu Dec 5, 2023
f17da84
Merge branch 'main' into functotool
yiranwu0 Dec 5, 2023
d39cb30
Merge branch 'main' into functotool
yiranwu0 Dec 7, 2023
6dcf62f
Merge branch 'main' into functotool
sonichi Dec 7, 2023
7908d9c
fix test funccall groupchat
yiranwu0 Dec 7, 2023
f02cd1c
reverse change
yiranwu0 Dec 7, 2023
de5776d
Merge branch 'main' into functotool
yiranwu0 Dec 7, 2023
560b1ed
update
yiranwu0 Dec 9, 2023
c2127e5
Merge branch 'main' into functotool
yiranwu0 Dec 9, 2023
445810a
update
yiranwu0 Dec 10, 2023
3eaaffa
Merge branch 'functotool' of github.com:microsoft/autogen into functo…
yiranwu0 Dec 10, 2023
29b460a
Merge branch 'main' into functotool
yiranwu0 Dec 10, 2023
02480b2
update
yiranwu0 Dec 10, 2023
c9da966
Merge branch 'functotool' of github.com:microsoft/autogen into functo…
yiranwu0 Dec 10, 2023
db5b4e3
update
yiranwu0 Dec 10, 2023
ffbaa87
Merge remote-tracking branch 'origin/main' into functotool
yiranwu0 Dec 10, 2023
fe97bf2
update
yiranwu0 Dec 10, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
- name: Coverage
if: matrix.python-version == '3.10'
run: |
pip install -e .[mathchat,test]
pip install -e .[test]
pip uninstall -y openai
coverage run -a -m pytest test --ignore=test/agentchat/contrib
coverage xml
Expand Down
2 changes: 1 addition & 1 deletion autogen/agentchat/contrib/compressible_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ def compress_messages(
print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
return False, None

compressed_message = self.client.extract_text_or_function_call(response)[0]
compressed_message = self.client.extract_text_or_completion_object(response)[0]
assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
if self.compress_config["verbose"]:
print(
Expand Down
7 changes: 6 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,12 @@ def generate_oai_reply(
response = client.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
)
return True, client.extract_text_or_function_call(response)[0]

# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
extracted_response = client.extract_text_or_completion_object(response)[0]
if not isinstance(extracted_response, str):
extracted_response = extracted_response.model_dump(mode="dict")
return True, extracted_response

async def a_generate_oai_reply(
self,
Expand Down
30 changes: 23 additions & 7 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,18 @@
from autogen.oai.openai_utils import get_key, oai_price1k
from autogen.token_count_utils import count_token

TOOL_ENABLED = False
try:
import openai
from openai import OpenAI, APIError
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
import diskcache

if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
ERROR = None
except ImportError:
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
Expand Down Expand Up @@ -205,7 +209,7 @@ def create(self, **config):
```python
def yes_or_no_filter(context, response):
return context.get("yes_or_no_choice", False) is False or any(
text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response)
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
)
```

Expand Down Expand Up @@ -442,21 +446,33 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000

@classmethod
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
"""Extract the text or function calls from a completion or chat response.
def extract_text_or_completion_object(
cls, response: ChatCompletion | Completion
) -> Union[List[str], List[ChatCompletionMessage]]:
"""Extract the text or ChatCompletion objects from a completion or chat response.

Args:
response (ChatCompletion | Completion): The response from openai.

Returns:
A list of text or function calls in the responses.
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
"""
choices = response.choices
if isinstance(response, Completion):
return [choice.text for choice in choices]
return [
choice.message if choice.message.function_call is not None else choice.message.content for choice in choices
]

if TOOL_ENABLED:
return [
choice.message
if choice.message.function_call is not None or choice.message.tool_calls is not None
else choice.message.content
for choice in choices
]
else:
return [
choice.message if choice.message.function_call is not None else choice.message.content
for choice in choices
]


# TODO: logging
2 changes: 1 addition & 1 deletion test/agentchat/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_eval_math_responses():
functions=functions,
)
print(response)
responses = client.extract_text_or_function_call(response)
responses = client.extract_text_or_completion_object(response)
print(responses[0])
function_call = responses[0].function_call
name, arguments = function_call.name, json.loads(function_call.arguments)
Expand Down
50 changes: 47 additions & 3 deletions test/oai/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
from test_utils import OAI_CONFIG_LIST, KEY_LOC

TOOL_ENABLED = False
try:
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletionMessage
except ImportError:
skip = True
else:
skip = False
import openai

if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -24,7 +30,44 @@ def test_aoai_chat_completion():
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
def test_oai_tool_calling_extraction():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(
messages=[
{
"role": "user",
"content": "What is the weather in San Francisco?",
},
],
tools=[
{
"type": "function",
"function": {
"name": "getCurrentWeather",
"description": "Get the weather in location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
},
}
],
)
print(response)
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -36,7 +79,7 @@ def test_chat_completion():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}])
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -45,7 +88,7 @@ def test_completion():
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand Down Expand Up @@ -96,6 +139,7 @@ def test_usage_summary():

if __name__ == "__main__":
test_aoai_chat_completion()
test_oai_tool_calling_extraction()
test_chat_completion()
test_completion()
test_cost()
Expand Down
8 changes: 4 additions & 4 deletions test/oai/test_client_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_aoai_chat_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -33,7 +33,7 @@ def test_chat_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand Down Expand Up @@ -66,7 +66,7 @@ def test_chat_functions_stream():
stream=True,
)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


@pytest.mark.skipif(skip, reason="openai>=1 not installed")
Expand All @@ -75,7 +75,7 @@ def test_completion_stream():
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
print(response)
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion website/docs/Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Therefore, some changes are required for users of `pyautogen<0.2`.
from autogen import OpenAIWrapper
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "2+2="}])
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
```
- Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release.
Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).
Expand Down
6 changes: 3 additions & 3 deletions website/docs/Use-Cases/enhanced_inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,15 @@ client = OpenAIWrapper()
# ChatCompletion
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
# extract the response text
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))
# get cost of this completion
print(response.cost)
# Azure OpenAI endpoint
client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azure")
# Completion
response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct")
# extract the response text
print(client.extract_text_or_function_call(response))
print(client.extract_text_or_completion_object(response))

```

Expand Down Expand Up @@ -240,7 +240,7 @@ Another type of error is that the returned response does not satisfy a requireme

```python
def valid_json_filter(response, **_):
for text in OpenAIWrapper.extract_text_or_function_call(response):
for text in OpenAIWrapper.extract_text_or_completion_object(response):
try:
json.loads(text)
return True
Expand Down