Skip to content

Commit

Permalink
[backend] adding tools for agents, filter tools by agent_id (cohere-a…
Browse files Browse the repository at this point in the history
…i#204)

* changes

* lint

* filter tools by agent

* lint

* update tests

* lint and squash migrations

* fix alembic migration err

* lint
  • Loading branch information
scott-cohere authored and ClaytonSmith committed Jun 13, 2024
1 parent 1bd58fe commit 53f587a
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 10 deletions.
49 changes: 49 additions & 0 deletions src/backend/alembic/versions/922e874930bf_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""empty message
Revision ID: 922e874930bf
Revises: 28763d200b29
Create Date: 2024-06-12 21:19:12.204875
"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql

# revision identifiers, used by Alembic.
revision: str = "922e874930bf"
down_revision: Union[str, None] = "28763d200b29"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"agents",
sa.Column(
"tools",
postgresql.ARRAY(
sa.Enum(
"Wiki_Retriever_LangChain",
"Search_File",
"Read_File",
"Python_Interpreter",
"Calculator",
"Tavily_Internet_Search",
name="toolname",
native_enum=False,
)
),
nullable=False,
),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("agents", "tools")
# ### end Alembic commands ###
8 changes: 6 additions & 2 deletions src/backend/database_models/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from enum import StrEnum

from sqlalchemy import Enum, Float, Integer, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.orm import Mapped, mapped_column

from backend.config.tools import ToolName
from backend.database_models.base import Base


Expand All @@ -28,7 +30,9 @@ class Agent(Base):
description: Mapped[str] = mapped_column(Text, default="", nullable=False)
preamble: Mapped[str] = mapped_column(Text, default="", nullable=False)
temperature: Mapped[float] = mapped_column(Float, default=0.3, nullable=False)
# tool: Mapped[List["Tool"]] = relationship()
tools: Mapped[list[ToolName]] = mapped_column(
ARRAY(Enum(ToolName, native_enum=False)), default=[], nullable=False
)

# TODO @scott-cohere: eventually switch to Fkey when new deployment tables are implemented
# TODO @scott-cohere: deployments have different names for models, need to implement mapping later
Expand Down
Empty file.
2 changes: 1 addition & 1 deletion src/backend/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_agent(session: DBSessionDep, agent: CreateAgent, request: Request):
user_id=user_id,
model=agent.model,
deployment=agent.deployment,
# tools=request.json().get("tools"),
tools=agent.tools,
)

return agent_crud.create_agent(session, agent_data)
Expand Down
22 changes: 20 additions & 2 deletions src/backend/routers/tool.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,37 @@
from fastapi import APIRouter
from typing import Optional

from fastapi import APIRouter, HTTPException

from backend.config.routers import RouterName
from backend.config.tools import AVAILABLE_TOOLS
from backend.crud import agent as agent_crud
from backend.database_models.database import DBSessionDep
from backend.schemas.tool import ManagedTool

router = APIRouter(prefix="/v1/tools")
router.name = RouterName.TOOL


@router.get("", response_model=list[ManagedTool])
def list_tools() -> list[ManagedTool]:
def list_tools(session: DBSessionDep, agent_id: str | None = None) -> list[ManagedTool]:
"""
List all available tools.
Returns:
list[ManagedTool]: List of available tools.
"""
if agent_id:
agent_tools = []
agent = agent_crud.get_agent(session, agent_id)

if not agent:
raise HTTPException(
status_code=404,
detail=f"Agent with ID: {agent_id} not found.",
)

for tool in agent.tools:
agent_tools.append(AVAILABLE_TOOLS[tool])
return agent_tools

return AVAILABLE_TOOLS.values()
6 changes: 4 additions & 2 deletions src/backend/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pydantic import BaseModel

from backend.config.tools import ToolName
from backend.database_models.agent import AgentDeployment, AgentModel


Expand All @@ -20,7 +21,7 @@ class Agent(AgentBase):
description: Optional[str]
preamble: Optional[str]
temperature: float
# tools: List[Tool]
tools: list[ToolName]

