Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability to specify 'role' field for select speaker messages for Group Chats (addresses #1861) #2167

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 8 additions & 24 deletions autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
if "function_call" in message and message["function_call"]:
function_call = dict(message["function_call"])
func_print = (
f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****"
f"***** Suggested function Call: {function_call.get('name', '(No function name found)')} *****"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"***** Suggested function Call: {function_call.get('name', '(No function name found)')} *****"
f"***** Suggested function call: {function_call.get('name', '(No function name found)')} *****"

)
iostream.print(colored(func_print, "green"), flush=True)
iostream.print(
Expand All @@ -728,7 +728,7 @@ def _print_received_message(self, message: Union[Dict, str], sender: Agent):
iostream.print(colored("*" * len(func_print), "green"), flush=True)
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
id = tool_call.get("id", "No tool call id found")
id = tool_call.get("id", "(No id found)")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
id = tool_call.get("id", "(No id found)")
id = tool_call.get("id", "No tool_call_id found")

function_call = dict(tool_call.get("function", {}))
func_print = f"***** Suggested tool call ({id}): {function_call.get('name', '(No function name found)')} *****"
iostream.print(colored(func_print, "green"), flush=True)
Expand Down Expand Up @@ -1311,12 +1311,6 @@ def _generate_oai_reply_from_client(self, llm_client, messages, cache) -> Union[
)
for tool_call in extracted_response.get("tool_calls") or []:
tool_call["function"]["name"] = self._normalize_name(tool_call["function"]["name"])
# Remove id and type if they are not present.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were for Mistral AI compatibility. Are these not needed now?

# This is to make the tool call object compatible with Mistral API.
if tool_call.get("id") is None:
tool_call.pop("id")
if tool_call.get("type") is None:
tool_call.pop("type")
return extracted_response

async def a_generate_oai_reply(
Expand Down Expand Up @@ -1546,6 +1540,7 @@ def generate_tool_calls_reply(
message = messages[-1]
tool_returns = []
for tool_call in message.get("tool_calls", []):
id = tool_call["id"]
function_call = tool_call.get("function", {})
func = self._function_map.get(function_call.get("name", None), None)
if inspect.iscoroutinefunction(func):
Expand All @@ -1563,24 +1558,13 @@ def generate_tool_calls_reply(
loop.close()
else:
_, func_return = self.execute_function(function_call)
content = func_return.get("content", "")
if content is None:
content = ""
tool_call_id = tool_call.get("id", None)
if tool_call_id is not None:
tool_call_response = {
"tool_call_id": tool_call_id,
tool_returns.append(
{
"tool_call_id": id,
"role": "tool",
"content": content,
"content": func_return.get("content", ""),
}
else:
# Do not include tool_call_id if it is not present.
# This is to make the tool call object compatible with Mistral API.
tool_call_response = {
"role": "tool",
"content": content,
}
tool_returns.append(tool_call_response)
)
if tool_returns:
return True, {
"role": "tool",
Expand Down
9 changes: 7 additions & 2 deletions autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def custom_speaker_selection_func(
"clear history" phrase in user prompt. This is experimental feature.
See description of GroupChatManager.clear_agents_history function for more info.
- send_introductions: send a round of introductions at the start of the group chat, so agents know who they can speak to (default: False)
- role_for_select_speaker_messages: sets the role name for speaker selection when in 'auto' mode, typically 'user' or 'system'. (default: 'system')
"""

agents: List[Agent]
Expand All @@ -74,6 +75,7 @@ def custom_speaker_selection_func(
speaker_transitions_type: Literal["allowed", "disallowed", None] = None
enable_clear_history: Optional[bool] = False
send_introductions: bool = False
role_for_select_speaker_messages: Optional[str] = "system"

_VALID_SPEAKER_SELECTION_METHODS = ["auto", "manual", "random", "round_robin"]
_VALID_SPEAKER_TRANSITIONS_TYPE = ["allowed", "disallowed", None]
Expand Down Expand Up @@ -411,7 +413,7 @@ def _prepare_and_select_agents(
selected_agent = self.next_agent(last_speaker, graph_eligible_agents)
elif speaker_selection_method.lower() == "random":
selected_agent = self.random_select_speaker(graph_eligible_agents)
else:
else: # auto
selected_agent = None
select_speaker_messages = self.messages.copy()
# If last message is a tool call or function call, blank the call so the api doesn't throw
Expand All @@ -420,7 +422,10 @@ def _prepare_and_select_agents(
if select_speaker_messages[-1].get("tool_calls", False):
select_speaker_messages[-1] = dict(select_speaker_messages[-1], tool_calls=None)
select_speaker_messages = select_speaker_messages + [
{"role": "system", "content": self.select_speaker_prompt(graph_eligible_agents)}
{
"role": self.role_for_select_speaker_messages,
"content": self.select_speaker_prompt(graph_eligible_agents),
}
]
return selected_agent, graph_eligible_agents, select_speaker_messages

Expand Down
Loading
Loading