Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update cohere tool calling multi agents #3488

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def create(self, params: Dict) -> ChatCompletion:
client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params)

# Convert AutoGen messages to Cohere messages
cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params)

Expand All @@ -169,13 +168,15 @@ def create(self, params: Dict) -> ChatCompletion:
cohere_finish = ""

max_retries = 5

for attempt in range(max_retries):
ans = None
try:
if streaming:
response = client.chat_stream(**cohere_params)
else:
response = client.chat(**cohere_params)

except CohereRateLimitError as e:
raise RuntimeError(f"Cohere exception occurred: {e}")
else:
Expand Down Expand Up @@ -303,6 +304,15 @@ def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_t
return temp_tool_results


def is_recent_tool_call(messages: list[Dict[str, Any]], tool_call_index: int):
messages_length = len(messages)
if tool_call_index == messages_length - 1:
return True
elif messages[tool_call_index + 1].get("role", "").lower() not in ("chatbot"):
return True
return False


def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
) -> tuple[list[dict[str, Any]], str, str]:
Expand All @@ -322,7 +332,7 @@ def oai_messages_to_cohere_messages(

cohere_messages = []
preamble = ""

cohere_tool_names = set()
# Tools
if "tools" in params:
cohere_tools = []
Expand Down Expand Up @@ -353,6 +363,7 @@ def oai_messages_to_cohere_messages(
"description": tool["function"]["description"],
"parameter_definitions": parameters,
}
cohere_tool_names.add(tool["function"]["name"] or "")

cohere_tools.append(cohere_tool)

Expand All @@ -370,42 +381,68 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
messages_length = len(messages)
for index, message in enumerate(messages):

if not message["content"]:
continue

if "role" in message and message["role"] == "system":
# System message
if preamble == "":
preamble = message["content"]
else:
preamble = preamble + "\n" + message["content"]
elif "tool_calls" in message:

elif message.get("tool_calls"):
# Suggested tool calls, build up the list before we put it into the tool_results
for tool_call in message["tool_calls"]:
message_tool_calls = []
for tool_call in message["tool_calls"] or []:
if (not tool_call.get("function", {}).get("name")) or tool_call.get("function", {}).get(
"name"
) not in cohere_tool_names:
new_message = {
"role": "CHATBOT",
"message": message.get("name") + ":" + message["content"] + str(message["tool_calls"]),
}
cohere_messages.append(new_message)
continue

tool_calls.append(tool_call)
message_tool_calls.append(
{
"name": tool_call.get("function", {}).get("name"),
"parameters": json.loads(tool_call.get("function", {}).get("arguments") or "null"),
}
)

if not message_tool_calls:
continue

# We also add the suggested tool call as a message
new_message = {
"role": "CHATBOT",
"message": message["content"],
"tool_calls": [
{
"name": tool_call_.get("function", {}).get("name"),
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
}
for tool_call_ in message["tool_calls"]
],
"message": message.get("name") + ":" + message["content"],
"tool_calls": message_tool_calls,
}

cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if not (tool_call_id := message.get("tool_call_id")):
continue

# Convert the tool call to a result
content_output = message["content"]
if tool_call_id not in [tool_call["id"] for tool_call in tool_calls]:

new_message = {
"role": "CHATBOT",
"message": content_output,
}
cohere_messages.append(new_message)
continue

# Convert the tool call to a result
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
if is_recent_tool_call(messages, index):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
Expand All @@ -420,7 +457,7 @@ def oai_messages_to_cohere_messages(
# Standard text message
new_message = {
"role": "USER" if message["role"] == "user" else "CHATBOT",
"message": message["content"],
"message": message.get("name") + ":" + message.get("content"),
}

cohere_messages.append(new_message)
Expand All @@ -436,7 +473,7 @@ def oai_messages_to_cohere_messages(
# So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
cohere_messages.append({"role": "CHATBOT", "message": "Please go ahead and follow the instructions!"})

# We return a blank message when we have tool results
# TODO: Check what happens if tool_results aren't the latest message
Expand All @@ -449,7 +486,7 @@ def oai_messages_to_cohere_messages(
if cohere_messages[-1]["role"] == "USER":
return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"]
else:
return cohere_messages, preamble, "Please continue."
return cohere_messages, preamble, "Please go ahead and follow the instructions!"


def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
Expand Down
Loading