Skip to content

Commit d31ecb3

Browse files
yiranwu0qingyun-wusonichi
authored
Convert ChatCompletionMessage to Dict after completion (microsoft#791)
* update * update * update signature * update * update * fix test funccall groupchat * reverse change * update * update * update * update * update --------- Co-authored-by: Qingyun Wu <[email protected]> Co-authored-by: Chi Wang <[email protected]>
1 parent ec18925 commit d31ecb3

File tree

9 files changed

+87
-22
lines changed

9 files changed

+87
-22
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ jobs:
4949
- name: Coverage
5050
if: matrix.python-version == '3.10'
5151
run: |
52-
pip install -e .[mathchat,test]
52+
pip install -e .[test]
5353
pip uninstall -y openai
5454
coverage run -a -m pytest test --ignore=test/agentchat/contrib
5555
coverage xml

autogen/agentchat/contrib/compressible_agent.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def compress_messages(
403403
print(colored(f"Failed to compress the content due to {e}", "red"), flush=True)
404404
return False, None
405405

406-
compressed_message = self.client.extract_text_or_function_call(response)[0]
406+
compressed_message = self.client.extract_text_or_completion_object(response)[0]
407407
assert isinstance(compressed_message, str), f"compressed_message should be a string: {compressed_message}"
408408
if self.compress_config["verbose"]:
409409
print(

autogen/agentchat/conversable_agent.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,12 @@ def generate_oai_reply(
631631
response = client.create(
632632
context=messages[-1].pop("context", None), messages=self._oai_system_message + messages
633633
)
634-
return True, client.extract_text_or_function_call(response)[0]
634+
635+
# TODO: line 301, line 271 is converting messages to dict. Can be removed after ChatCompletionMessage_to_dict is merged.
636+
extracted_response = client.extract_text_or_completion_object(response)[0]
637+
if not isinstance(extracted_response, str):
638+
extracted_response = extracted_response.model_dump(mode="dict")
639+
return True, extracted_response
635640

636641
async def a_generate_oai_reply(
637642
self,

autogen/oai/client.py

+23-7
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,18 @@
1010
from autogen.oai.openai_utils import get_key, oai_price1k
1111
from autogen.token_count_utils import count_token
1212

13+
TOOL_ENABLED = False
1314
try:
15+
import openai
1416
from openai import OpenAI, APIError
1517
from openai.types.chat import ChatCompletion
1618
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
1719
from openai.types.completion import Completion
1820
from openai.types.completion_usage import CompletionUsage
1921
import diskcache
2022

23+
if openai.__version__ >= "1.1.0":
24+
TOOL_ENABLED = True
2125
ERROR = None
2226
except ImportError:
2327
ERROR = ImportError("Please install openai>=1 and diskcache to use autogen.OpenAIWrapper.")
@@ -205,7 +209,7 @@ def create(self, **config):
205209
```python
206210
def yes_or_no_filter(context, response):
207211
return context.get("yes_or_no_choice", False) is False or any(
208-
text in ["Yes.", "No."] for text in client.extract_text_or_function_call(response)
212+
text in ["Yes.", "No."] for text in client.extract_text_or_completion_object(response)
209213
)
210214
```
211215
@@ -442,21 +446,33 @@ def cost(self, response: Union[ChatCompletion, Completion]) -> float:
442446
return tmp_price1K * (n_input_tokens + n_output_tokens) / 1000
443447

444448
@classmethod
445-
def extract_text_or_function_call(cls, response: ChatCompletion | Completion) -> List[str]:
446-
"""Extract the text or function calls from a completion or chat response.
449+
def extract_text_or_completion_object(
450+
cls, response: ChatCompletion | Completion
451+
) -> Union[List[str], List[ChatCompletionMessage]]:
452+
"""Extract the text or ChatCompletion objects from a completion or chat response.
447453
448454
Args:
449455
response (ChatCompletion | Completion): The response from openai.
450456
451457
Returns:
452-
A list of text or function calls in the responses.
458+
A list of text, or a list of ChatCompletion objects if function_call/tool_calls are present.
453459
"""
454460
choices = response.choices
455461
if isinstance(response, Completion):
456462
return [choice.text for choice in choices]
457-
return [
458-
choice.message if choice.message.function_call is not None else choice.message.content for choice in choices
459-
]
463+
464+
if TOOL_ENABLED:
465+
return [
466+
choice.message
467+
if choice.message.function_call is not None or choice.message.tool_calls is not None
468+
else choice.message.content
469+
for choice in choices
470+
]
471+
else:
472+
return [
473+
choice.message if choice.message.function_call is not None else choice.message.content
474+
for choice in choices
475+
]
460476

461477

462478
# TODO: logging

test/agentchat/test_function_call.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_eval_math_responses():
4848
functions=functions,
4949
)
5050
print(response)
51-
responses = client.extract_text_or_function_call(response)
51+
responses = client.extract_text_or_completion_object(response)
5252
print(responses[0])
5353
function_call = responses[0].function_call
5454
name, arguments = function_call.name, json.loads(function_call.arguments)

test/oai/test_client.py

+47-3
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,18 @@
22
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
33
from test_utils import OAI_CONFIG_LIST, KEY_LOC
44

5+
TOOL_ENABLED = False
56
try:
67
from openai import OpenAI
8+
from openai.types.chat.chat_completion import ChatCompletionMessage
79
except ImportError:
810
skip = True
911
else:
1012
skip = False
13+
import openai
14+
15+
if openai.__version__ >= "1.1.0":
16+
TOOL_ENABLED = True
1117

1218

1319
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@@ -24,7 +30,44 @@ def test_aoai_chat_completion():
2430
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
2531
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
2632
print(response)
27-
print(client.extract_text_or_function_call(response))
33+
print(client.extract_text_or_completion_object(response))
34+
35+
36+
@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
37+
def test_oai_tool_calling_extraction():
38+
config_list = config_list_from_json(
39+
env_or_file=OAI_CONFIG_LIST,
40+
file_location=KEY_LOC,
41+
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
42+
)
43+
client = OpenAIWrapper(config_list=config_list)
44+
response = client.create(
45+
messages=[
46+
{
47+
"role": "user",
48+
"content": "What is the weather in San Francisco?",
49+
},
50+
],
51+
tools=[
52+
{
53+
"type": "function",
54+
"function": {
55+
"name": "getCurrentWeather",
56+
"description": "Get the weather in location",
57+
"parameters": {
58+
"type": "object",
59+
"properties": {
60+
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
61+
"unit": {"type": "string", "enum": ["c", "f"]},
62+
},
63+
"required": ["location"],
64+
},
65+
},
66+
}
67+
],
68+
)
69+
print(response)
70+
print(client.extract_text_or_completion_object(response))
2871

2972

3073
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@@ -36,7 +79,7 @@ def test_chat_completion():
3679
client = OpenAIWrapper(config_list=config_list)
3780
response = client.create(messages=[{"role": "user", "content": "1+1="}])
3881
print(response)
39-
print(client.extract_text_or_function_call(response))
82+
print(client.extract_text_or_completion_object(response))
4083

4184

4285
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@@ -45,7 +88,7 @@ def test_completion():
4588
client = OpenAIWrapper(config_list=config_list)
4689
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
4790
print(response)
48-
print(client.extract_text_or_function_call(response))
91+
print(client.extract_text_or_completion_object(response))
4992

5093

5194
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@@ -96,6 +139,7 @@ def test_usage_summary():
96139

97140
if __name__ == "__main__":
98141
test_aoai_chat_completion()
142+
test_oai_tool_calling_extraction()
99143
test_chat_completion()
100144
test_completion()
101145
test_cost()

test/oai/test_client_stream.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_aoai_chat_completion_stream():
2020
client = OpenAIWrapper(config_list=config_list)
2121
response = client.create(messages=[{"role": "user", "content": "2+2="}], stream=True)
2222
print(response)
23-
print(client.extract_text_or_function_call(response))
23+
print(client.extract_text_or_completion_object(response))
2424

2525

2626
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@@ -33,7 +33,7 @@ def test_chat_completion_stream():
3333
client = OpenAIWrapper(config_list=config_list)
3434
response = client.create(messages=[{"role": "user", "content": "1+1="}], stream=True)
3535
print(response)
36-
print(client.extract_text_or_function_call(response))
36+
print(client.extract_text_or_completion_object(response))
3737

3838

3939
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@@ -66,7 +66,7 @@ def test_chat_functions_stream():
6666
stream=True,
6767
)
6868
print(response)
69-
print(client.extract_text_or_function_call(response))
69+
print(client.extract_text_or_completion_object(response))
7070

7171

7272
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@@ -75,7 +75,7 @@ def test_completion_stream():
7575
client = OpenAIWrapper(config_list=config_list)
7676
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct", stream=True)
7777
print(response)
78-
print(client.extract_text_or_function_call(response))
78+
print(client.extract_text_or_completion_object(response))
7979

8080

8181
if __name__ == "__main__":

website/docs/Installation.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Therefore, some changes are required for users of `pyautogen<0.2`.
6161
from autogen import OpenAIWrapper
6262
client = OpenAIWrapper(config_list=config_list)
6363
response = client.create(messages=[{"role": "user", "content": "2+2="}])
64-
print(client.extract_text_or_function_call(response))
64+
print(client.extract_text_or_completion_object(response))
6565
```
6666
- Inference parameter tuning and inference logging features are currently unavailable in `OpenAIWrapper`. Logging will be added in a future release.
6767
Inference parameter tuning can be done via [`flaml.tune`](https://microsoft.github.io/FLAML/docs/Use-Cases/Tune-User-Defined-Function).

website/docs/Use-Cases/enhanced_inference.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,15 @@ client = OpenAIWrapper()
119119
# ChatCompletion
120120
response = client.create(messages=[{"role": "user", "content": "2+2="}], model="gpt-3.5-turbo")
121121
# extract the response text
122-
print(client.extract_text_or_function_call(response))
122+
print(client.extract_text_or_completion_object(response))
123123
# get cost of this completion
124124
print(response.cost)
125125
# Azure OpenAI endpoint
126126
client = OpenAIWrapper(api_key=..., base_url=..., api_version=..., api_type="azure")
127127
# Completion
128128
response = client.create(prompt="2+2=", model="gpt-3.5-turbo-instruct")
129129
# extract the response text
130-
print(client.extract_text_or_function_call(response))
130+
print(client.extract_text_or_completion_object(response))
131131

132132
```
133133

@@ -240,7 +240,7 @@ Another type of error is that the returned response does not satisfy a requireme
240240

241241
```python
242242
def valid_json_filter(response, **_):
243-
for text in OpenAIWrapper.extract_text_or_function_call(response):
243+
for text in OpenAIWrapper.extract_text_or_completion_object(response):
244244
try:
245245
json.loads(text)
246246
return True

0 commit comments

Comments
 (0)