model: AgentModel
deployment: AgentDeployment
Expand All @@ -38,6 +39,7 @@ class CreateAgent(BaseModel):
temperature: Optional[float] = None
model: AgentModel
deployment: AgentDeployment
tools: Optional[list[ToolName]] = None

class Config:
from_attributes = True
Expand All @@ -52,7 +54,7 @@ class UpdateAgent(BaseModel):
temperature: Optional[float] = None
model: Optional[AgentModel] = None
deployment: Optional[AgentDeployment] = None
# tools: Optional[List[Tool]] = None
tools: Optional[list[ToolName]] = None

class Config:
from_attributes = True
Expand Down
10 changes: 10 additions & 0 deletions src/backend/tests/crud/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from sqlalchemy.exc import IntegrityError

from backend.config.tools import ToolName
from backend.crud import agent as agent_crud
from backend.database_models.agent import Agent, AgentDeployment, AgentModel
from backend.schemas.agent import UpdateAgent
Expand All @@ -17,6 +18,7 @@ def test_create_agent(session, user):
description="test",
preamble="test",
temperature=0.5,
tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File],
model=AgentModel.COMMAND_R_PLUS,
deployment=AgentDeployment.COHERE_PLATFORM,
)
Expand All @@ -28,6 +30,7 @@ def test_create_agent(session, user):
assert agent.description == "test"
assert agent.preamble == "test"
assert agent.temperature == 0.5
assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File]
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM

Expand All @@ -38,6 +41,7 @@ def test_create_agent(session, user):
assert agent.description == "test"
assert agent.preamble == "test"
assert agent.temperature == 0.5
assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File]
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM

Expand All @@ -57,6 +61,7 @@ def test_create_agent_empty_non_required_fields(session, user):
assert agent.description == ""
assert agent.preamble == ""
assert agent.temperature == 0.3
assert agent.tools == []
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM

Expand All @@ -67,6 +72,7 @@ def test_create_agent_empty_non_required_fields(session, user):
assert agent.description == ""
assert agent.preamble == ""
assert agent.temperature == 0.3
assert agent.tools == []
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM

Expand Down Expand Up @@ -127,6 +133,7 @@ def test_create_agent_duplicate_name_version(session, user):
description="test",
preamble="test",
temperature=0.5,
tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File],
model=AgentModel.COMMAND_R_PLUS,
deployment=AgentDeployment.COHERE_PLATFORM,
)
Expand Down Expand Up @@ -179,6 +186,7 @@ def test_update_agent(session, user):
preamble="test",
temperature=0.5,
user_id=user.id,
tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File],
)

new_agent_data = UpdateAgent(
Expand All @@ -187,6 +195,7 @@ def test_update_agent(session, user):
version=2,
preamble="new_test",
temperature=0.6,
tools=[ToolName.Python_Interpreter, ToolName.Calculator],
)

agent = agent_crud.update_agent(session, agent, new_agent_data)
Expand All @@ -195,6 +204,7 @@ def test_update_agent(session, user):
assert agent.version == new_agent_data.version
assert agent.preamble == new_agent_data.preamble
assert agent.temperature == new_agent_data.temperature
assert agent.tools == [ToolName.Python_Interpreter, ToolName.Calculator]


def test_delete_agent(session, user):
Expand Down
16 changes: 16 additions & 0 deletions src/backend/tests/factories/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import factory

from backend.config.tools import ToolName
from backend.database_models.agent import Agent, AgentDeployment, AgentModel

