diff --git a/autogen/oai/cohere.py b/autogen/oai/cohere.py index 3d38d86425fb..e9a89c9cabd8 100644 --- a/autogen/oai/cohere.py +++ b/autogen/oai/cohere.py @@ -148,7 +148,6 @@ def create(self, params: Dict) -> ChatCompletion: client_name = params.get("client_name") or "autogen-cohere" # Parse parameters to the Cohere API's parameters cohere_params = self.parse_params(params) - # Convert AutoGen messages to Cohere messages cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params) @@ -169,6 +168,7 @@ def create(self, params: Dict) -> ChatCompletion: cohere_finish = "" max_retries = 5 + for attempt in range(max_retries): ans = None try: @@ -176,6 +176,7 @@ def create(self, params: Dict) -> ChatCompletion: response = client.chat_stream(**cohere_params) else: response = client.chat(**cohere_params) + except CohereRateLimitError as e: raise RuntimeError(f"Cohere exception occurred: {e}") else: @@ -303,6 +304,15 @@ def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_t return temp_tool_results +def is_recent_tool_call(messages: list[Dict[str, Any]], tool_call_index: int): + messages_length = len(messages) + if tool_call_index == messages_length - 1: + return True + elif messages[tool_call_index + 1].get("role", "").lower() not in ("chatbot"): + return True + return False + + def oai_messages_to_cohere_messages( messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any] ) -> tuple[list[dict[str, Any]], str, str]: @@ -322,7 +332,7 @@ def oai_messages_to_cohere_messages( cohere_messages = [] preamble = "" - + cohere_tool_names = set() # Tools if "tools" in params: cohere_tools = [] @@ -353,6 +363,7 @@ def oai_messages_to_cohere_messages( "description": tool["function"]["description"], "parameter_definitions": parameters, } + cohere_tool_names.add(tool["function"]["name"] or "") cohere_tools.append(cohere_tool) @@ -370,31 +381,48 @@ def oai_messages_to_cohere_messages( # 'content' field renamed to 'message' # tools go into tools parameter # tool_results go into tool_results parameter - messages_length = len(messages) for index, message in enumerate(messages): + if not message["content"]: + continue + if "role" in message and message["role"] == "system": # System message if preamble == "": preamble = message["content"] else: preamble = preamble + "\n" + message["content"] - elif "tool_calls" in message: + + elif message.get("tool_calls"): # Suggested tool calls, build up the list before we put it into the tool_results - for tool_call in message["tool_calls"]: + message_tool_calls = [] + for tool_call in message["tool_calls"] or []: + if (not tool_call.get("function", {}).get("name")) or tool_call.get("function", {}).get( + "name" + ) not in cohere_tool_names: + new_message = { + "role": "CHATBOT", + "message": message.get("name") + ":" + message["content"] + str(message["tool_calls"]), + } + cohere_messages.append(new_message) + continue + tool_calls.append(tool_call) + message_tool_calls.append( + { + "name": tool_call.get("function", {}).get("name"), + "parameters": json.loads(tool_call.get("function", {}).get("arguments") or "null"), + } + ) + + if not message_tool_calls: + continue # We also add the suggested tool call as a message new_message = { "role": "CHATBOT", - "message": message["content"], - "tool_calls": [ - { - "name": tool_call_.get("function", {}).get("name"), - "parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"), - } - for tool_call_ in message["tool_calls"] - ], + "message": message.get("name") + ":" + message["content"], + "tool_calls": message_tool_calls, } cohere_messages.append(new_message) @@ -402,10 +430,19 @@ def oai_messages_to_cohere_messages( if not (tool_call_id := message.get("tool_call_id")): continue - # Convert the tool call to a result content_output = message["content"] + if tool_call_id not in [tool_call["id"] for tool_call in tool_calls]: + + new_message = { + "role": "CHATBOT", + "message": content_output, + } + cohere_messages.append(new_message) + continue + + # Convert the tool call to a result tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls) - if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")): + if is_recent_tool_call(messages, index): # If the tool call is the last message or the next message is a user/tool message, this is a recent tool call. # So, we pass it into tool_results. tool_results.extend(tool_results_chat_turn) @@ -420,7 +457,7 @@ def oai_messages_to_cohere_messages( # Standard text message new_message = { "role": "USER" if message["role"] == "user" else "CHATBOT", - "message": message["content"], + "message": message.get("name") + ":" + message.get("content"), } cohere_messages.append(new_message) @@ -436,7 +473,7 @@ def oai_messages_to_cohere_messages( # So, we add a CHATBOT 'continue' message, if so. # Changed key from "content" to "message" (jaygdesai/autogen_Jay) if cohere_messages[-1]["role"].lower() == "user": - cohere_messages.append({"role": "CHATBOT", "message": "Please continue."}) + cohere_messages.append({"role": "CHATBOT", "message": "Please go ahead and follow the instructions!"}) # We return a blank message when we have tool results # TODO: Check what happens if tool_results aren't the latest message @@ -449,7 +486,7 @@ def oai_messages_to_cohere_messages( if cohere_messages[-1]["role"] == "USER": return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"] else: - return cohere_messages, preamble, "Please continue." + return cohere_messages, preamble, "Please go ahead and follow the instructions!" def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float: