Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: tool calling cohere #3355

Merged
71 changes: 45 additions & 26 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"api_type": "cohere",
"model": "command-r-plus",
"api_key": os.environ.get("COHERE_API_KEY")
"client_name": "autogen-cohere", # Optional parameter
}
]}

Expand Down Expand Up @@ -144,7 +145,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]:
def create(self, params: Dict) -> ChatCompletion:

messages = params.get("messages", [])

client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params)

Expand All @@ -156,7 +157,7 @@ def create(self, params: Dict) -> ChatCompletion:
cohere_params["preamble"] = preamble

# We use chat model by default
client = Cohere(api_key=self.api_key)
client = Cohere(api_key=self.api_key, client_name=client_name)

# Token counts will be returned
prompt_tokens = 0
Expand Down Expand Up @@ -285,6 +286,23 @@ def create(self, params: Dict) -> ChatCompletion:
return response_oai


def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]:
temp_tool_results = []

for tool_call in all_tool_calls:
if tool_call["id"] == tool_call_id:

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}"
),
}
output = [{"value": content_output}]
temp_tool_results.append(ToolResult(call=call, outputs=output))
return temp_tool_results


def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
) -> tuple[list[dict[str, Any]], str, str]:
Expand Down Expand Up @@ -352,7 +370,8 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
for message in messages:
messages_length = len(messages)
for index, message in enumerate(messages):

if "role" in message and message["role"] == "system":
# System message
Expand All @@ -369,34 +388,34 @@ def oai_messages_to_cohere_messages(
new_message = {
"role": "CHATBOT",
"message": message["content"],
# Not including tools in this message, may need to. Testing required.
"tool_calls": [
{
"name": tool_call_.get("function", {}).get("name"),
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
}
for tool_call_ in message["tool_calls"]
],
}

cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if "tool_call_id" in message:
# Convert the tool call to a result
if not (tool_call_id := message.get("tool_call_id")):
continue

# Convert the tool call to a result
content_output = message["content"]
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
continue

tool_call_id = message["tool_call_id"]
content_output = message["content"]

# Find the original tool
for tool_call in tool_calls:
if tool_call["id"] == tool_call_id:

call = {
"name": tool_call["function"]["name"],
"parameters": json.loads(
tool_call["function"]["arguments"]
if not tool_call["function"]["arguments"] == ""
else "{}"
),
}
output = [{"value": content_output}]

tool_results.append(ToolResult(call=call, outputs=output))
else:
# If its not the current tool call, we pass it as a tool message in the chat history.
new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn}
cohere_messages.append(new_message)

break
elif "content" in message and isinstance(message["content"], str):
# Standard text message
new_message = {
Expand All @@ -416,7 +435,7 @@ def oai_messages_to_cohere_messages(
# If we're adding tool_results, like we are, the last message can't be a USER message
# So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"] == "USER":
if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})

# We return a blank message when we have tool results
Expand Down
4 changes: 3 additions & 1 deletion website/docs/topics/non-openai-models/cloud-cohere.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
"- seed (null, integer)\n",
"- frequency_penalty (number 0..1)\n",
"- presence_penalty (number 0..1)\n",
"- client_name (null, string)\n",
"\n",
"Example:\n",
"```python\n",
Expand All @@ -108,6 +109,7 @@
" \"model\": \"command-r\",\n",
" \"api_key\": \"your Cohere API Key goes here\",\n",
" \"api_type\": \"cohere\",\n",
" \"client_name\": \"autogen-cohere\",\n",
" \"temperature\": 0.5,\n",
" \"p\": 0.2,\n",
" \"k\": 100,\n",
Expand Down Expand Up @@ -526,7 +528,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.12.5"
}
},
"nbformat": 4,
Expand Down
Loading