from .base import BaseFactory
Expand All @@ -17,6 +18,21 @@ class Meta:
temperature = factory.Faker("pyfloat")
created_at = factory.Faker("date_time")
updated_at = factory.Faker("date_time")
tools = factory.List(
[
factory.Faker(
"random_element",
elements=[
ToolName.Wiki_Retriever_LangChain,
ToolName.Search_File,
ToolName.Read_File,
ToolName.Python_Interpreter,
ToolName.Calculator,
ToolName.Tavily_Internet_Search,
],
)
]
)
model = factory.Faker(
"random_element",
elements=(
Expand Down
20 changes: 20 additions & 0 deletions src/backend/tests/routers/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.config.tools import ToolName
from backend.database_models.agent import Agent, AgentDeployment, AgentModel
from backend.tests.factories import get_factory

Expand All @@ -14,6 +15,7 @@ def test_create_agent(session_client: TestClient, session: Session) -> None:
"temperature": 0.5,
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
"tools": [ToolName.Wiki_Retriever_LangChain],
}

response = session_client.post(
Expand All @@ -29,6 +31,7 @@ def test_create_agent(session_client: TestClient, session: Session) -> None:
assert response_agent["temperature"] == request_json["temperature"]
assert response_agent["model"] == request_json["model"]
assert response_agent["deployment"] == request_json["deployment"]
assert response_agent["tools"] == request_json["tools"]

agent = session.get(Agent, response_agent["id"])
assert agent is not None
Expand All @@ -39,6 +42,7 @@ def test_create_agent(session_client: TestClient, session: Session) -> None:
assert agent.temperature == request_json["temperature"]
assert agent.model == request_json["model"]
assert agent.deployment == request_json["deployment"]
assert agent.tools == request_json["tools"]


def test_create_agent_missing_name(
Expand Down Expand Up @@ -156,6 +160,22 @@ def test_create_agent_wrong_model_deployment_enums(
assert response.status_code == 422


def test_create_agent_wrong_tool_name_enums(
session_client: TestClient, session: Session
) -> None:
request_json = {
"name": "test agent",
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
"tools": ["not a real tool"],
}

response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": "123"}
)
assert response.status_code == 422


def test_list_agents_empty(session_client: TestClient, session: Session) -> None:
response = session_client.get("/v1/agents", headers={"User-Id": "123"})
assert response.status_code == 200
Expand Down
39 changes: 36 additions & 3 deletions src/backend/tests/routers/test_tool.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.config.tools import AVAILABLE_TOOLS
from backend.config.tools import AVAILABLE_TOOLS, ToolName
from backend.tests.factories import get_factory


def test_list_tools(client: TestClient) -> None:
response = client.get("/v1/tools")
def test_list_tools(session_client: TestClient, session: Session) -> None:
response = session_client.get("/v1/tools")
assert response.status_code == 200
for tool in response.json():
assert tool["name"] in AVAILABLE_TOOLS.keys()
Expand All @@ -18,3 +20,34 @@ def test_list_tools(client: TestClient) -> None:
assert tool["error_message"] == tool_definition.error_message
assert tool["category"] == tool_definition.category
assert tool["description"] == tool_definition.description


def test_list_tools_with_agent(session_client: TestClient, session: Session) -> None:
agent = get_factory("Agent", session).create(
name="test agent", tools=[ToolName.Wiki_Retriever_LangChain]
)

response = session_client.get("/v1/tools", params={"agent_id": agent.id})
assert response.status_code == 200
assert len(response.json()) == 1

tool = response.json()[0]
assert tool["name"] == ToolName.Wiki_Retriever_LangChain

# get tool that has the same name as the tool in the response
tool_definition = AVAILABLE_TOOLS[tool["name"]]

assert tool["kwargs"] == tool_definition.kwargs
assert tool["is_visible"] == tool_definition.is_visible
assert tool["is_available"] == tool_definition.is_available
assert tool["error_message"] == tool_definition.error_message
assert tool["category"] == tool_definition.category
assert tool["description"] == tool_definition.description


def test_list_tools_with_agent_that_doesnt_exist(
session_client: TestClient, session: Session
) -> None:
response = session_client.get("/v1/tools", params={"agent_id": "fake_id"})
assert response.status_code == 404
assert response.json() == {"detail": "Agent with ID: fake_id not found."}

0 comments on commit 53f587a

Please sign in to comment.