Skip to content

Commit

Permalink
feat: support openai structured output (#1198)
Browse files Browse the repository at this point in the history
Co-authored-by: CaelumF <[email protected]>
  • Loading branch information
Wendong-Fan and CaelumF authored Nov 20, 2024
1 parent f972161 commit 146af24
Show file tree
Hide file tree
Showing 6 changed files with 133 additions and 3 deletions.
9 changes: 9 additions & 0 deletions camel/agents/chat_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,15 @@ def step(
a boolean indicating whether the chat session has terminated,
and information about the chat session.
"""
if (
self.model_backend.model_config_dict.get("response_format")
and response_format
):
raise ValueError(
"The `response_format` parameter cannot be set both in "
"the model configuration and in the ChatAgent step."
)

if isinstance(input_message, str):
input_message = BaseMessage.make_user_message(
role_name='User', content=input_message
Expand Down
31 changes: 28 additions & 3 deletions camel/configs/openai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from __future__ import annotations

from typing import Optional, Sequence, Union
from typing import Any, Optional, Sequence, Type, Union

from pydantic import Field
from pydantic import BaseModel, Field

from camel.configs.base_config import BaseConfig
from camel.types import NOT_GIVEN, NotGiven
Expand Down Expand Up @@ -104,11 +104,36 @@ class ChatGPTConfig(BaseConfig):
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
max_tokens: Union[int, NotGiven] = NOT_GIVEN
presence_penalty: float = 0.0
response_format: Union[dict, NotGiven] = NOT_GIVEN
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
frequency_penalty: float = 0.0
logit_bias: dict = Field(default_factory=dict)
user: str = ""
tool_choice: Optional[Union[dict[str, str], str]] = None

def as_dict(self) -> dict[str, Any]:
r"""Convert the current configuration to a dictionary.
This method converts the current configuration object to a dictionary
representation, which can be used for serialization or other purposes.
Returns:
dict[str, Any]: A dictionary representation of the current
configuration.
"""
config_dict = self.model_dump()
if self.tools:
from camel.toolkits import FunctionTool

tools_schema = []
for tool in self.tools:
if not isinstance(tool, FunctionTool):
raise ValueError(
f"The tool {tool} should "
"be an instance of `FunctionTool`."
)
tools_schema.append(tool.get_openai_tool_schema())
config_dict["tools"] = NOT_GIVEN
return config_dict


OPENAI_API_PARAMS = {param for param in ChatGPTConfig.model_fields.keys()}
38 changes: 38 additions & 0 deletions camel/models/openai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
ChatCompletion,
ChatCompletionChunk,
ModelType,
ParsedChatCompletion,
)
from camel.utils import (
BaseTokenCounter,
Expand Down Expand Up @@ -127,13 +128,50 @@ def run(
self.model_config_dict["presence_penalty"] = 0.0
self.model_config_dict["frequency_penalty"] = 0.0

if self.model_config_dict.get("response_format"):
# stream is not supported in beta.chat.completions.parse
if "stream" in self.model_config_dict:
del self.model_config_dict["stream"]

response = self._client.beta.chat.completions.parse(
messages=messages,
model=self.model_type,
**self.model_config_dict,
)

return self._to_chat_completion(response)

response = self._client.chat.completions.create(
messages=messages,
model=self.model_type,
**self.model_config_dict,
)
return response

def _to_chat_completion(
self, response: "ParsedChatCompletion"
) -> ChatCompletion:
# TODO: Handle n > 1 or warn consumers it's not supported
choice = dict(
index=response.choices[0].index,
message={
"role": response.choices[0].message.role,
"content": response.choices[0].message.content,
"tool_calls": response.choices[0].message.tool_calls,
},
finish_reason=response.choices[0].finish_reason,
)

obj = ChatCompletion.construct(
id=response.id,
choices=[choice],
created=response.created,
model=response.model,
object="chat.completion",
usage=response.usage,
)
return obj

def check_model_config(self):
r"""Check whether the model configuration contains any
unexpected arguments to OpenAI API.
Expand Down
2 changes: 2 additions & 0 deletions camel/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
Choice,
CompletionUsage,
NotGiven,
ParsedChatCompletion,
)
from .unified_model_type import UnifiedModelType

Expand Down Expand Up @@ -71,4 +72,5 @@
'UnifiedModelType',
'NOT_GIVEN',
'NotGiven',
'ParsedChatCompletion',
]
2 changes: 2 additions & 0 deletions camel/types/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ChatCompletionUserMessageParam,
)
from openai.types.completion_usage import CompletionUsage
from openai.types.chat import ParsedChatCompletion
from openai._types import NOT_GIVEN, NotGiven

Choice = Choice
Expand All @@ -45,3 +46,4 @@
CompletionUsage = CompletionUsage
NOT_GIVEN = NOT_GIVEN
NotGiven = NotGiven
ParsedChatCompletion = ParsedChatCompletion
54 changes: 54 additions & 0 deletions examples/models/openai_structured_output_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
from pydantic import BaseModel

from camel.agents import ChatAgent
from camel.configs import ChatGPTConfig
from camel.models import ModelFactory
from camel.types import ModelPlatformType, ModelType


class Student(BaseModel):
name: str
age: str


class StudentList(BaseModel):
studentLlst: list[Student]


openai_model = ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=ModelType.GPT_4O_MINI,
model_config_dict=ChatGPTConfig(
temperature=0.0, response_format=StudentList
).as_dict(),
)

# Set agent
camel_agent = ChatAgent(model=openai_model)

# Set user message
user_msg = """give me some student infos."""

# Get response information
response = camel_agent.step(user_msg)
print(response.msgs[0].content)
'''
===============================================================================
{"studentLlst":[{"name":"Alice Johnson","age":"20"},{"name":"Brian Smith",
"age":"22"},{"name":"Catherine Lee","age":"19"},{"name":"David Brown",
"age":"21"},{"name":"Eva White","age":"20"}]}
===============================================================================
'''

0 comments on commit 146af24

Please sign in to comment.