From 5861bd92a6813a47a7ceaf391c4d5871d71a14ff Mon Sep 17 00:00:00 2001 From: Anirudh31415926535 Date: Thu, 29 Aug 2024 02:47:39 +0800 Subject: [PATCH] fix: tool calling cohere (#3355) * Add support for tool calling cohere * update tool calling code * make client name configurable with default * formatting nits * update docs --------- Co-authored-by: Mark Sze <66362098+marklysze@users.noreply.github.com> Co-authored-by: Li Jiang --- autogen/oai/cohere.py | 71 ++++++++++++------- .../non-openai-models/cloud-cohere.ipynb | 4 +- 2 files changed, 48 insertions(+), 27 deletions(-) diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 35b7ac97c4f3..3d38d86425fb 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -6,6 +6,7 @@ "api_type": "cohere", "model": "command-r-plus", "api_key": os.environ.get("COHERE_API_KEY") + "client_name": "autogen-cohere", # Optional parameter } ]} @@ -144,7 +145,7 @@ def parse_params(self, params: Dict[str, Any]) -> Dict[str, Any]: def create(self, params: Dict) -> ChatCompletion: messages = params.get("messages", []) - + client_name = params.get("client_name") or "autogen-cohere" # Parse parameters to the Cohere API's parameters cohere_params = self.parse_params(params) @@ -156,7 +157,7 @@ def create(self, params: Dict) -> ChatCompletion: cohere_params["preamble"] = preamble # We use chat model by default - client = Cohere(api_key=self.api_key) + client = Cohere(api_key=self.api_key, client_name=client_name) # Token counts will be returned prompt_tokens = 0 @@ -285,6 +286,23 @@ def create(self, params: Dict) -> ChatCompletion: return response_oai +def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_tool_calls) -> List[Dict[str, Any]]: + temp_tool_results = [] + + for tool_call in all_tool_calls: + if tool_call["id"] == tool_call_id: + + call = { + "name": tool_call["function"]["name"], + "parameters": json.loads( + tool_call["function"]["arguments"] if not tool_call["function"]["arguments"] == "" else "{}" + ), + } + output = [{"value": content_output}] + temp_tool_results.append(ToolResult(call=call, outputs=output)) + return temp_tool_results + + def oai_messages_to_cohere_messages( messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any] ) -> tuple[list[dict[str, Any]], str, str]: @@ -352,7 +370,8 @@ def oai_messages_to_cohere_messages( # 'content' field renamed to 'message' # tools go into tools parameter # tool_results go into tool_results parameter - for message in messages: + messages_length = len(messages) + for index, message in enumerate(messages): if "role" in message and message["role"] == "system": # System message @@ -369,34 +388,34 @@ def oai_messages_to_cohere_messages( new_message = { "role": "CHATBOT", "message": message["content"], - # Not including tools in this message, may need to. Testing required. + "tool_calls": [ + { + "name": tool_call_.get("function", {}).get("name"), + "parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"), + } + for tool_call_ in message["tool_calls"] + ], } cohere_messages.append(new_message) elif "role" in message and message["role"] == "tool": - if "tool_call_id" in message: - # Convert the tool call to a result + if not (tool_call_id := message.get("tool_call_id")): + continue + + # Convert the tool call to a result + content_output = message["content"] + tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls) + if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")): + # If the tool call is the last message or the next message is a user/tool message, this is a recent tool call. + # So, we pass it into tool_results. + tool_results.extend(tool_results_chat_turn) + continue - tool_call_id = message["tool_call_id"] - content_output = message["content"] - - # Find the original tool - for tool_call in tool_calls: - if tool_call["id"] == tool_call_id: - - call = { - "name": tool_call["function"]["name"], - "parameters": json.loads( - tool_call["function"]["arguments"] - if not tool_call["function"]["arguments"] == "" - else "{}" - ), - } - output = [{"value": content_output}] - - tool_results.append(ToolResult(call=call, outputs=output)) + else: + # If its not the current tool call, we pass it as a tool message in the chat history. + new_message = {"role": "TOOL", "tool_results": tool_results_chat_turn} + cohere_messages.append(new_message) - break elif "content" in message and isinstance(message["content"], str): # Standard text message new_message = { @@ -416,7 +435,7 @@ def oai_messages_to_cohere_messages( # If we're adding tool_results, like we are, the last message can't be a USER message # So, we add a CHATBOT 'continue' message, if so. # Changed key from "content" to "message" (jaygdesai/autogen_Jay) - if cohere_messages[-1]["role"] == "USER": + if cohere_messages[-1]["role"].lower() == "user": cohere_messages.append({"role": "CHATBOT", "message": "Please continue."}) # We return a blank message when we have tool results diff --git a/website/docs/topics/non-openai-models/cloud-cohere.ipynb b/website/docs/topics/non-openai-models/cloud-cohere.ipynb index b678810a7699..73dcc54a75ed 100644 --- a/website/docs/topics/non-openai-models/cloud-cohere.ipynb +++ b/website/docs/topics/non-openai-models/cloud-cohere.ipynb @@ -100,6 +100,7 @@ "- seed (null, integer)\n", "- frequency_penalty (number 0..1)\n", "- presence_penalty (number 0..1)\n", + "- client_name (null, string)\n", "\n", "Example:\n", "```python\n", @@ -108,6 +109,7 @@ " \"model\": \"command-r\",\n", " \"api_key\": \"your Cohere API Key goes here\",\n", " \"api_type\": \"cohere\",\n", + " \"client_name\": \"autogen-cohere\",\n", " \"temperature\": 0.5,\n", " \"p\": 0.2,\n", " \"k\": 100,\n", @@ -526,7 +528,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.5" } }, "nbformat": 4,