-
Notifications
You must be signed in to change notification settings - Fork 692
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Integrate cohere models (#932)
Co-authored-by: NeilJohnson0930 <[email protected]>
- Loading branch information
1 parent
d4e2356
commit f972161
Showing
13 changed files
with
824 additions
and
144 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from __future__ import annotations | ||
|
||
from typing import List, Optional | ||
|
||
from camel.configs.base_config import BaseConfig | ||
|
||
|
||
class CohereConfig(BaseConfig): | ||
r"""Defines the parameters for generating chat completions using the | ||
Cohere API. | ||
Args: | ||
temperature (float, optional): Sampling temperature to use, between | ||
:obj:`0` and :obj:`2`. Higher values make the output more random, | ||
while lower values make it more focused and deterministic. | ||
(default: :obj:`0.3`) | ||
documents (list, optional): A list of relevant documents that the | ||
model can cite to generate a more accurate reply. Each document is | ||
either a string or document object with content and metadata. | ||
(default: :obj:`None`) | ||
max_tokens (int, optional): The maximum number of tokens the model | ||
will generate as part of the response. (default: :obj:`None`) | ||
stop_sequences (List(str), optional): A list of up to 5 strings that | ||
the model will use to stop generation. If the model generates a | ||
string that matches any of the strings in the list, it will stop | ||
generating tokens and return the generated text up to that point | ||
not including the stop sequence. (default: :obj:`None`) | ||
seed (int, optional): If specified, the backend will make a best | ||
effort to sample tokens deterministically, such that repeated | ||
requests with the same seed and parameters should return the same | ||
result. However, determinism cannot be totally guaranteed. | ||
(default: :obj:`None`) | ||
frequency_penalty (float, optional): Min value of `0.0`, max value of | ||
`1.0`. Used to reduce repetitiveness of generated tokens. The | ||
higher the value, the stronger a penalty is applied to previously | ||
present tokens, proportional to how many times they have already | ||
appeared in the prompt or prior generation. (default: :obj:`0.0`) | ||
presence_penalty (float, optional): Min value of `0.0`, max value of | ||
`1.0`. Used to reduce repetitiveness of generated tokens. Similar | ||
to `frequency_penalty`, except that this penalty is applied | ||
equally to all tokens that have already appeared, regardless of | ||
their exact frequencies. (default: :obj:`0.0`) | ||
k (int, optional): Ensures only the top k most likely tokens are | ||
considered for generation at each step. Min value of `0`, max | ||
value of `500`. (default: :obj:`0`) | ||
p (float, optional): Ensures that only the most likely tokens, with | ||
total probability mass of `p`, are considered for generation at | ||
each step. If both k and p are enabled, `p` acts after `k`. Min | ||
value of `0.01`, max value of `0.99`. (default: :obj:`0.75`) | ||
""" | ||
|
||
temperature: Optional[float] = 0.2 | ||
documents: Optional[list] = None | ||
max_tokens: Optional[int] = None | ||
stop_sequences: Optional[List[str]] = None | ||
seed: Optional[int] = None | ||
frequency_penalty: Optional[float] = 0.0 | ||
presence_penalty: Optional[float] = 0.0 | ||
k: Optional[int] = 0 | ||
p: Optional[float] = 0.75 | ||
|
||
|
||
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,282 @@ | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2023 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
import ast | ||
import json | ||
import logging | ||
import os | ||
import uuid | ||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union | ||
|
||
if TYPE_CHECKING: | ||
from cohere.types import ChatMessageV2, ChatResponse | ||
|
||
from camel.configs import COHERE_API_PARAMS, CohereConfig | ||
from camel.messages import OpenAIMessage | ||
from camel.models import BaseModelBackend | ||
from camel.types import ChatCompletion, ModelType | ||
from camel.utils import ( | ||
BaseTokenCounter, | ||
OpenAITokenCounter, | ||
api_keys_required, | ||
) | ||
|
||
try: | ||
if os.getenv("AGENTOPS_API_KEY") is not None: | ||
from agentops import LLMEvent, record | ||
else: | ||
raise ImportError | ||
except (ImportError, AttributeError): | ||
LLMEvent = None | ||
|
||
|
||
class CohereModel(BaseModelBackend): | ||
r"""Cohere API in a unified BaseModelBackend interface.""" | ||
|
||
def __init__( | ||
self, | ||
model_type: Union[ModelType, str], | ||
model_config_dict: Optional[Dict[str, Any]] = None, | ||
api_key: Optional[str] = None, | ||
url: Optional[str] = None, | ||
token_counter: Optional[BaseTokenCounter] = None, | ||
): | ||
import cohere | ||
|
||
if model_config_dict is None: | ||
model_config_dict = CohereConfig().as_dict() | ||
|
||
api_key = api_key or os.environ.get("COHERE_API_KEY") | ||
url = url or os.environ.get("COHERE_API_BASE_URL") | ||
super().__init__( | ||
model_type, model_config_dict, api_key, url, token_counter | ||
) | ||
self._client = cohere.ClientV2(api_key=self._api_key) | ||
|
||
def _to_openai_response(self, response: 'ChatResponse') -> ChatCompletion: | ||
if response.usage and response.usage.tokens: | ||
input_tokens = response.usage.tokens.input_tokens or 0 | ||
output_tokens = response.usage.tokens.output_tokens or 0 | ||
usage = { | ||
"prompt_tokens": input_tokens, | ||
"completion_tokens": output_tokens, | ||
"total_tokens": input_tokens + output_tokens, | ||
} | ||
else: | ||
usage = {} | ||
|
||
tool_calls = response.message.tool_calls | ||
choices = [] | ||
if tool_calls: | ||
for tool_call in tool_calls: | ||
openai_tool_calls = [ | ||
dict( | ||
id=tool_call.id, | ||
function={ | ||
"name": tool_call.function.name, | ||
"arguments": tool_call.function.arguments, | ||
} | ||
if tool_call.function | ||
else {}, | ||
type=tool_call.type, | ||
) | ||
] | ||
|
||
choice = dict( | ||
index=None, | ||
message={ | ||
"role": "assistant", | ||
"content": response.message.tool_plan, | ||
"tool_calls": openai_tool_calls, | ||
}, | ||
finish_reason=response.finish_reason | ||
if response.finish_reason | ||
else None, | ||
) | ||
choices.append(choice) | ||
|
||
else: | ||
openai_tool_calls = None | ||
|
||
choice = dict( | ||
index=None, | ||
message={ | ||
"role": "assistant", | ||
"content": response.message.content[0].text, # type: ignore[union-attr,index] | ||
"tool_calls": openai_tool_calls, | ||
}, | ||
finish_reason=response.finish_reason | ||
if response.finish_reason | ||
else None, | ||
) | ||
choices.append(choice) | ||
|
||
obj = ChatCompletion.construct( | ||
id=response.id, | ||
choices=choices, | ||
created=None, | ||
model=self.model_type, | ||
object="chat.completion", | ||
usage=usage, | ||
) | ||
return obj | ||
|
||
def _to_cohere_chatmessage( | ||
self, messages: List[OpenAIMessage] | ||
) -> List["ChatMessageV2"]: | ||
from cohere.types import ToolCallV2Function | ||
from cohere.types.chat_message_v2 import ( | ||
AssistantChatMessageV2, | ||
SystemChatMessageV2, | ||
ToolCallV2, | ||
ToolChatMessageV2, | ||
UserChatMessageV2, | ||
) | ||
|
||
tool_call_id = None | ||
new_messages = [] | ||
for msg in messages: | ||
role = msg.get("role") | ||
content = msg.get("content") | ||
function_call = msg.get("function_call") | ||
|
||
if role == "user": | ||
new_message = UserChatMessageV2(role="user", content=content) # type: ignore[arg-type] | ||
elif role in {"tool", "function"}: | ||
new_message = ToolChatMessageV2( | ||
role="tool", | ||
tool_call_id=tool_call_id, # type: ignore[arg-type] | ||
content=content, # type: ignore[assignment,arg-type] | ||
) | ||
elif role == "assistant": | ||
if not function_call: | ||
new_message = AssistantChatMessageV2( # type: ignore[assignment] | ||
role="assistant", | ||
content=content, # type: ignore[arg-type] | ||
) | ||
else: | ||
arguments = function_call.get("arguments") # type: ignore[attr-defined] | ||
arguments_dict = ast.literal_eval(arguments) | ||
arguments_json = json.dumps(arguments_dict) | ||
|
||
assis_tool_call_id = str(uuid.uuid4()) | ||
tool_call_id = assis_tool_call_id | ||
new_message = AssistantChatMessageV2( # type: ignore[assignment] | ||
role="assistant", | ||
tool_calls=[ | ||
ToolCallV2( | ||
id=assis_tool_call_id, | ||
type="function", | ||
function=ToolCallV2Function( | ||
name=function_call.get("name"), # type: ignore[attr-defined] | ||
arguments=arguments_json, # type: ignore[attr-defined] | ||
), | ||
) | ||
], | ||
content=content, # type: ignore[arg-type] | ||
) | ||
elif role == "system": | ||
new_message = SystemChatMessageV2( # type: ignore[assignment] | ||
role="system", | ||
content=content, # type: ignore[arg-type] | ||
) | ||
else: | ||
raise ValueError(f"Unsupported message role: {role}") | ||
|
||
new_messages.append(new_message) | ||
return new_messages # type: ignore[return-value] | ||
|
||
@property | ||
def token_counter(self) -> BaseTokenCounter: | ||
r"""Initialize the token counter for the model backend. | ||
Returns: | ||
BaseTokenCounter: The token counter following the model's | ||
tokenization style. | ||
""" | ||
if not self._token_counter: | ||
self._token_counter = OpenAITokenCounter( | ||
model=ModelType.GPT_4O_MINI | ||
) | ||
return self._token_counter | ||
|
||
@api_keys_required("COHERE_API_KEY") | ||
def run(self, messages: List[OpenAIMessage]) -> ChatCompletion: | ||
r"""Runs inference of Cohere chat completion. | ||
Args: | ||
messages (List[OpenAIMessage]): Message list with the chat history | ||
in OpenAI API format. | ||
Returns: | ||
ChatCompletion. | ||
""" | ||
from cohere.core.api_error import ApiError | ||
|
||
cohere_messages = self._to_cohere_chatmessage(messages) | ||
|
||
try: | ||
response = self._client.chat( | ||
messages=cohere_messages, | ||
model=self.model_type, | ||
**self.model_config_dict, | ||
) | ||
except ApiError as e: | ||
logging.error(f"Cohere API Error: {e.status_code}") | ||
logging.error(f"Error body: {e.body}") | ||
raise | ||
except Exception as e: | ||
logging.error(f"Unexpected error when calling Cohere API: {e!s}") | ||
raise | ||
|
||
openai_response = self._to_openai_response(response) | ||
|
||
# Add AgentOps LLM Event tracking | ||
if LLMEvent: | ||
llm_event = LLMEvent( | ||
thread_id=openai_response.id, | ||
prompt=" ".join( | ||
[message.get("content") for message in messages] # type: ignore[misc] | ||
), | ||
prompt_tokens=openai_response.usage.prompt_tokens, # type: ignore[union-attr] | ||
completion=openai_response.choices[0].message.content, | ||
completion_tokens=openai_response.usage.completion_tokens, # type: ignore[union-attr] | ||
model=self.model_type, | ||
) | ||
record(llm_event) | ||
|
||
return openai_response | ||
|
||
def check_model_config(self): | ||
r"""Check whether the model configuration contains any unexpected | ||
arguments to Cohere API. | ||
Raises: | ||
ValueError: If the model configuration dictionary contains any | ||
unexpected arguments to Cohere API. | ||
""" | ||
for param in self.model_config_dict: | ||
if param not in COHERE_API_PARAMS: | ||
raise ValueError( | ||
f"Unexpected argument `{param}` is " | ||
"input into Cohere model backend." | ||
) | ||
|
||
@property | ||
def stream(self) -> bool: | ||
r"""Returns whether the model is in stream mode, which sends partial | ||
results each time. Current it's not supported. | ||
Returns: | ||
bool: Whether the model is in stream mode. | ||
""" | ||
return False |
Oops, something went wrong.