From 6efa03529a478095facb6fbac0dd24dc52469caa Mon Sep 17 00:00:00 2001 From: Jean-Marc Le Roux Date: Wed, 4 Dec 2024 19:52:25 +0100 Subject: [PATCH] Add the `add_tool()`, `remove_tool()` and `remove_all_tools()` methods for `AssistantAgent` --- .../agents/_assistant_agent.py | 72 +++++++++++-------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 74a80ac7bbd3..cb712891bb21 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -231,24 +231,10 @@ def __init__( else: self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] - if tools is not None: - if model_client.capabilities["function_calling"] is False: - raise ValueError("The model does not support function calling.") - for tool in tools: - if isinstance(tool, Tool): - self._tools.append(tool) - elif callable(tool): - if hasattr(tool, "__doc__") and tool.__doc__ is not None: - description = tool.__doc__ - else: - description = "" - self._tools.append(FunctionTool(tool, description=description)) - else: - raise ValueError(f"Unsupported tool type: {type(tool)}") - # Check if tool names are unique. - tool_names = [tool.name for tool in self._tools] - if len(tool_names) != len(set(tool_names)): - raise ValueError(f"Tool names must be unique: {tool_names}") + self._model_context: List[LLMMessage] = [] + self._reflect_on_tool_use = reflect_on_tool_use + self._tool_call_summary_format = tool_call_summary_format + self._is_running = False # Handoff tools. self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, HandoffBase] = {} @@ -258,24 +244,54 @@ def __init__( for handoff in handoffs: if isinstance(handoff, str): handoff = HandoffBase(target=handoff) + if handoff.name in self._handoffs: + raise ValueError(f"Handoff name {handoff.name} already exists.") if isinstance(handoff, HandoffBase): self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: raise ValueError(f"Unsupported handoff type: {type(handoff)}") - # Check if handoff tool names are unique. - handoff_tool_names = [tool.name for tool in self._handoff_tools] - if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") + if tools is not None: + for tool in tools: + self.add_tool(tool) + + def add_tool(self, tool: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]) -> None: + new_tool = None + if self._model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling.") + if isinstance(tool, Tool): + new_tool = tool + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + new_tool = FunctionTool(tool, description=description) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + self._tools.append(new_tool) + # Check if tool names are unique. + if new_tool.name in self._tools: + raise ValueError(f"Tool names must be unique: {new_tool.name}") # Check if handoff tool names not in tool names. - if any(name in tool_names for name in handoff_tool_names): + handoff_tool_names = [handoff.name for handoff in self._handoffs.values()] + if new_tool.name in handoff_tool_names: raise ValueError( - f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" + f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; " + f"tool names: {new_tool.name}" ) - self._model_context: List[LLMMessage] = [] - self._reflect_on_tool_use = reflect_on_tool_use - self._tool_call_summary_format = tool_call_summary_format - self._is_running = False + + def remove_all_tools(self) -> None: + """Remove all tools.""" + self._tools = [] + + def remove_tool(self, tool_name: List[str]) -> None: + """Remove tools by name.""" + for tool in self._tools: + if tool.name == tool_name: + self._tools.remove(tool) + return + raise ValueError(f"Tool {tool_name} not found.") @property def produced_message_types(self) -> List[type[ChatMessage]]: