Skip to content

Commit

Permalink
fix: update cohere tool calling multi agents (#3488)
Browse files Browse the repository at this point in the history
* update cohere tool calling multi agents

* Add agent name prefix to chatbot message

---------

Co-authored-by: Jack Gerrits <[email protected]>
  • Loading branch information
Anirudh31415926535 and jackgerrits authored Sep 25, 2024
1 parent 29fa4ce commit 7cfcf55
Showing 1 changed file with 55 additions and 18 deletions.
73 changes: 55 additions & 18 deletions autogen/oai/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def create(self, params: Dict) -> ChatCompletion:
client_name = params.get("client_name") or "autogen-cohere"
# Parse parameters to the Cohere API's parameters
cohere_params = self.parse_params(params)

# Convert AutoGen messages to Cohere messages
cohere_messages, preamble, final_message = oai_messages_to_cohere_messages(messages, params, cohere_params)

Expand All @@ -169,13 +168,15 @@ def create(self, params: Dict) -> ChatCompletion:
cohere_finish = ""

max_retries = 5

for attempt in range(max_retries):
ans = None
try:
if streaming:
response = client.chat_stream(**cohere_params)
else:
response = client.chat(**cohere_params)

except CohereRateLimitError as e:
raise RuntimeError(f"Cohere exception occurred: {e}")
else:
Expand Down Expand Up @@ -303,6 +304,15 @@ def extract_to_cohere_tool_results(tool_call_id: str, content_output: str, all_t
return temp_tool_results


def is_recent_tool_call(messages: list[Dict[str, Any]], tool_call_index: int):
messages_length = len(messages)
if tool_call_index == messages_length - 1:
return True
elif messages[tool_call_index + 1].get("role", "").lower() not in ("chatbot"):
return True
return False


def oai_messages_to_cohere_messages(
messages: list[Dict[str, Any]], params: Dict[str, Any], cohere_params: Dict[str, Any]
) -> tuple[list[dict[str, Any]], str, str]:
Expand All @@ -322,7 +332,7 @@ def oai_messages_to_cohere_messages(

cohere_messages = []
preamble = ""

cohere_tool_names = set()
# Tools
if "tools" in params:
cohere_tools = []
Expand Down Expand Up @@ -353,6 +363,7 @@ def oai_messages_to_cohere_messages(
"description": tool["function"]["description"],
"parameter_definitions": parameters,
}
cohere_tool_names.add(tool["function"]["name"] or "")

cohere_tools.append(cohere_tool)

Expand All @@ -370,42 +381,68 @@ def oai_messages_to_cohere_messages(
# 'content' field renamed to 'message'
# tools go into tools parameter
# tool_results go into tool_results parameter
messages_length = len(messages)
for index, message in enumerate(messages):

if not message["content"]:
continue

if "role" in message and message["role"] == "system":
# System message
if preamble == "":
preamble = message["content"]
else:
preamble = preamble + "\n" + message["content"]
elif "tool_calls" in message:

elif message.get("tool_calls"):
# Suggested tool calls, build up the list before we put it into the tool_results
for tool_call in message["tool_calls"]:
message_tool_calls = []
for tool_call in message["tool_calls"] or []:
if (not tool_call.get("function", {}).get("name")) or tool_call.get("function", {}).get(
"name"
) not in cohere_tool_names:
new_message = {
"role": "CHATBOT",
"message": message.get("name") + ":" + message["content"] + str(message["tool_calls"]),
}
cohere_messages.append(new_message)
continue

tool_calls.append(tool_call)
message_tool_calls.append(
{
"name": tool_call.get("function", {}).get("name"),
"parameters": json.loads(tool_call.get("function", {}).get("arguments") or "null"),
}
)

if not message_tool_calls:
continue

# We also add the suggested tool call as a message
new_message = {
"role": "CHATBOT",
"message": message["content"],
"tool_calls": [
{
"name": tool_call_.get("function", {}).get("name"),
"parameters": json.loads(tool_call_.get("function", {}).get("arguments") or "null"),
}
for tool_call_ in message["tool_calls"]
],
"message": message.get("name") + ":" + message["content"],
"tool_calls": message_tool_calls,
}

cohere_messages.append(new_message)
elif "role" in message and message["role"] == "tool":
if not (tool_call_id := message.get("tool_call_id")):
continue

# Convert the tool call to a result
content_output = message["content"]
if tool_call_id not in [tool_call["id"] for tool_call in tool_calls]:

new_message = {
"role": "CHATBOT",
"message": content_output,
}
cohere_messages.append(new_message)
continue

# Convert the tool call to a result
tool_results_chat_turn = extract_to_cohere_tool_results(tool_call_id, content_output, tool_calls)
if (index == messages_length - 1) or (messages[index + 1].get("role", "").lower() in ("user", "tool")):
if is_recent_tool_call(messages, index):
# If the tool call is the last message or the next message is a user/tool message, this is a recent tool call.
# So, we pass it into tool_results.
tool_results.extend(tool_results_chat_turn)
Expand All @@ -420,7 +457,7 @@ def oai_messages_to_cohere_messages(
# Standard text message
new_message = {
"role": "USER" if message["role"] == "user" else "CHATBOT",
"message": message["content"],
"message": message.get("name") + ":" + message.get("content"),
}

cohere_messages.append(new_message)
Expand All @@ -436,7 +473,7 @@ def oai_messages_to_cohere_messages(
# So, we add a CHATBOT 'continue' message, if so.
# Changed key from "content" to "message" (jaygdesai/autogen_Jay)
if cohere_messages[-1]["role"].lower() == "user":
cohere_messages.append({"role": "CHATBOT", "message": "Please continue."})
cohere_messages.append({"role": "CHATBOT", "message": "Please go ahead and follow the instructions!"})

# We return a blank message when we have tool results
# TODO: Check what happens if tool_results aren't the latest message
Expand All @@ -449,7 +486,7 @@ def oai_messages_to_cohere_messages(
if cohere_messages[-1]["role"] == "USER":
return cohere_messages[0:-1], preamble, cohere_messages[-1]["message"]
else:
return cohere_messages, preamble, "Please continue."
return cohere_messages, preamble, "Please go ahead and follow the instructions!"


def calculate_cohere_cost(input_tokens: int, output_tokens: int, model: str) -> float:
Expand Down

0 comments on commit 7cfcf55

Please sign in to comment.