From 53f587addca4f5a9b6fbe28f62d9857924e0bd06 Mon Sep 17 00:00:00 2001 From: Scott <146760070+scott-cohere@users.noreply.github.com> Date: Thu, 13 Jun 2024 12:59:29 -0400 Subject: [PATCH] [backend] adding tools for agents, filter tools by agent_id (#204) * changes * lint * filter tools by agent * lint * update tests * lint and squash migrations * fix alembic migration err * lint --- src/backend/alembic/versions/922e874930bf_.py | 49 +++++++++++++++++++ src/backend/database_models/agent.py | 8 ++- src/backend/database_models/tool.py | 0 src/backend/routers/agent.py | 2 +- src/backend/routers/tool.py | 22 ++++++++- src/backend/schemas/agent.py | 6 ++- src/backend/tests/crud/test_agent.py | 10 ++++ src/backend/tests/factories/agent.py | 16 ++++++ src/backend/tests/routers/test_agent.py | 20 ++++++++ src/backend/tests/routers/test_tool.py | 39 +++++++++++++-- 10 files changed, 162 insertions(+), 10 deletions(-) create mode 100644 src/backend/alembic/versions/922e874930bf_.py create mode 100644 src/backend/database_models/tool.py diff --git a/src/backend/alembic/versions/922e874930bf_.py b/src/backend/alembic/versions/922e874930bf_.py new file mode 100644 index 0000000000..971f5e2a85 --- /dev/null +++ b/src/backend/alembic/versions/922e874930bf_.py @@ -0,0 +1,49 @@ +"""empty message + +Revision ID: 922e874930bf +Revises: 28763d200b29 +Create Date: 2024-06-12 21:19:12.204875 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "922e874930bf" +down_revision: Union[str, None] = "28763d200b29" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "agents", + sa.Column( + "tools", + postgresql.ARRAY( + sa.Enum( + "Wiki_Retriever_LangChain", + "Search_File", + "Read_File", + "Python_Interpreter", + "Calculator", + "Tavily_Internet_Search", + name="toolname", + native_enum=False, + ) + ), + nullable=False, + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("agents", "tools") + # ### end Alembic commands ### diff --git a/src/backend/database_models/agent.py b/src/backend/database_models/agent.py index 5097c46ed4..e2561e0aba 100644 --- a/src/backend/database_models/agent.py +++ b/src/backend/database_models/agent.py @@ -1,8 +1,10 @@ from enum import StrEnum from sqlalchemy import Enum, Float, Integer, String, Text, UniqueConstraint -from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.dialects.postgresql import ARRAY +from sqlalchemy.orm import Mapped, mapped_column +from backend.config.tools import ToolName from backend.database_models.base import Base @@ -28,7 +30,9 @@ class Agent(Base): description: Mapped[str] = mapped_column(Text, default="", nullable=False) preamble: Mapped[str] = mapped_column(Text, default="", nullable=False) temperature: Mapped[float] = mapped_column(Float, default=0.3, nullable=False) - # tool: Mapped[List["Tool"]] = relationship() + tools: Mapped[list[ToolName]] = mapped_column( + ARRAY(Enum(ToolName, native_enum=False)), default=[], nullable=False + ) # TODO @scott-cohere: eventually switch to Fkey when new deployment tables are implemented # TODO @scott-cohere: deployments have different names for models, need to implement mapping later diff --git a/src/backend/database_models/tool.py b/src/backend/database_models/tool.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index bed3c14936..ceae9aa655 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -26,7 +26,7 @@ def create_agent(session: DBSessionDep, agent: CreateAgent, request: Request): user_id=user_id, model=agent.model, deployment=agent.deployment, - # tools=request.json().get("tools"), + tools=agent.tools, ) return agent_crud.create_agent(session, agent_data) diff --git a/src/backend/routers/tool.py b/src/backend/routers/tool.py index 74f0814d7d..ad9d40a97a 100644 --- a/src/backend/routers/tool.py +++ b/src/backend/routers/tool.py @@ -1,7 +1,11 @@ -from fastapi import APIRouter +from typing import Optional + +from fastapi import APIRouter, HTTPException from backend.config.routers import RouterName from backend.config.tools import AVAILABLE_TOOLS +from backend.crud import agent as agent_crud +from backend.database_models.database import DBSessionDep from backend.schemas.tool import ManagedTool router = APIRouter(prefix="/v1/tools") @@ -9,11 +13,25 @@ @router.get("", response_model=list[ManagedTool]) -def list_tools() -> list[ManagedTool]: +def list_tools(session: DBSessionDep, agent_id: str | None = None) -> list[ManagedTool]: """ List all available tools. Returns: list[ManagedTool]: List of available tools. """ + if agent_id: + agent_tools = [] + agent = agent_crud.get_agent(session, agent_id) + + if not agent: + raise HTTPException( + status_code=404, + detail=f"Agent with ID: {agent_id} not found.", + ) + + for tool in agent.tools: + agent_tools.append(AVAILABLE_TOOLS[tool]) + return agent_tools + return AVAILABLE_TOOLS.values() diff --git a/src/backend/schemas/agent.py b/src/backend/schemas/agent.py index a9b5050459..d7289f6d9e 100644 --- a/src/backend/schemas/agent.py +++ b/src/backend/schemas/agent.py @@ -3,6 +3,7 @@ from pydantic import BaseModel +from backend.config.tools import ToolName from backend.database_models.agent import AgentDeployment, AgentModel @@ -20,7 +21,7 @@ class Agent(AgentBase): description: Optional[str] preamble: Optional[str] temperature: float - # tools: List[Tool] + tools: list[ToolName] model: AgentModel deployment: AgentDeployment @@ -38,6 +39,7 @@ class CreateAgent(BaseModel): temperature: Optional[float] = None model: AgentModel deployment: AgentDeployment + tools: Optional[list[ToolName]] = None class Config: from_attributes = True @@ -52,7 +54,7 @@ class UpdateAgent(BaseModel): temperature: Optional[float] = None model: Optional[AgentModel] = None deployment: Optional[AgentDeployment] = None - # tools: Optional[List[Tool]] = None + tools: Optional[list[ToolName]] = None class Config: from_attributes = True diff --git a/src/backend/tests/crud/test_agent.py b/src/backend/tests/crud/test_agent.py index ce91729d29..989477c01f 100644 --- a/src/backend/tests/crud/test_agent.py +++ b/src/backend/tests/crud/test_agent.py @@ -3,6 +3,7 @@ import pytest from sqlalchemy.exc import IntegrityError +from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.database_models.agent import Agent, AgentDeployment, AgentModel from backend.schemas.agent import UpdateAgent @@ -17,6 +18,7 @@ def test_create_agent(session, user): description="test", preamble="test", temperature=0.5, + tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], model=AgentModel.COMMAND_R_PLUS, deployment=AgentDeployment.COHERE_PLATFORM, ) @@ -28,6 +30,7 @@ def test_create_agent(session, user): assert agent.description == "test" assert agent.preamble == "test" assert agent.temperature == 0.5 + assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] assert agent.model == AgentModel.COMMAND_R_PLUS assert agent.deployment == AgentDeployment.COHERE_PLATFORM @@ -38,6 +41,7 @@ def test_create_agent(session, user): assert agent.description == "test" assert agent.preamble == "test" assert agent.temperature == 0.5 + assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] assert agent.model == AgentModel.COMMAND_R_PLUS assert agent.deployment == AgentDeployment.COHERE_PLATFORM @@ -57,6 +61,7 @@ def test_create_agent_empty_non_required_fields(session, user): assert agent.description == "" assert agent.preamble == "" assert agent.temperature == 0.3 + assert agent.tools == [] assert agent.model == AgentModel.COMMAND_R_PLUS assert agent.deployment == AgentDeployment.COHERE_PLATFORM @@ -67,6 +72,7 @@ def test_create_agent_empty_non_required_fields(session, user): assert agent.description == "" assert agent.preamble == "" assert agent.temperature == 0.3 + assert agent.tools == [] assert agent.model == AgentModel.COMMAND_R_PLUS assert agent.deployment == AgentDeployment.COHERE_PLATFORM @@ -127,6 +133,7 @@ def test_create_agent_duplicate_name_version(session, user): description="test", preamble="test", temperature=0.5, + tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], model=AgentModel.COMMAND_R_PLUS, deployment=AgentDeployment.COHERE_PLATFORM, ) @@ -179,6 +186,7 @@ def test_update_agent(session, user): preamble="test", temperature=0.5, user_id=user.id, + tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], ) new_agent_data = UpdateAgent( @@ -187,6 +195,7 @@ def test_update_agent(session, user): version=2, preamble="new_test", temperature=0.6, + tools=[ToolName.Python_Interpreter, ToolName.Calculator], ) agent = agent_crud.update_agent(session, agent, new_agent_data) @@ -195,6 +204,7 @@ def test_update_agent(session, user): assert agent.version == new_agent_data.version assert agent.preamble == new_agent_data.preamble assert agent.temperature == new_agent_data.temperature + assert agent.tools == [ToolName.Python_Interpreter, ToolName.Calculator] def test_delete_agent(session, user): diff --git a/src/backend/tests/factories/agent.py b/src/backend/tests/factories/agent.py index 6122471ae5..d86b91d2f1 100644 --- a/src/backend/tests/factories/agent.py +++ b/src/backend/tests/factories/agent.py @@ -1,5 +1,6 @@ import factory +from backend.config.tools import ToolName from backend.database_models.agent import Agent, AgentDeployment, AgentModel from .base import BaseFactory @@ -17,6 +18,21 @@ class Meta: temperature = factory.Faker("pyfloat") created_at = factory.Faker("date_time") updated_at = factory.Faker("date_time") + tools = factory.List( + [ + factory.Faker( + "random_element", + elements=[ + ToolName.Wiki_Retriever_LangChain, + ToolName.Search_File, + ToolName.Read_File, + ToolName.Python_Interpreter, + ToolName.Calculator, + ToolName.Tavily_Internet_Search, + ], + ) + ] + ) model = factory.Faker( "random_element", elements=( diff --git a/src/backend/tests/routers/test_agent.py b/src/backend/tests/routers/test_agent.py index 9bbdb3ca38..41ab7f16b3 100644 --- a/src/backend/tests/routers/test_agent.py +++ b/src/backend/tests/routers/test_agent.py @@ -1,6 +1,7 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session +from backend.config.tools import ToolName from backend.database_models.agent import Agent, AgentDeployment, AgentModel from backend.tests.factories import get_factory @@ -14,6 +15,7 @@ def test_create_agent(session_client: TestClient, session: Session) -> None: "temperature": 0.5, "model": AgentModel.COMMAND_R, "deployment": AgentDeployment.COHERE_PLATFORM, + "tools": [ToolName.Wiki_Retriever_LangChain], } response = session_client.post( @@ -29,6 +31,7 @@ def test_create_agent(session_client: TestClient, session: Session) -> None: assert response_agent["temperature"] == request_json["temperature"] assert response_agent["model"] == request_json["model"] assert response_agent["deployment"] == request_json["deployment"] + assert response_agent["tools"] == request_json["tools"] agent = session.get(Agent, response_agent["id"]) assert agent is not None @@ -39,6 +42,7 @@ def test_create_agent(session_client: TestClient, session: Session) -> None: assert agent.temperature == request_json["temperature"] assert agent.model == request_json["model"] assert agent.deployment == request_json["deployment"] + assert agent.tools == request_json["tools"] def test_create_agent_missing_name( @@ -156,6 +160,22 @@ def test_create_agent_wrong_model_deployment_enums( assert response.status_code == 422 +def test_create_agent_wrong_tool_name_enums( + session_client: TestClient, session: Session +) -> None: + request_json = { + "name": "test agent", + "model": AgentModel.COMMAND_R, + "deployment": AgentDeployment.COHERE_PLATFORM, + "tools": ["not a real tool"], + } + + response = session_client.post( + "/v1/agents", json=request_json, headers={"User-Id": "123"} + ) + assert response.status_code == 422 + + def test_list_agents_empty(session_client: TestClient, session: Session) -> None: response = session_client.get("/v1/agents", headers={"User-Id": "123"}) assert response.status_code == 200 diff --git a/src/backend/tests/routers/test_tool.py b/src/backend/tests/routers/test_tool.py index 86e6b4eef1..f08582d710 100644 --- a/src/backend/tests/routers/test_tool.py +++ b/src/backend/tests/routers/test_tool.py @@ -1,10 +1,12 @@ from fastapi.testclient import TestClient +from sqlalchemy.orm import Session -from backend.config.tools import AVAILABLE_TOOLS +from backend.config.tools import AVAILABLE_TOOLS, ToolName +from backend.tests.factories import get_factory -def test_list_tools(client: TestClient) -> None: - response = client.get("/v1/tools") +def test_list_tools(session_client: TestClient, session: Session) -> None: + response = session_client.get("/v1/tools") assert response.status_code == 200 for tool in response.json(): assert tool["name"] in AVAILABLE_TOOLS.keys() @@ -18,3 +20,34 @@ def test_list_tools(client: TestClient) -> None: assert tool["error_message"] == tool_definition.error_message assert tool["category"] == tool_definition.category assert tool["description"] == tool_definition.description + + +def test_list_tools_with_agent(session_client: TestClient, session: Session) -> None: + agent = get_factory("Agent", session).create( + name="test agent", tools=[ToolName.Wiki_Retriever_LangChain] + ) + + response = session_client.get("/v1/tools", params={"agent_id": agent.id}) + assert response.status_code == 200 + assert len(response.json()) == 1 + + tool = response.json()[0] + assert tool["name"] == ToolName.Wiki_Retriever_LangChain + + # get tool that has the same name as the tool in the response + tool_definition = AVAILABLE_TOOLS[tool["name"]] + + assert tool["kwargs"] == tool_definition.kwargs + assert tool["is_visible"] == tool_definition.is_visible + assert tool["is_available"] == tool_definition.is_available + assert tool["error_message"] == tool_definition.error_message + assert tool["category"] == tool_definition.category + assert tool["description"] == tool_definition.description + + +def test_list_tools_with_agent_that_doesnt_exist( + session_client: TestClient, session: Session +) -> None: + response = session_client.get("/v1/tools", params={"agent_id": "fake_id"}) + assert response.status_code == 404 + assert response.json() == {"detail": "Agent with ID: fake_id not found."}