Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tool calling cohere #3355

Merged
70 changes: 44 additions & 26 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"api_type": "cohere",
"model": "command-r-plus",
"api_key": os.environ.get("COHERE_API_KEY")
"client_name": "autogen-cohere", # Optional parameter
}
]}

Expand Down Expand Up @@ -34,6 +35,7 @@
from openai.types.chat import ChatCompletion, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion import ChatCompletionMessage, Choice
from openai.types.completion_usage import CompletionUsage
import json

from autogen.oai.client_utils import validate_parameter

Expand Down Expand Up @@ -66,6 +68,7 @@ def __init__(self, **kwargs):
"""
# Ensure we have the api_key upon instantiation
self.api_key = kwargs.get("api_key", None)
self.client_name = kwargs.get("client_name") or "autogen-cohere"
marklysze marked this conversation as resolved.
Show resolved Hide resolved
if not self.api_key:
self.api_key = os.getenv("COHERE_API_KEY")

Expand Down Expand Up @@ -156,7 +159,7 @@ def create(self, params: Dict) -> ChatCompletion:
cohere_params["preamble"] = preamble

# We use chat model by default
client = Cohere(api_key=self.api_key)
client = Cohere(api_key=self.api_key, client_name=self.client_name)

# Token counts will be returned
prompt_tokens = 0
Expand Down Expand Up @@ -284,6 +287,24 @@ def create(self, params: Dict) -> ChatCompletion:

return response_oai

def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
temp_tool_results = []

for tool_call in all_tool_calls:
if tool_call["id"] == tool_call_id:

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]
temp_tool_results.append(ToolResult(call=call, outputs=output))
return temp_tool_results


def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
Expand Down Expand Up @@ -352,7 +373,8 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
for message in messages:
messages_length = len(messages)
for index, message in enumerate(messages):

if "role" in message and message["role"] == "system":
# System message
Expand All @@ -369,34 +391,29 @@ def oai_messages_to_cohere_messages(
new_message = {
"role": "CHATBOT",
"message": message["content"],
# Not including tools in this message, may need to. Testing required.
"tool_calls": [{"name": tool_call_.get("function",{}).get("name"), "parameters": json.loads(tool_call_.get("function",{}).get("arguments") or 'null')} for tool_call_ in message["tool_calls"]],
}

cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if "tool_call_id" in message:
# Convert the tool call to a result

tool_call_id = message["tool_call_id"]
content_output = message["content"]

# Find the original tool
for tool_call in tool_calls:
if tool_call["id"] == tool_call_id:

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]

tool_results.append(ToolResult(call=call, outputs=output))
if not (tool_call_id := message.get("tool_call_id")):
continue

# Convert the tool call to a result
content_output = message["content"]
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
if (index == messages_length - 1) or (messages[index+1].get("role", "").lower() in ("user", "tool")):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
continue

else:
# If its not the current tool call, we pass it as a tool message in the chat history.
new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
cohere_messages.append(new_message)


break
elif "content" in message and isinstance(message["content"], str):
# Standard text message
new_message = {
Expand All @@ -405,6 +422,7 @@ def oai_messages_to_cohere_messages(
}

cohere_messages.append(new_message)


# Append any Tool Results
if len(tool_results) != 0:
Expand All @@ -416,7 +434,7 @@ def oai_messages_to_cohere_messages(
# If we're adding tool_results, like we are, the last message can't be a USER message
# So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"] == "USER":
if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})

# We return a blank message when we have tool results
Expand Down