diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index de2087382254..099b4bdb6082 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -182,24 +182,7 @@ def __init__( else: self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] - if tools is not None: - if model_client.capabilities["function_calling"] is False: - raise ValueError("The model does not support function calling.") - for tool in tools: - if isinstance(tool, Tool): - self._tools.append(tool) - elif callable(tool): - if hasattr(tool, "__doc__") and tool.__doc__ is not None: - description = tool.__doc__ - else: - description = "" - self._tools.append(FunctionTool(tool, description=description)) - else: - raise ValueError(f"Unsupported tool type: {type(tool)}") - # Check if tool names are unique. - tool_names = [tool.name for tool in self._tools] - if len(tool_names) != len(set(tool_names)): - raise ValueError(f"Tool names must be unique: {tool_names}") + self._model_context: List[LLMMessage] = [] # Handoff tools. self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, HandoffBase] = {} @@ -214,6 +197,27 @@ def __init__( self._handoffs[handoff.name] = handoff else: raise ValueError(f"Unsupported handoff type: {type(handoff)}") + if tools is not None: + self.add_tools(tools) + + def add_tools(self, tools: List[Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]]) -> None: + if self._model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling.") + for tool in tools: + if isinstance(tool, Tool): + self._tools.append(tool) + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + self._tools.append(FunctionTool(tool, description=description)) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + # Check if tool names are unique. + tool_names = [tool.name for tool in self._tools] + if len(tool_names) != len(set(tool_names)): + raise ValueError(f"Tool names must be unique: {tool_names}") # Check if handoff tool names are unique. handoff_tool_names = [tool.name for tool in self._handoff_tools] if len(handoff_tool_names) != len(set(handoff_tool_names)): @@ -223,7 +227,26 @@ def __init__( raise ValueError( f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" ) - self._model_context: List[LLMMessage] = [] + + def remove_all_tools(self) -> None: + """Remove all tools.""" + self._tools = [] + + def remove_tools(self, tool_names: List[str]) -> None: + """Remove tools by name.""" + for name in tool_names: + for tool in self._tools: + if tool.name == name: + self._tools.remove(tool) + break + for tool in self._handoff_tools: + if tool.name == name: + self._handoff_tools.remove(tool) + break + for handoff in self._handoffs.values(): + if handoff.name == name: + self._handoffs.pop(handoff.name) + break @property def produced_message_types(self) -> List[type[ChatMessage]]: