Skip to content

Commit

Permalink
feat: Sharegpt conversion (#1144)
Browse files Browse the repository at this point in the history
Signed-off-by: Caelum Forder <[email protected]>
Co-authored-by: Wendong <[email protected]>
  • Loading branch information
CaelumF and Wendong-Fan authored Nov 13, 2024
1 parent 111143a commit 333c9d9
Show file tree
Hide file tree
Showing 12 changed files with 845 additions and 22 deletions.
16 changes: 16 additions & 0 deletions camel/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,24 @@
ChatCompletionUserMessageParam,
)

from .conversion import (
HermesFunctionFormatter,
ShareGPTMessage,
)
from .conversion.models import (
ShareGPTConversation,
)
from .conversion.sharegpt.function_call_formatter import (
FunctionCallFormatter,
)

OpenAISystemMessage = ChatCompletionSystemMessageParam
OpenAIAssistantMessage = ChatCompletionAssistantMessageParam
OpenAIUserMessage = ChatCompletionUserMessageParam
OpenAIFunctionMessage = ChatCompletionFunctionMessageParam
OpenAIMessage = ChatCompletionMessageParam


from .base import BaseMessage # noqa: E402
from .func_message import FunctionCallingMessage # noqa: E402

Expand All @@ -34,6 +46,10 @@
'OpenAIUserMessage',
'OpenAIFunctionMessage',
'OpenAIMessage',
'FunctionCallFormatter',
'HermesFunctionFormatter',
'ShareGPTConversation',
'ShareGPTMessage',
'BaseMessage',
'FunctionCallingMessage',
]
104 changes: 104 additions & 0 deletions camel/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,22 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
import base64
import io
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
from PIL import Image

from camel.messages import (
FunctionCallFormatter,
HermesFunctionFormatter,
OpenAIAssistantMessage,
OpenAIMessage,
OpenAISystemMessage,
OpenAIUserMessage,
)
from camel.messages.conversion import ShareGPTMessage
from camel.prompts import CodePrompt, TextPrompt
from camel.types import (
OpenAIBackendRole,
Expand Down Expand Up @@ -271,6 +275,106 @@ def extract_text_and_code_prompts(

return text_prompts, code_prompts

@classmethod
def from_sharegpt(
cls,
message: ShareGPTMessage,
function_format: Optional[FunctionCallFormatter[Any, Any]] = None,
role_mapping=None,
) -> "BaseMessage":
r"""Convert ShareGPT message to BaseMessage or FunctionCallingMessage.
Note tool calls and responses have an 'assistant' role in CAMEL
Args:
message (ShareGPTMessage): ShareGPT message to convert.
function_format (FunctionCallFormatter, optional): Function call
formatter to use. (default: :obj:`HermesFunctionFormatter()`.
role_mapping (Dict[str, List[str, RoleType]], optional): Role
mapping to use. Defaults to a CAMEL specific mapping.
Returns:
BaseMessage: Converted message.
"""
from camel.messages import FunctionCallingMessage

if role_mapping is None:
role_mapping = {
"system": ["system", RoleType.USER],
"human": ["user", RoleType.USER],
"gpt": ["assistant", RoleType.ASSISTANT],
"tool": ["assistant", RoleType.ASSISTANT],
}
role_name, role_type = role_mapping[message.from_]

if function_format is None:
function_format = HermesFunctionFormatter()

# Check if this is a function-related message
if message.from_ == "gpt":
func_info = function_format.extract_tool_calls(message.value)
if (
func_info and len(func_info) == 1
): # TODO: Handle multiple tool calls
# Including cleaned content is useful to
# remind consumers of non-considered content
clean_content = re.sub(
r"<tool_call>.*?</tool_call>",
"",
message.value,
flags=re.DOTALL,
).strip()

return FunctionCallingMessage(
role_name=role_name,
role_type=role_type,
meta_dict=None,
content=clean_content,
func_name=func_info[0].__dict__["name"],
args=func_info[0].__dict__["arguments"],
)
elif message.from_ == "tool":
func_r_info = function_format.extract_tool_response(message.value)
if func_r_info:
return FunctionCallingMessage(
role_name=role_name,
role_type=role_type,
meta_dict=None,
content="",
func_name=func_r_info.__dict__["name"],
result=func_r_info.__dict__["content"],
)

# Regular message
return cls(
role_name=role_name,
role_type=role_type,
meta_dict=None,
content=message.value,
)

def to_sharegpt(
self,
function_format: Optional[FunctionCallFormatter] = None,
) -> ShareGPTMessage:
r"""Convert BaseMessage to ShareGPT message
Args:
function_format (FunctionCallFormatter): Function call formatter
to use. Defaults to Hermes.
"""

if function_format is None:
function_format = HermesFunctionFormatter()

# Convert role type to ShareGPT 'from' field
if self.role_type == RoleType.USER:
from_ = "system" if self.role_name == "system" else "human"
else: # RoleType.ASSISTANT
from_ = "gpt"

# Function conversion code in FunctionCallingMessage
return ShareGPTMessage(from_=from_, value=self.content) # type: ignore[call-arg]

def to_openai_message(
self,
role_at_backend: OpenAIBackendRole,
Expand Down
29 changes: 29 additions & 0 deletions camel/messages/conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

from .models import (
ShareGPTConversation,
ShareGPTMessage,
ToolCall,
ToolResponse,
)
from .sharegpt import HermesFunctionFormatter

__all__ = [
'ShareGPTMessage',
'ShareGPTConversation',
'HermesFunctionFormatter',
'ToolCall',
'ToolResponse',
]
178 changes: 178 additions & 0 deletions camel/messages/conversion/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========

import json
from typing import Any, Dict, List, Literal

from pydantic import (
BaseModel,
Field,
RootModel,
field_validator,
model_validator,
)


class ShareGPTMessage(BaseModel):
r"""A single message in ShareGPT format with enhanced validation"""

from_: Literal["human", "gpt", "system", "tool"] = Field(
alias="from", description="The role of the message sender"
)
value: str = Field(
min_length=0,
max_length=100000,
description="The content of the message",
)

model_config = {
"populate_by_name": True,
"extra": "forbid",
"json_schema_extra": {
"examples": [
{"from": "human", "value": "What's the weather like today?"}
]
},
}


class ShareGPTConversation(RootModel):
r"""A full conversation in ShareGPT format with validation"""

root: List[ShareGPTMessage]

@model_validator(mode='after')
def validate_conversation_flow(self) -> 'ShareGPTConversation':
r"""Validate the conversation follows logical message order"""
messages = self.root

if not messages:
raise ValueError("Conversation cannot be empty")

if messages[0].from_ not in ("system", "human"):
raise ValueError(
"Conversation must start with either system or human message"
)

# Validate message sequence
for i in range(1, len(messages)):
curr, prev = messages[i], messages[i - 1]

if curr.from_ == "tool":
if prev.from_ != "gpt" or "<tool_call>" not in prev.value:
raise ValueError(
f"Tool response at position {i} "
f"must follow an gpt message with a tool call"
)

if curr.from_ == "gpt" and prev.from_ not in (
"human",
"tool",
):
raise ValueError(
f"Assistant message at position {i} "
f"must follow a human or tool message"
)

return self

def model_dump(self, **kwargs):
return self.root

def __iter__(self):
return iter(self.root)


class ToolCall(BaseModel):
r"""Represents a single tool/function call with validation"""

name: str = Field(
min_length=1,
max_length=256,
description="The name of the tool to call",
)
arguments: Dict[str, Any] = Field(
description="The arguments to pass to the tool"
)

@field_validator('arguments')
@classmethod
def validate_arguments(cls, v: Dict[str, Any]) -> Dict[str, Any]:
r"""Validate argument structure and content"""

# Try to serialize arguments to ensure they're JSON-compatible
try:
json.dumps(v)
except (TypeError, ValueError):
raise ValueError("Arguments must be JSON-serializable")

return v

model_config = {
"extra": "forbid",
"json_schema_extra": {
"examples": [
{
"name": "get_weather",
"arguments": {"city": "London", "units": "celsius"},
}
]
},
}


class ToolResponse(BaseModel):
r"""Represents a tool/function response with validation. This is a
base class and default implementation for tool responses, for the purpose
of converting between different formats.
"""

name: str = Field(
min_length=1,
max_length=256,
description="The name of the tool that was called",
)
content: Any = Field(
description="The response content from the tool."
" Must be JSON serializable literal or object"
)

@field_validator('content')
@classmethod
def validate_content(cls, v: Dict[str, Any]) -> Dict[str, Any]:
r"""Validate response content structure"""

# Ensure content is JSON-serializable
try:
json.dumps(v)
except (TypeError, ValueError):
raise ValueError("Response content must be JSON-serializable")

return v

model_config = {
"extra": "forbid",
"json_schema_extra": {
"examples": [
{
"name": "get_weather",
"content": {
"temperature": 20,
"conditions": "sunny",
"humidity": 65,
},
}
]
},
}
20 changes: 20 additions & 0 deletions camel/messages/conversion/sharegpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. ===========


from .hermes import HermesFunctionFormatter

__all__ = [
'HermesFunctionFormatter',
]
Loading

0 comments on commit 333c9d9

Please sign in to comment.