diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 74a80ac7bbd3..ebaadee764c2 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -231,24 +231,10 @@ def __init__( else: self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] - if tools is not None: - if model_client.capabilities["function_calling"] is False: - raise ValueError("The model does not support function calling.") - for tool in tools: - if isinstance(tool, Tool): - self._tools.append(tool) - elif callable(tool): - if hasattr(tool, "__doc__") and tool.__doc__ is not None: - description = tool.__doc__ - else: - description = "" - self._tools.append(FunctionTool(tool, description=description)) - else: - raise ValueError(f"Unsupported tool type: {type(tool)}") - # Check if tool names are unique. - tool_names = [tool.name for tool in self._tools] - if len(tool_names) != len(set(tool_names)): - raise ValueError(f"Tool names must be unique: {tool_names}") + self._model_context: List[LLMMessage] = [] + self._reflect_on_tool_use = reflect_on_tool_use + self._tool_call_summary_format = tool_call_summary_format + self._is_running = False # Handoff tools. self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, HandoffBase] = {} @@ -258,24 +244,54 @@ def __init__( for handoff in handoffs: if isinstance(handoff, str): handoff = HandoffBase(target=handoff) + if handoff.name in self._handoffs: + raise ValueError(f"Handoff name {handoff.name} already exists.") if isinstance(handoff, HandoffBase): self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: raise ValueError(f"Unsupported handoff type: {type(handoff)}") - # Check if handoff tool names are unique. - handoff_tool_names = [tool.name for tool in self._handoff_tools] - if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") + if tools is not None: + for tool in tools: + self.add_tool(tool) + + def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None: + new_tool = None + if self._model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling.") + if isinstance(tool, Tool): + new_tool = tool + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + new_tool = FunctionTool(tool, description=description) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + # Check if tool names are unique. + if any(tool.name == new_tool.name for tool in self._tools): + raise ValueError(f"Tool names must be unique: {new_tool.name}") # Check if handoff tool names not in tool names. - if any(name in tool_names for name in handoff_tool_names): + handoff_tool_names = [handoff.name for handoff in self._handoffs.values()] + if new_tool.name in handoff_tool_names: raise ValueError( - f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" + f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; " + f"tool names: {new_tool.name}" ) - self._model_context: List[LLMMessage] = [] - self._reflect_on_tool_use = reflect_on_tool_use - self._tool_call_summary_format = tool_call_summary_format - self._is_running = False + self._tools.append(new_tool) + + def remove_all_tools(self) -> None: + """Remove all tools.""" + self._tools.clear() + + def remove_tool(self, tool_name: str) -> None: + """Remove tools by name.""" + for tool in self._tools: + if tool.name == tool_name: + self._tools.remove(tool) + return + raise ValueError(f"Tool {tool_name} not found.") @property def produced_message_types(self) -> List[type[ChatMessage]]: diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index c132e3a4862c..1a560a1d62de 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -467,3 +467,78 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: else: assert message == result.messages[index] index += 1 + + +def test_tool_management(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test function to be used as a tool + def sample_tool() -> str: + return "sample result" + + # Test adding a tool + tool = FunctionTool(sample_tool, description="Sample tool") + agent.add_tool(tool) + assert len(agent._tools) == 1 + + # Test adding duplicate tool + with pytest.raises(ValueError, match="Tool names must be unique"): + agent.add_tool(tool) + + # Test tool collision with handoff + agent_with_handoff = AssistantAgent( + name="test_assistant", model_client=model_client, handoffs=[Handoff(target="other_agent")] + ) + + conflicting_tool = FunctionTool(sample_tool, name="transfer_to_other_agent", description="Sample tool") + with pytest.raises(ValueError, match="Handoff names must be unique from tool names"): + agent_with_handoff.add_tool(conflicting_tool) + + # Test removing a tool + agent.remove_tool(tool.name) + assert len(agent._tools) == 0 + + # Test removing non-existent tool + with pytest.raises(ValueError, match="Tool non_existent_tool not found"): + agent.remove_tool("non_existent_tool") + + # Test removing all tools + agent.add_tool(tool) + assert len(agent._tools) == 1 + agent.remove_all_tools() + assert len(agent._tools) == 0 + + # Test idempotency of remove_all_tools + agent.remove_all_tools() + assert len(agent._tools) == 0 + + +def test_callable_tool_addition(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test adding a callable directly + def documented_tool() -> str: + """This is a documented tool""" + return "result" + + agent.add_tool(documented_tool) + assert len(agent._tools) == 1 + assert agent._tools[0].description == "This is a documented tool" + + # Test adding async callable + async def async_tool() -> str: + return "async result" + + agent.add_tool(async_tool) + assert len(agent._tools) == 2 + + +def test_invalid_tool_addition(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test adding invalid tool type + with pytest.raises(ValueError, match="Unsupported tool type"): + agent.add_tool("not a tool")