Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tool calling cohere #3355

Merged
63 changes: 40 additions & 23 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def create(self, params: Dict) -> ChatCompletion:
cohere_params["preamble"] = preamble

# We use chat model by default
client = Cohere(api_key=self.api_key)
client = Cohere(api_key=self.api_key, client_name="autogen-cohere")
marklysze marked this conversation as resolved.
Show resolved Hide resolved

# Token counts will be returned
prompt_tokens = 0
Expand Down Expand Up @@ -284,6 +284,24 @@ def create(self, params: Dict) -> ChatCompletion:

return response_oai

def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
temp_tool_results = []

for tool_call in all_tool_calls:
if tool_call["id"] == tool_call_id:

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]
temp_tool_results.append(ToolResult(call=call, outputs=output))
return temp_tool_results


def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
Expand Down Expand Up @@ -352,7 +370,8 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
for message in messages:
messages_length = len(messages)
for index, message in enumerate(messages):

if "role" in message and message["role"] == "system":
# System message
Expand All @@ -369,34 +388,31 @@ def oai_messages_to_cohere_messages(
new_message = {
"role": "CHATBOT",
"message": message["content"],
"tool_calls": [{"name": tool_call_.get("function",{}).get("name"), "parameters": eval(tool_call_.get("function",{}).get("arguments"))} for tool_call_ in message["tool_calls"]],
Anirudh31415926535 marked this conversation as resolved.
Show resolved Hide resolved
# Not including tools in this message, may need to. Testing required.
}

cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if "tool_call_id" in message:
# Convert the tool call to a result

tool_call_id = message["tool_call_id"]
content_output = message["content"]
if not (tool_call_id := message.get("tool_call_id")):
continue

# Find the original tool
for tool_call in tool_calls:
if tool_call["id"] == tool_call_id:
# Convert the tool call to a result
content_output = message["content"]
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]

tool_results.append(ToolResult(call=call, outputs=output))
if (index == messages_length - 1) or (messages[index+1].get("role", "").lower() in ("user", "tool")):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
break

else:
# If its not the current tool call, we pass it as a tool message in the chat history.
new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
cohere_messages.append(new_message)


break
elif "content" in message and isinstance(message["content"], str):
# Standard text message
new_message = {
Expand All @@ -405,6 +421,7 @@ def oai_messages_to_cohere_messages(
}

cohere_messages.append(new_message)


# Append any Tool Results
if len(tool_results) != 0:
Expand All @@ -416,7 +433,7 @@ def oai_messages_to_cohere_messages(
# If we're adding tool_results, like we are, the last message can't be a USER message
# So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"] == "USER":
if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})

# We return a blank message when we have tool results
Expand Down