diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py index 9be3adcdc99..45d2a283fa3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_assistant_agent.py @@ -15,6 +15,7 @@ ) from autogen_core import CancellationToken, FunctionCall +from autogen_core.components.models import LLMMessage from autogen_core.model_context import ( ChatCompletionContext, UnboundedChatCompletionContext, @@ -255,24 +256,11 @@ def __init__( else: self._system_messages = [SystemMessage(content=system_message)] self._tools: List[Tool] = [] - if tools is not None: - if model_client.capabilities["function_calling"] is False: - raise ValueError("The model does not support function calling.") - for tool in tools: - if isinstance(tool, Tool): - self._tools.append(tool) - elif callable(tool): - if hasattr(tool, "__doc__") and tool.__doc__ is not None: - description = tool.__doc__ - else: - description = "" - self._tools.append(FunctionTool(tool, description=description)) - else: - raise ValueError(f"Unsupported tool type: {type(tool)}") - # Check if tool names are unique. - tool_names = [tool.name for tool in self._tools] - if len(tool_names) != len(set(tool_names)): - raise ValueError(f"Tool names must be unique: {tool_names}") + self._model_context: List[LLMMessage] = [] + self._reflect_on_tool_use = reflect_on_tool_use + self._tool_call_summary_format = tool_call_summary_format + self._is_running = False + # Handoff tools. self._handoff_tools: List[Tool] = [] self._handoffs: Dict[str, HandoffBase] = {} @@ -282,26 +270,191 @@ def __init__( for handoff in handoffs: if isinstance(handoff, str): handoff = HandoffBase(target=handoff) + if handoff.name in self._handoffs: + raise ValueError(f"Handoff name {handoff.name} already exists.") if isinstance(handoff, HandoffBase): self._handoff_tools.append(handoff.handoff_tool) self._handoffs[handoff.name] = handoff else: raise ValueError(f"Unsupported handoff type: {type(handoff)}") - # Check if handoff tool names are unique. - handoff_tool_names = [tool.name for tool in self._handoff_tools] - if len(handoff_tool_names) != len(set(handoff_tool_names)): - raise ValueError(f"Handoff names must be unique: {handoff_tool_names}") - # Check if handoff tool names not in tool names. - if any(name in tool_names for name in handoff_tool_names): - raise ValueError( - f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; tool names: {tool_names}" - ) + if tools is not None: + for tool in tools: + self.add_tool(tool) + if not model_context: self._model_context = UnboundedChatCompletionContext() self._reflect_on_tool_use = reflect_on_tool_use self._tool_call_summary_format = tool_call_summary_format self._is_running = False + def add_tool(self, tool: Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]) -> None: + """ + Adds a new tool to the assistant agent. + + The tool can be either an instance of the `Tool` class, or a callable function. If the tool is a callable + function, a :class:`~autogen_core.tools.FunctionTool` instance will be created with the function and its docstring as the description. + + The tool name must be unique among all the tools and handoffs added to the agent. If the model does not support + function calling, an error will be raised. + + Args: + tool (Tool | Callable[..., Any] | Callable[..., Awaitable[Any]]): The tool to add. + + Raises: + ValueError: If the tool name is not unique. + ValueError: If the tool name is already used by a handoff. + ValueError: If the tool has an unsupported type. + ValueError: If the model does not support function calling. + + Examples: + .. code-block:: python + + import asyncio + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + from autogen_agentchat.ui import Console + from autogen_core import CancellationToken + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client) + + agent.add_tool(get_current_time) + + await Console( + agent.on_messages_stream( + [TextMessage(content="What is the current time?", source="user")], CancellationToken() + ) + ) + + + asyncio.run(main()) + """ + new_tool = None + if self._model_client.capabilities["function_calling"] is False: + raise ValueError("The model does not support function calling.") + if isinstance(tool, Tool): + new_tool = tool + elif callable(tool): + if hasattr(tool, "__doc__") and tool.__doc__ is not None: + description = tool.__doc__ + else: + description = "" + new_tool = FunctionTool(tool, description=description) + else: + raise ValueError(f"Unsupported tool type: {type(tool)}") + # Check if tool names are unique. + if any(tool.name == new_tool.name for tool in self._tools): + raise ValueError(f"Tool names must be unique: {new_tool.name}") + # Check if handoff tool names not in tool names. + handoff_tool_names = [handoff.name for handoff in self._handoffs.values()] + if new_tool.name in handoff_tool_names: + raise ValueError( + f"Handoff names must be unique from tool names. Handoff names: {handoff_tool_names}; " + f"tool names: {new_tool.name}" + ) + self._tools.append(new_tool) + + def remove_all_tools(self) -> None: + """ + Remove all tools. + + Examples: + .. code-block:: python + + import asyncio + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + from autogen_agentchat.ui import Console + from autogen_core import CancellationToken + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client) + + agent.add_tool(get_current_time) + agent.remove_all_tools() + + await Console( + agent.on_messages_stream( + [TextMessage(content="What is the current time?", source="user")], CancellationToken() + ) + ) + + + asyncio.run(main()) + + """ + self._tools.clear() + + def remove_tool(self, tool_name: str) -> None: + """ + Remove a tool by name. + + Args: + tool_name (str): The name of the tool to remove. + + Raises: + ValueError: If the tool name is not found. + + Examples: + .. code-block:: python + + import asyncio + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.messages import TextMessage + from autogen_agentchat.ui import Console + from autogen_core import CancellationToken + + + async def get_current_time() -> str: + return "The current time is 12:00 PM." + + + async def main() -> None: + model_client = OpenAIChatCompletionClient( + model="gpt-4o", + # api_key = "your_openai_api_key" + ) + agent = AssistantAgent(name="assistant", model_client=model_client) + + agent.add_tool(get_current_time) + agent.remove_tool("get_current_time") + + await Console( + agent.on_messages_stream( + [TextMessage(content="What is the current time?", source="user")], CancellationToken() + ) + ) + + + asyncio.run(main()) + """ + for tool in self._tools: + if tool.name == tool_name: + self._tools.remove(tool) + return + raise ValueError(f"Tool {tool_name} not found.") + @property def produced_message_types(self) -> Tuple[type[ChatMessage], ...]: """The types of messages that the assistant agent produces.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 9065d513918..bd33702a6f4 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -468,3 +468,78 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: else: assert message == result.messages[index] index += 1 + + +def test_tool_management(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test function to be used as a tool + def sample_tool() -> str: + return "sample result" + + # Test adding a tool + tool = FunctionTool(sample_tool, description="Sample tool") + agent.add_tool(tool) + assert len(agent._tools) == 1 + + # Test adding duplicate tool + with pytest.raises(ValueError, match="Tool names must be unique"): + agent.add_tool(tool) + + # Test tool collision with handoff + agent_with_handoff = AssistantAgent( + name="test_assistant", model_client=model_client, handoffs=[Handoff(target="other_agent")] + ) + + conflicting_tool = FunctionTool(sample_tool, name="transfer_to_other_agent", description="Sample tool") + with pytest.raises(ValueError, match="Handoff names must be unique from tool names"): + agent_with_handoff.add_tool(conflicting_tool) + + # Test removing a tool + agent.remove_tool(tool.name) + assert len(agent._tools) == 0 + + # Test removing non-existent tool + with pytest.raises(ValueError, match="Tool non_existent_tool not found"): + agent.remove_tool("non_existent_tool") + + # Test removing all tools + agent.add_tool(tool) + assert len(agent._tools) == 1 + agent.remove_all_tools() + assert len(agent._tools) == 0 + + # Test idempotency of remove_all_tools + agent.remove_all_tools() + assert len(agent._tools) == 0 + + +def test_callable_tool_addition(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test adding a callable directly + def documented_tool() -> str: + """This is a documented tool""" + return "result" + + agent.add_tool(documented_tool) + assert len(agent._tools) == 1 + assert agent._tools[0].description == "This is a documented tool" + + # Test adding async callable + async def async_tool() -> str: + return "async result" + + agent.add_tool(async_tool) + assert len(agent._tools) == 2 + + +def test_invalid_tool_addition(): + model_client = OpenAIChatCompletionClient(model="gpt-4", api_key="") + agent = AssistantAgent(name="test_assistant", model_client=model_client) + + # Test adding invalid tool type + with pytest.raises(ValueError, match="Unsupported tool type"): + agent.add_tool("not a tool")