From 61cc85a02b2271571b5939b292cb9fe686e4cc81 Mon Sep 17 00:00:00 2001 From: Scott <146760070+scott-cohere@users.noreply.github.com> Date: Mon, 17 Jun 2024 10:34:41 -0400 Subject: [PATCH 1/4] [backend] remove agent request enums, validate tools, model, deployment for agents (#214) * fix * moar fix * Saving * making deployment required again.. lol * deployment schema error now?? * done * lint. gen client --- src/backend/alembic/versions/3f207ae41477_.py | 71 ++++++ src/backend/database_models/agent.py | 31 +-- src/backend/routers/agent.py | 24 +- src/backend/schemas/agent.py | 21 +- src/backend/services/request_validators.py | 89 ++++++++ src/backend/tests/crud/test_agent.py | 40 ++-- src/backend/tests/factories/agent.py | 23 +- src/backend/tests/routers/test_agent.py | 209 +++++++++++++++--- .../src/cohere-client/generated/index.ts | 3 - .../cohere-client/generated/models/Agent.ts | 13 +- .../generated/models/AgentDeployment.ts | 10 - .../generated/models/AgentModel.ts | 10 - .../generated/models/CreateAgent.ts | 13 +- .../generated/models/ToolName.ts | 12 - .../generated/models/UpdateAgent.ts | 13 +- 15 files changed, 403 insertions(+), 179 deletions(-) create mode 100644 src/backend/alembic/versions/3f207ae41477_.py delete mode 100644 src/interfaces/coral_web/src/cohere-client/generated/models/AgentDeployment.ts delete mode 100644 src/interfaces/coral_web/src/cohere-client/generated/models/AgentModel.ts delete mode 100644 src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts diff --git a/src/backend/alembic/versions/3f207ae41477_.py b/src/backend/alembic/versions/3f207ae41477_.py new file mode 100644 index 0000000000..741724bf89 --- /dev/null +++ b/src/backend/alembic/versions/3f207ae41477_.py @@ -0,0 +1,71 @@ +"""empty message + +Revision ID: 3f207ae41477 +Revises: 922e874930bf +Create Date: 2024-06-15 23:02:22.350756 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "3f207ae41477" +down_revision: Union[str, None] = "922e874930bf" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "agents", + "tools", + existing_type=postgresql.ARRAY(sa.VARCHAR(length=24)), + type_=postgresql.ARRAY(sa.Text()), + existing_nullable=False, + ) + op.alter_column( + "agents", + "model", + existing_type=sa.VARCHAR(length=14), + type_=sa.Text(), + existing_nullable=False, + ) + op.alter_column( + "agents", + "deployment", + existing_type=sa.VARCHAR(length=15), + type_=sa.Text(), + existing_nullable=False, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.alter_column( + "agents", + "deployment", + existing_type=sa.Text(), + type_=sa.VARCHAR(length=15), + existing_nullable=False, + ) + op.alter_column( + "agents", + "model", + existing_type=sa.Text(), + type_=sa.VARCHAR(length=14), + existing_nullable=False, + ) + op.alter_column( + "agents", + "tools", + existing_type=postgresql.ARRAY(sa.Text()), + type_=postgresql.ARRAY(sa.VARCHAR(length=24)), + existing_nullable=False, + ) + # ### end Alembic commands ### diff --git a/src/backend/database_models/agent.py b/src/backend/database_models/agent.py index e842e861c8..4ead7789f3 100644 --- a/src/backend/database_models/agent.py +++ b/src/backend/database_models/agent.py @@ -1,27 +1,12 @@ from enum import StrEnum -from sqlalchemy import Enum, Float, Integer, String, Text, UniqueConstraint +from sqlalchemy import Enum, Float, Integer, Text, UniqueConstraint from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.orm import Mapped, mapped_column -from backend.config.tools import ToolName from backend.database_models.base import Base -class AgentDeployment(StrEnum): - COHERE_PLATFORM = "Cohere Platform" - SAGE_MAKER = "SageMaker" - AZURE = "Azure" - BEDROCK = "Bedrock" - - -class AgentModel(StrEnum): - COMMAND_R = "command-r" - COMMAND_R_PLUS = "command-r-plus" - COMMAND_LIGHT = "command-light" - COMMAND = "command" - - class Agent(Base): __tablename__ = "agents" @@ -30,19 +15,15 @@ class Agent(Base): description: Mapped[str] = mapped_column(Text, default="", nullable=False) preamble: Mapped[str] = mapped_column(Text, default="", nullable=False) temperature: Mapped[float] = mapped_column(Float, default=0.3, nullable=False) - tools: Mapped[list[ToolName]] = mapped_column( - ARRAY(Enum(ToolName, native_enum=False)), default=[], nullable=False - ) + tools: Mapped[list[str]] = mapped_column(ARRAY(Text), default=[], nullable=False) # TODO @scott-cohere: eventually switch to Fkey when new deployment tables are implemented # TODO @scott-cohere: deployments have different names for models, need to implement mapping later # enum place holders - model: Mapped[AgentModel] = mapped_column( - Enum(AgentModel, native_enum=False), nullable=False - ) - deployment: Mapped[AgentDeployment] = mapped_column( - Enum(AgentDeployment, native_enum=False), - default=AgentDeployment.COHERE_PLATFORM, + model: Mapped[str] = mapped_column(Text, nullable=False) + # This is not used for now, just default it to Cohere Platform + deployment: Mapped[str] = mapped_column( + Text, nullable=False, ) diff --git a/src/backend/routers/agent.py b/src/backend/routers/agent.py index ceae9aa655..d2e2c6c253 100644 --- a/src/backend/routers/agent.py +++ b/src/backend/routers/agent.py @@ -6,7 +6,11 @@ from backend.database_models.database import DBSessionDep from backend.schemas.agent import Agent, CreateAgent, DeleteAgent, UpdateAgent from backend.services.auth.utils import get_header_user_id -from backend.services.request_validators import validate_user_header +from backend.services.request_validators import ( + validate_create_agent_request, + validate_update_agent_request, + validate_user_header, +) router = APIRouter( prefix="/v1/agents", @@ -14,7 +18,14 @@ router.name = RouterName.AGENT -@router.post("", response_model=Agent, dependencies=[Depends(validate_user_header)]) +@router.post( + "", + response_model=Agent, + dependencies=[ + Depends(validate_user_header), + Depends(validate_create_agent_request), + ], +) def create_agent(session: DBSessionDep, agent: CreateAgent, request: Request): user_id = get_header_user_id(request) @@ -75,7 +86,14 @@ async def get_agent(agent_id: str, session: DBSessionDep, request: Request) -> A return agent -@router.put("/{agent_id}", response_model=Agent) +@router.put( + "/{agent_id}", + response_model=Agent, + dependencies=[ + Depends(validate_user_header), + Depends(validate_update_agent_request), + ], +) async def update_agent( agent_id: str, new_agent: UpdateAgent, diff --git a/src/backend/schemas/agent.py b/src/backend/schemas/agent.py index ad4827393f..6d100bf2a7 100644 --- a/src/backend/schemas/agent.py +++ b/src/backend/schemas/agent.py @@ -3,9 +3,6 @@ from pydantic import BaseModel -from backend.config.tools import ToolName -from backend.database_models.agent import AgentDeployment, AgentModel - class AgentBase(BaseModel): user_id: str @@ -21,10 +18,10 @@ class Agent(AgentBase): description: Optional[str] preamble: Optional[str] temperature: float - tools: list[ToolName] + tools: list[str] - model: AgentModel - deployment: AgentDeployment + model: str + deployment: str class Config: from_attributes = True @@ -37,9 +34,9 @@ class CreateAgent(BaseModel): description: Optional[str] = None preamble: Optional[str] = None temperature: Optional[float] = None - model: AgentModel - deployment: Optional[AgentDeployment] = None - tools: Optional[list[ToolName]] = None + model: str + deployment: str + tools: Optional[list[str]] = None class Config: from_attributes = True @@ -52,9 +49,9 @@ class UpdateAgent(BaseModel): description: Optional[str] = None preamble: Optional[str] = None temperature: Optional[float] = None - model: Optional[AgentModel] = None - deployment: Optional[AgentDeployment] = None - tools: Optional[list[ToolName]] = None + model: Optional[str] = None + deployment: Optional[str] = None + tools: Optional[list[str]] = None class Config: from_attributes = True diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index fa361b09fc..eb942118f0 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -110,3 +110,92 @@ async def validate_env_vars(request: Request): + ",".join(invalid_keys) ), ) + + +async def validate_create_agent_request(request: Request): + """ + Validate that the create agent request has valid tools, deployments, and compatible models. + + Args: + request (Request): The request to validate + + Raises: + HTTPException: If the request does not have the appropriate values in the body + """ + body = await request.json() + + # Validate tools + tools = body.get("tools") + if tools: + for tool in tools: + if tool not in AVAILABLE_TOOLS: + raise HTTPException(status_code=400, detail=f"Tool {tool} not found.") + + name = body.get("name") + model = body.get("model") + deployment = body.get("deployment") + if not name or not model or not deployment: + raise HTTPException( + status_code=400, detail="Name, model, and deployment are required." + ) + + # Validate deployment + if deployment not in AVAILABLE_MODEL_DEPLOYMENTS.keys(): + raise HTTPException( + status_code=400, + detail=f"Deployment {deployment} not found or is not available.", + ) + + # Validate model + if model not in AVAILABLE_MODEL_DEPLOYMENTS[deployment].models: + raise HTTPException( + status_code=400, + detail=f"Model {model} not found for deployment {deployment}.", + ) + + +async def validate_update_agent_request(request: Request): + """ + Validate that the update agent request has valid tools, deployments, and compatible models. + + Args: + request (Request): The request to validate + + Raises: + HTTPException: If the request does not have the appropriate values in the body + """ + body = await request.json() + + # Validate tools + tools = body.get("tools") + if tools: + for tool in tools: + if tool not in AVAILABLE_TOOLS: + raise HTTPException(status_code=400, detail=f"Tool {tool} not found.") + + model, deployment = body.get("model"), body.get("deployment") + # Model and deployment must be updated together to ensure compatibility + if not model and deployment: + raise HTTPException( + status_code=400, + detail="If updating an agent's deployment type, the model must also be provided.", + ) + elif model and not deployment: + raise HTTPException( + status_code=400, + detail=f"If updating an agent's model, the deployment must also be provided.", + ) + elif model and deployment: + # Validate deployment + if deployment not in AVAILABLE_MODEL_DEPLOYMENTS.keys(): + raise HTTPException( + status_code=400, + detail=f"Deployment {deployment} not found or is not available.", + ) + + # Validate model + if model not in AVAILABLE_MODEL_DEPLOYMENTS[deployment].models: + raise HTTPException( + status_code=400, + detail=f"Model {model} not found for deployment {deployment}.", + ) diff --git a/src/backend/tests/crud/test_agent.py b/src/backend/tests/crud/test_agent.py index 6fbeb5e330..992b30641c 100644 --- a/src/backend/tests/crud/test_agent.py +++ b/src/backend/tests/crud/test_agent.py @@ -3,9 +3,10 @@ import pytest from sqlalchemy.exc import IntegrityError +from backend.config.deployments import ALL_MODEL_DEPLOYMENTS, ModelDeploymentName from backend.config.tools import ToolName from backend.crud import agent as agent_crud -from backend.database_models.agent import Agent, AgentDeployment, AgentModel +from backend.database_models.agent import Agent from backend.schemas.agent import UpdateAgent from backend.tests.factories import get_factory @@ -19,8 +20,8 @@ def test_create_agent(session, user): preamble="test", temperature=0.5, tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], - model=AgentModel.COMMAND_R_PLUS, - deployment=AgentDeployment.COHERE_PLATFORM, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, ) agent = agent_crud.create_agent(session, agent_data) @@ -31,8 +32,8 @@ def test_create_agent(session, user): assert agent.preamble == "test" assert agent.temperature == 0.5 assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] - assert agent.model == AgentModel.COMMAND_R_PLUS - assert agent.deployment == AgentDeployment.COHERE_PLATFORM + assert agent.model == "command-r-plus" + assert agent.deployment == ModelDeploymentName.CoherePlatform agent = agent_crud.get_agent(session, agent.id) assert agent.user_id == user.id @@ -42,15 +43,16 @@ def test_create_agent(session, user): assert agent.preamble == "test" assert agent.temperature == 0.5 assert agent.tools == [ToolName.Wiki_Retriever_LangChain, ToolName.Search_File] - assert agent.model == AgentModel.COMMAND_R_PLUS - assert agent.deployment == AgentDeployment.COHERE_PLATFORM + assert agent.model == "command-r-plus" + assert agent.deployment == ModelDeploymentName.CoherePlatform def test_create_agent_empty_non_required_fields(session, user): agent_data = Agent( user_id=user.id, name="test", - model=AgentModel.COMMAND_R_PLUS, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, ) agent = agent_crud.create_agent(session, agent_data) @@ -61,8 +63,8 @@ def test_create_agent_empty_non_required_fields(session, user): assert agent.preamble == "" assert agent.temperature == 0.3 assert agent.tools == [] - assert agent.model == AgentModel.COMMAND_R_PLUS - assert agent.deployment == AgentDeployment.COHERE_PLATFORM + assert agent.model == "command-r-plus" + assert agent.deployment == ModelDeploymentName.CoherePlatform agent = agent_crud.get_agent(session, agent.id) assert agent.user_id == user.id @@ -72,15 +74,15 @@ def test_create_agent_empty_non_required_fields(session, user): assert agent.preamble == "" assert agent.temperature == 0.3 assert agent.tools == [] - assert agent.model == AgentModel.COMMAND_R_PLUS - assert agent.deployment == AgentDeployment.COHERE_PLATFORM + assert agent.model == "command-r-plus" + assert agent.deployment == ModelDeploymentName.CoherePlatform def test_create_agent_missing_name(session, user): agent_data = Agent( user_id=user.id, - model=AgentModel.COMMAND_R_PLUS, - deployment=AgentDeployment.COHERE_PLATFORM, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, ) with pytest.raises(IntegrityError): @@ -91,7 +93,7 @@ def test_create_agent_missing_model(session, user): agent_data = Agent( user_id=user.id, name="test", - deployment=AgentDeployment.COHERE_PLATFORM, + deployment=ModelDeploymentName.CoherePlatform, ) with pytest.raises(IntegrityError): @@ -101,8 +103,8 @@ def test_create_agent_missing_model(session, user): def test_create_agent_missing_user_id(session): agent_data = Agent( name="test", - model=AgentModel.COMMAND_R_PLUS, - deployment=AgentDeployment.COHERE_PLATFORM, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, ) with pytest.raises(IntegrityError): @@ -122,8 +124,8 @@ def test_create_agent_duplicate_name_version(session, user): preamble="test", temperature=0.5, tools=[ToolName.Wiki_Retriever_LangChain, ToolName.Search_File], - model=AgentModel.COMMAND_R_PLUS, - deployment=AgentDeployment.COHERE_PLATFORM, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, ) with pytest.raises(IntegrityError): diff --git a/src/backend/tests/factories/agent.py b/src/backend/tests/factories/agent.py index d86b91d2f1..5c2ddc5d92 100644 --- a/src/backend/tests/factories/agent.py +++ b/src/backend/tests/factories/agent.py @@ -1,7 +1,8 @@ import factory +from backend.config.deployments import ALL_MODEL_DEPLOYMENTS, ModelDeploymentName from backend.config.tools import ToolName -from backend.database_models.agent import Agent, AgentDeployment, AgentModel +from backend.database_models.agent import Agent from .base import BaseFactory @@ -33,21 +34,5 @@ class Meta: ) ] ) - model = factory.Faker( - "random_element", - elements=( - AgentModel.COMMAND_R, - AgentModel.COMMAND_R_PLUS, - AgentModel.COMMAND_LIGHT, - AgentModel.COMMAND, - ), - ) - deployment = factory.Faker( - "random_element", - elements=( - AgentDeployment.COHERE_PLATFORM, - AgentDeployment.SAGE_MAKER, - AgentDeployment.AZURE, - AgentDeployment.BEDROCK, - ), - ) + model = "command-r-plus" + deployment = ModelDeploymentName.CoherePlatform diff --git a/src/backend/tests/routers/test_agent.py b/src/backend/tests/routers/test_agent.py index 20b7d53611..2f1126f547 100644 --- a/src/backend/tests/routers/test_agent.py +++ b/src/backend/tests/routers/test_agent.py @@ -1,8 +1,9 @@ from fastapi.testclient import TestClient from sqlalchemy.orm import Session +from backend.config.deployments import ALL_MODEL_DEPLOYMENTS, ModelDeploymentName from backend.config.tools import ToolName -from backend.database_models.agent import Agent, AgentDeployment, AgentModel +from backend.database_models.agent import Agent from backend.tests.factories import get_factory @@ -13,9 +14,9 @@ def test_create_agent(session_client: TestClient, session: Session) -> None: "description": "test description", "preamble": "test preamble", "temperature": 0.5, - "model": AgentModel.COMMAND_R, - "deployment": AgentDeployment.COHERE_PLATFORM, - "tools": [ToolName.Wiki_Retriever_LangChain], + "model": "command-r-plus", + "deployment": ModelDeploymentName.CoherePlatform, + "tools": [ToolName.Calculator], } response = session_client.post( @@ -52,13 +53,14 @@ def test_create_agent_missing_name( "description": "test description", "preamble": "test preamble", "temperature": 0.5, - "model": AgentModel.COMMAND_R, - "deployment": AgentDeployment.COHERE_PLATFORM, + "model": "command-r-plus", + "deployment": ModelDeploymentName.CoherePlatform, } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": "123"} ) - assert response.status_code == 422 + assert response.status_code == 400 + assert response.json() == {"detail": "Name, model, and deployment are required."} def test_create_agent_missing_model( @@ -69,12 +71,30 @@ def test_create_agent_missing_model( "description": "test description", "preamble": "test preamble", "temperature": 0.5, - "deployment": AgentDeployment.COHERE_PLATFORM, + "deployment": ModelDeploymentName.CoherePlatform, } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": "123"} ) - assert response.status_code == 422 + assert response.status_code == 400 + assert response.json() == {"detail": "Name, model, and deployment are required."} + + +def test_create_agent_missing_deployment( + session_client: TestClient, session: Session +) -> None: + request_json = { + "name": "test agent", + "description": "test description", + "preamble": "test preamble", + "temperature": 0.5, + "model": "command-r-plus", + } + response = session_client.post( + "/v1/agents", json=request_json, headers={"User-Id": "123"} + ) + assert response.status_code == 400 + assert response.json() == {"detail": "Name, model, and deployment are required."} def test_create_agent_missing_user_id_header( @@ -82,8 +102,8 @@ def test_create_agent_missing_user_id_header( ) -> None: request_json = { "name": "test agent", - "model": AgentModel.COMMAND_R, - "deployment": AgentDeployment.COHERE_PLATFORM, + "model": "command-r-plus", + "deployment": ModelDeploymentName.CoherePlatform, } response = session_client.post("/v1/agents", json=request_json) assert response.status_code == 401 @@ -94,11 +114,10 @@ def test_create_agent_missing_non_required_fields( ) -> None: request_json = { "name": "test agent", - "model": AgentModel.COMMAND_R, + "model": "command-r-plus", + "deployment": ModelDeploymentName.CoherePlatform, } - print(request_json) - response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": "123"} ) @@ -122,7 +141,7 @@ def test_create_agent_missing_non_required_fields( assert agent.model == request_json["model"] -def test_create_agent_wrong_model_deployment_enums( +def test_create_agent_invalid_deployment( session_client: TestClient, session: Session ) -> None: request_json = { @@ -131,30 +150,34 @@ def test_create_agent_wrong_model_deployment_enums( "description": "test description", "preamble": "test preamble", "temperature": 0.5, - "model": "not a real model", + "model": "command-r-plus", "deployment": "not a real deployment", } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": "123"} ) - assert response.status_code == 422 + assert response.status_code == 400 + assert response.json() == { + "detail": "Deployment not a real deployment not found or is not available." + } -def test_create_agent_wrong_tool_name_enums( +def test_create_agent_invalid_tool( session_client: TestClient, session: Session ) -> None: request_json = { "name": "test agent", - "model": AgentModel.COMMAND_R, - "deployment": AgentDeployment.COHERE_PLATFORM, - "tools": ["not a real tool"], + "model": "command-r-plus", + "deployment": ModelDeploymentName.CoherePlatform, + "tools": [ToolName.Calculator, "not a real tool"], } response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": "123"} ) - assert response.status_code == 422 + assert response.status_code == 400 + assert response.json() == {"detail": "Tool not a real tool not found."} def test_list_agents_empty(session_client: TestClient, session: Session) -> None: @@ -217,8 +240,8 @@ def test_update_agent(session_client: TestClient, session: Session) -> None: description="test description", preamble="test preamble", temperature=0.5, - model=AgentModel.COMMAND_R, - deployment=AgentDeployment.COHERE_PLATFORM, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, ) request_json = { @@ -227,13 +250,16 @@ def test_update_agent(session_client: TestClient, session: Session) -> None: "description": "updated description", "preamble": "updated preamble", "temperature": 0.7, - "model": AgentModel.COMMAND_R_PLUS, - "deployment": AgentDeployment.SAGE_MAKER, + "model": "command-r", + "deployment": ModelDeploymentName.CoherePlatform, } response = session_client.put( f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"} ) + + print("DEBUGGG") + print(response.json()) assert response.status_code == 200 updated_agent = response.json() assert updated_agent["name"] == "updated name" @@ -241,8 +267,8 @@ def test_update_agent(session_client: TestClient, session: Session) -> None: assert updated_agent["description"] == "updated description" assert updated_agent["preamble"] == "updated preamble" assert updated_agent["temperature"] == 0.7 - assert updated_agent["model"] == AgentModel.COMMAND_R_PLUS - assert updated_agent["deployment"] == AgentDeployment.SAGE_MAKER + assert updated_agent["model"] == "command-r" + assert updated_agent["deployment"] == ModelDeploymentName.CoherePlatform def test_partial_update_agent(session_client: TestClient, session: Session) -> None: @@ -252,12 +278,14 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N description="test description", preamble="test preamble", temperature=0.5, - model=AgentModel.COMMAND_R, - deployment=AgentDeployment.COHERE_PLATFORM, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, + tools=[ToolName.Calculator], ) request_json = { "name": "updated name", + "tools": [ToolName.Search_File, ToolName.Read_File], } response = session_client.put( @@ -270,8 +298,9 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N assert updated_agent["description"] == "test description" assert updated_agent["preamble"] == "test preamble" assert updated_agent["temperature"] == 0.5 - assert updated_agent["model"] == AgentModel.COMMAND_R - assert updated_agent["deployment"] == AgentDeployment.COHERE_PLATFORM + assert updated_agent["model"] == "command-r-plus" + assert updated_agent["deployment"] == ModelDeploymentName.CoherePlatform + assert updated_agent["tools"] == [ToolName.Search_File, ToolName.Read_File] def test_update_nonexistent_agent(session_client: TestClient, session: Session) -> None: @@ -285,7 +314,113 @@ def test_update_nonexistent_agent(session_client: TestClient, session: Session) assert response.json() == {"detail": "Agent with ID: 456 not found."} -def test_update_agent_wrong_model_deployment_enums( +def test_update_agent_invalid_model( + session_client: TestClient, session: Session +) -> None: + agent = get_factory("Agent", session).create( + name="test agent", + version=1, + description="test description", + preamble="test preamble", + temperature=0.5, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, + ) + + request_json = { + "model": "not a real model", + "deployment": ModelDeploymentName.CoherePlatform, + } + + response = session_client.put( + f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"} + ) + assert response.status_code == 400 + assert response.json() == { + "detail": "Model not a real model not found for deployment Cohere Platform." + } + + +def test_update_agent_invalid_deployment( + session_client: TestClient, session: Session +) -> None: + agent = get_factory("Agent", session).create( + name="test agent", + version=1, + description="test description", + preamble="test preamble", + temperature=0.5, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, + ) + + request_json = { + "model": "command-r", + "deployment": "not a real deployment", + } + + response = session_client.put( + f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"} + ) + assert response.status_code == 400 + assert response.json() == { + "detail": "Deployment not a real deployment not found or is not available." + } + + +def test_update_agent_model_without_deployment( + session_client: TestClient, session: Session +) -> None: + agent = get_factory("Agent", session).create( + name="test agent", + version=1, + description="test description", + preamble="test preamble", + temperature=0.5, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, + ) + + request_json = { + "model": "command-r", + } + + response = session_client.put( + f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"} + ) + assert response.status_code == 400 + assert response.json() == { + "detail": "If updating an agent's model, the deployment must also be provided." + } + + +def test_update_agent_deployment_without_model( + session_client: TestClient, session: Session +) -> None: + agent = get_factory("Agent", session).create( + name="test agent", + version=1, + description="test description", + preamble="test preamble", + temperature=0.5, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, + ) + + request_json = { + "deployment": ModelDeploymentName.CoherePlatform, + } + + response = session_client.put( + f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"} + ) + assert response.status_code == 400 + assert response.json() == { + "detail": "If updating an agent's deployment type, the model must also be provided." + } + + +def test_update_agent_invalid_tool( session_client: TestClient, session: Session ) -> None: agent = get_factory("Agent", session).create( @@ -294,19 +429,21 @@ def test_update_agent_wrong_model_deployment_enums( description="test description", preamble="test preamble", temperature=0.5, - model=AgentModel.COMMAND_R, - deployment=AgentDeployment.COHERE_PLATFORM, + model="command-r-plus", + deployment=ModelDeploymentName.CoherePlatform, ) request_json = { "model": "not a real model", "deployment": "not a real deployment", + "tools": [ToolName.Calculator, "not a real tool"], } response = session_client.put( f"/v1/agents/{agent.id}", json=request_json, headers={"User-Id": "123"} ) - assert response.status_code == 422 + assert response.status_code == 400 + assert response.json() == {"detail": "Tool not a real tool not found."} def test_delete_agent(session_client: TestClient, session: Session) -> None: diff --git a/src/interfaces/coral_web/src/cohere-client/generated/index.ts b/src/interfaces/coral_web/src/cohere-client/generated/index.ts index 7c8cd12d17..5b99db276d 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/index.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/index.ts @@ -8,8 +8,6 @@ export { OpenAPI } from './core/OpenAPI'; export type { OpenAPIConfig } from './core/OpenAPI'; export type { Agent } from './models/Agent'; -export { AgentDeployment } from './models/AgentDeployment'; -export { AgentModel } from './models/AgentModel'; export type { Body_upload_file_v1_conversations_upload_file_post } from './models/Body_upload_file_v1_conversations_upload_file_post'; export { Category } from './models/Category'; export type { ChatMessage } from './models/ChatMessage'; @@ -55,7 +53,6 @@ export type { StreamToolResult } from './models/StreamToolResult'; export type { Tool } from './models/Tool'; export type { ToolCall } from './models/ToolCall'; export { ToolInputType } from './models/ToolInputType'; -export { ToolName } from './models/ToolName'; export type { UpdateAgent } from './models/UpdateAgent'; export type { UpdateConversation } from './models/UpdateConversation'; export type { UpdateDeploymentEnv } from './models/UpdateDeploymentEnv'; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts index 9cd2f76beb..eb21166bc6 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/Agent.ts @@ -1,14 +1,7 @@ /* generated using openapi-typescript-codegen -- do no edit */ - /* istanbul ignore file */ - /* tslint:disable */ - /* eslint-disable */ -import type { AgentDeployment } from './AgentDeployment'; -import type { AgentModel } from './AgentModel'; -import type { ToolName } from './ToolName'; - export type Agent = { user_id: string; id: string; @@ -19,7 +12,7 @@ export type Agent = { description: string | null; preamble: string | null; temperature: number; - tools: Array; - model: AgentModel; - deployment: AgentDeployment; + tools: Array; + model: string; + deployment: string; }; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/AgentDeployment.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/AgentDeployment.ts deleted file mode 100644 index 78d833a765..0000000000 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/AgentDeployment.ts +++ /dev/null @@ -1,10 +0,0 @@ -/* generated using openapi-typescript-codegen -- do no edit */ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ -export enum AgentDeployment { - COHERE_PLATFORM = 'Cohere Platform', - SAGE_MAKER = 'SageMaker', - AZURE = 'Azure', - BEDROCK = 'Bedrock', -} diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/AgentModel.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/AgentModel.ts deleted file mode 100644 index 8e0b56d0f7..0000000000 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/AgentModel.ts +++ /dev/null @@ -1,10 +0,0 @@ -/* generated using openapi-typescript-codegen -- do no edit */ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ -export enum AgentModel { - COMMAND_R = 'command-r', - COMMAND_R_PLUS = 'command-r-plus', - COMMAND_LIGHT = 'command-light', - COMMAND = 'command', -} diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts index 13870aba77..45566c4ff1 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/CreateAgent.ts @@ -1,21 +1,14 @@ /* generated using openapi-typescript-codegen -- do no edit */ - /* istanbul ignore file */ - /* tslint:disable */ - /* eslint-disable */ -import type { AgentDeployment } from './AgentDeployment'; -import type { AgentModel } from './AgentModel'; -import type { ToolName } from './ToolName'; - export type CreateAgent = { name: string; version?: number | null; description?: string | null; preamble?: string | null; temperature?: number | null; - model: AgentModel; - deployment?: AgentDeployment | null; - tools?: Array | null; + model: string; + deployment: string; + tools?: Array | null; }; diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts deleted file mode 100644 index 14b40f9d5e..0000000000 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/ToolName.ts +++ /dev/null @@ -1,12 +0,0 @@ -/* generated using openapi-typescript-codegen -- do no edit */ -/* istanbul ignore file */ -/* tslint:disable */ -/* eslint-disable */ -export enum ToolName { - WIKIPEDIA = 'Wikipedia', - SEARCH_FILE = 'search_file', - READ_DOCUMENT = 'read_document', - PYTHON_INTERPRETER = 'Python_Interpreter', - CALCULATOR = 'Calculator', - INTERNET_SEARCH = 'Internet_Search', -} diff --git a/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts b/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts index 7462d81c05..d71d5a7161 100644 --- a/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts +++ b/src/interfaces/coral_web/src/cohere-client/generated/models/UpdateAgent.ts @@ -1,21 +1,14 @@ /* generated using openapi-typescript-codegen -- do no edit */ - /* istanbul ignore file */ - /* tslint:disable */ - /* eslint-disable */ -import type { AgentDeployment } from './AgentDeployment'; -import type { AgentModel } from './AgentModel'; -import type { ToolName } from './ToolName'; - export type UpdateAgent = { name?: string | null; version?: number | null; description?: string | null; preamble?: string | null; temperature?: number | null; - model?: AgentModel | null; - deployment?: AgentDeployment | null; - tools?: Array | null; + model?: string | null; + deployment?: string | null; + tools?: Array | null; }; From 916b83e050a47a094130897a0680fb4e2e801aa4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lu=C3=ADsa=20Moura?= Date: Mon, 17 Jun 2024 11:21:04 -0400 Subject: [PATCH 2/4] Chat: add multistep (#206) * Chat: add multistep * fix type error * add comments * merge events * improve code * fix non streaming chat and refactor * fix custom tools * fix tests * fix chat history * fix more tests * fix rerank tests * fix chat history bugs * comment --- poetry.lock | 4 +- pyproject.toml | 2 +- src/backend/chat/collate.py | 14 +- src/backend/chat/custom/custom.py | 308 +++++++----- src/backend/chat/enums.py | 1 + src/backend/model_deployments/azure.py | 14 +- src/backend/model_deployments/base.py | 10 - src/backend/model_deployments/bedrock.py | 3 +- .../model_deployments/cohere_platform.py | 26 +- src/backend/schemas/chat.py | 61 ++- src/backend/schemas/cohere_chat.py | 4 + src/backend/schemas/tool.py | 6 + src/backend/services/chat.py | 451 +++++++++++------- .../mock_deployments/mock_azure.py | 40 +- .../mock_deployments/mock_bedrock.py | 12 +- .../mock_deployments/mock_cohere_platform.py | 12 +- .../mock_deployments/mock_sagemaker.py | 12 +- .../tests/model_deployments/test_azure.py | 24 - .../tests/model_deployments/test_bedrock.py | 24 - .../model_deployments/test_cohere_platform.py | 24 - .../tests/model_deployments/test_sagemaker.py | 24 - src/backend/tests/tools/test_collate.py | 51 +- 22 files changed, 638 insertions(+), 489 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9c3c3c27c5..7a10dc2be8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -5639,4 +5639,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = "~3.11" -content-hash = "1eed87886503afb8de203b8f97dc54116b1a48bb67ee1336fd8c80909d2de127" +content-hash = "66aee79e01207051434dd576d1ef5c9b5c91e2f5cc92a1827c6328ef6139b2c7" diff --git a/pyproject.toml b/pyproject.toml index 8e1d8da6a5..d65a16cc9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ sse-starlette = "^2.0.0" boto3 = "^1.0.0" httpx = "^0.27.0" chromadb = "^0.4.16" -cohere = "^5.5.6" +cohere = "^5.5.7" llama-index = "^0.10.11" inquirer = "^3.2.4" langchain-community = "^0.0.32" diff --git a/src/backend/chat/collate.py b/src/backend/chat/collate.py index a25f7494e0..9c98c0c1ca 100644 --- a/src/backend/chat/collate.py +++ b/src/backend/chat/collate.py @@ -1,9 +1,9 @@ -from itertools import zip_longest +import json from typing import Any, Dict, List from backend.model_deployments.base import BaseDeployment -RELEVANCE_THRESHOLD = 0.5 +RELEVANCE_THRESHOLD = 0.3 def rerank_and_chunk( @@ -45,9 +45,9 @@ def rerank_and_chunk( reranked_results = {} for tool_call_hashable, tool_result in unified_tool_results.items(): tool_call = tool_result["call"] - query = tool_call.parameters.get("query") or tool_call.parameters.get( - "search_query" - ) + query = tool_call.get("parameters").get("query") or tool_call.get( + "parameters" + ).get("search_query") # Only rerank if there is a query if not query: @@ -122,3 +122,7 @@ def chunk(content, compact_mode=False, soft_word_cut_off=100, hard_word_cut_off= chunks.append(current_chunk.strip()) return chunks + + +def to_dict(obj): + return json.loads(json.dumps(obj, default=lambda o: o.__dict__)) diff --git a/src/backend/chat/custom/custom.py b/src/backend/chat/custom/custom.py index 470c2c1d4d..4c42bf334c 100644 --- a/src/backend/chat/custom/custom.py +++ b/src/backend/chat/custom/custom.py @@ -5,18 +5,18 @@ from fastapi import HTTPException from backend.chat.base import BaseChat -from backend.chat.collate import rerank_and_chunk +from backend.chat.collate import rerank_and_chunk, to_dict from backend.chat.custom.utils import get_deployment from backend.chat.enums import StreamEvent from backend.config.tools import AVAILABLE_TOOLS, ToolName from backend.crud.file import get_files_by_conversation_id -from backend.model_deployments.base import BaseDeployment from backend.schemas.chat import ChatMessage from backend.schemas.cohere_chat import CohereChatRequest -from backend.schemas.tool import Category, Tool +from backend.schemas.tool import Tool from backend.services.logger import get_logger logger = get_logger() +MAX_STEPS = 15 class CustomChat(BaseChat): @@ -42,151 +42,209 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: status_code=400, detail="Both tools and documents cannot be provided." ) - # If a direct answer is generated instead of tool calls, the chat will not be called again - # Instead, the direct answer will be returned from the stream - stream = self.handle_managed_tools(chat_request, deployment_model, **kwargs) + self.chat_request = chat_request + self.is_first_start = True + should_break = False - first_event, generated_direct_answer = next(stream) + for step in range(MAX_STEPS): + logger.info(f"Step {step + 1}") + stream = self.call_chat(self.chat_request, deployment_model, **kwargs) - if generated_direct_answer: - yield first_event - for event, _ in stream: - yield event - else: - chat_request = first_event - invoke_method = ( - deployment_model.invoke_chat_stream - if kwargs.get("stream", True) - else deployment_model.invoke_chat - ) + for event in stream: + result = self.handle_event(event, chat_request) - yield from invoke_method(chat_request) + if result: + yield result - def handle_managed_tools( - self, - chat_request: CohereChatRequest, - deployment_model: BaseDeployment, - **kwargs: Any, - ) -> Generator[Any, None, None]: - """ - This function handles the managed tools. + if event[ + "event_type" + ] == StreamEvent.STREAM_END and self.is_final_event( + event, chat_request + ): + should_break = True + break - Args: - chat_request (CohereChatRequest): The chat request - deployment_model (BaseDeployment): The deployment model - **kwargs (Any): The keyword arguments + if should_break: + break - Returns: - Generator[Any, None, None]: The tool results or the chat response, and a boolean indicating if a direct answer was generated - """ - tools = [ - Tool(**AVAILABLE_TOOLS.get(tool.name).model_dump()) - for tool in chat_request.tools - if AVAILABLE_TOOLS.get(tool.name) - ] + def is_final_event( + self, event: Dict[str, Any], chat_request: CohereChatRequest + ) -> bool: + # The event is final if: + # 1. It is a stream end event with no tool calls - direct answer + # 2. It is a stream end event with tool calls, but no managed tools - tool calls generation only + if "response" in event: + response = event["response"] + else: + return True - if not tools: - yield chat_request, False + return not ("tool_calls" in response and response["tool_calls"]) or ( + "tool_calls" in response + and response["tool_calls"] + and chat_request.tools + and not self.get_managed_tools(self.chat_request) + ) - for event, should_return in self.get_tool_results( - chat_request.message, - chat_request.chat_history, - tools, - kwargs.get("conversation_id"), - deployment_model, - kwargs, + def handle_event( + self, event: Dict[str, Any], chat_request: CohereChatRequest + ) -> Dict[str, Any]: + # All events other than stream start and stream end are returned + if ( + event["event_type"] != StreamEvent.STREAM_START + and event["event_type"] != StreamEvent.STREAM_END ): - if should_return: - yield event, True - else: - chat_request.tool_results = event - chat_request.tools = tools - yield chat_request, False - - def get_tool_results( - self, - message: str, - chat_history: List[Dict[str, str]], - tools: list[Tool], - conversation_id: str, - deployment_model: BaseDeployment, - kwargs: Any, - ) -> Any: - """ - Invokes the tools and returns the results. If no tools calls are generated, it returns the chat response - as a direct answer. + return event + + # Only the first occurrence of stream start is returned + if event["event_type"] == StreamEvent.STREAM_START: + if self.is_first_start: + self.is_first_start = False + return event + + # Only the final occurrence of stream end is returned + # The final event is the one that does not contain tool calls + if event["event_type"] == StreamEvent.STREAM_END: + if self.is_final_event(event, chat_request): + return event + + return None + + def is_not_direct_answer(self, event: Dict[str, Any]) -> bool: + # If the event contains tool calls, it is not a direct answer + return ( + event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION + and "tool_calls" in event + ) - Args: - message (str): The message to be processed - chat_history (List[Dict[str, str]]): The chat history - tools (list[Tool]): The tools to be invoked - conversation_id (str): The conversation ID - deployment_model (BaseDeployment): The deployment model - kwargs (Any): The keyword arguments + def call_chat(self, chat_request, deployment_model, **kwargs: Any): + managed_tools = self.get_managed_tools(chat_request) - Returns: - Any: The tool results or the chat response, and a boolean indicating if a direct answer was generated + # If tools are managed and not zero shot tools, replace the tools in the chat request + if len(managed_tools) == len(chat_request.tools): + chat_request.tools = managed_tools - """ + # Get the tool calls stream and either return a direct answer or continue + tool_calls_stream = self.get_tool_calls( + managed_tools, chat_request.chat_history, deployment_model, **kwargs + ) + is_direct_answer, new_chat_history, stream = self.handle_tool_calls_stream( + tool_calls_stream + ) + + for event in stream: + yield event + + if is_direct_answer: + return + + # If the stream contains tool calls, call the tools and update the chat history + tool_results = self.call_tools(new_chat_history, deployment_model, **kwargs) + chat_request.tool_results = [result for result in tool_results] + chat_request.chat_history = new_chat_history + + # Remove the message if tool results are present + if tool_results: + chat_request.message = "" + + for event in deployment_model.invoke_chat_stream(chat_request): + if event["event_type"] != StreamEvent.STREAM_START: + yield event + if event["event_type"] == StreamEvent.STREAM_END: + chat_request.chat_history = event["response"].get("chat_history", []) + + # Update the chat request and restore the message + self.chat_request = chat_request + + def call_tools(self, chat_history, deployment_model, **kwargs: Any): tool_results = [] + if not hasattr(chat_history[-1], "tool_results"): + logging.warning("No tool calls found in chat history.") + return tool_results + + tool_calls = chat_history[-1].tool_calls + logger.info(f"Tool calls: {tool_calls}") + + # TODO: Call tools in parallel + for tool_call in tool_calls: + tool = AVAILABLE_TOOLS.get(tool_call["name"]) + if not tool: + logging.warning(f"Couldn't find tool {tool_call['name']}") + continue + + outputs = tool.implementation().call( + parameters=tool_call.get("parameters"), + session=kwargs.get("session"), + model_deployment=deployment_model, + user_id=kwargs.get("user_id"), + ) + + # If the tool returns a list of outputs, append each output to the tool_results list + # Otherwise, append the single output to the tool_results list + outputs = outputs if isinstance(outputs, list) else [outputs] + for output in outputs: + tool_results.append({"call": tool_call, "outputs": [output]}) + + tool_results = rerank_and_chunk(tool_results, deployment_model) + return tool_results + + def handle_tool_calls_stream(self, tool_results_stream): + # Process the stream and return the chat history, and a copy of the stream and a flag indicating if the response is a direct answer + stream, stream_copy = tee(tool_results_stream) + is_direct_answer = True + + chat_history = [] + for event in stream: + if event["event_type"] == StreamEvent.STREAM_END: + stream_chat_history = [] + if "response" in event: + stream_chat_history = event["response"].get("chat_history", []) + elif "chat_history" in event: + stream_chat_history = event["chat_history"] + + for message in stream_chat_history: + if not isinstance(message, dict): + message = to_dict(message) + + chat_history.append( + ChatMessage( + role=message.get("role"), + message=message.get("message", ""), + tool_results=message.get("tool_results", None), + tool_calls=message.get("tool_calls", None), + ) + ) + + elif ( + event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION + and "tool_calls" in event + ): + is_direct_answer = False - # If the tool is Read_File or SearchFile, add the available files to the chat history - # so that the model knows what files are available + return is_direct_answer, chat_history, stream_copy + + def get_managed_tools(self, chat_request: CohereChatRequest): + return [ + Tool(**AVAILABLE_TOOLS.get(tool.name).model_dump()) + for tool in chat_request.tools + if AVAILABLE_TOOLS.get(tool.name) + ] + + def get_tool_calls(self, tools, chat_history, deployment_model, **kwargs: Any): + # If the chat history contains a read or search file tool, add the files to the chat history tool_names = [tool.name for tool in tools] if ToolName.Read_File in tool_names or ToolName.Search_File in tool_names: chat_history = self.add_files_to_chat_history( chat_history, - conversation_id, + kwargs.get("conversation_id"), kwargs.get("session"), kwargs.get("user_id"), ) + self.chat_request.chat_history = chat_history - logger.info(f"Invoking tools: {tools}") - stream = deployment_model.invoke_tools( - message, tools, chat_history=chat_history - ) - - # Invoke tools can return a direct answer or a stream of events with the tool calls - # If one of the events is a tool call generation, the tools are invoked, and the results are returned - # Otherwise, the chat response is returned as a direct answer - stream, stream_copy = tee(stream) - - tool_call_found = False - for event in stream: - if event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION: - tool_call_found = True - tool_calls = event["tool_calls"] - - logger.info(f"Tool calls: {tool_calls}") - - # TODO: parallelize tool calls - for tool_call in tool_calls: - tool = AVAILABLE_TOOLS.get(tool_call.name) - if not tool: - logging.warning(f"Couldn't find tool {tool_call.name}") - continue - - outputs = tool.implementation().call( - parameters=tool_call.parameters, - session=kwargs.get("session"), - model_deployment=deployment_model, - user_id=kwargs.get("user_id"), - ) - - # If the tool returns a list of outputs, append each output to the tool_results list - # Otherwise, append the single output to the tool_results list - outputs = outputs if isinstance(outputs, list) else [outputs] - for output in outputs: - tool_results.append({"call": tool_call, "outputs": [output]}) - - tool_results = rerank_and_chunk(tool_results, deployment_model) - logger.info(f"Tool results: {tool_results}") - yield tool_results, False - break + logger.info(f"Available tools: {tools}") + stream = deployment_model.invoke_chat_stream(self.chat_request) - if not tool_call_found: - for event in stream_copy: - yield event, True + return stream def add_files_to_chat_history( self, diff --git a/src/backend/chat/enums.py b/src/backend/chat/enums.py index a95ea884f3..01715f0989 100644 --- a/src/backend/chat/enums.py +++ b/src/backend/chat/enums.py @@ -16,3 +16,4 @@ class StreamEvent(str, Enum): STREAM_END = "stream-end" NON_STREAMED_CHAT_RESPONSE = "non-streamed-chat-response" TOOL_CALLS_GENERATION = "tool-calls-generation" + TOOL_CALLS_CHUNK = "tool-calls-chunk" diff --git a/src/backend/model_deployments/azure.py b/src/backend/model_deployments/azure.py index a3b78e0dff..7da9a57f98 100644 --- a/src/backend/model_deployments/azure.py +++ b/src/backend/model_deployments/azure.py @@ -4,6 +4,7 @@ import cohere from cohere.types import StreamedChatResponse +from backend.chat.collate import to_dict from backend.model_deployments.base import BaseDeployment from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest @@ -65,7 +66,7 @@ def invoke_chat_stream( **kwargs, ) for event in stream: - yield event.__dict__ + yield to_dict(event) def invoke_search_queries( self, @@ -92,14 +93,7 @@ def invoke_rerank( def invoke_tools( self, - message: str, - tools: List[Any], - chat_history: List[Dict[str, str]] | None = None, + chat_request: CohereChatRequest, **kwargs: Any, ) -> Generator[StreamedChatResponse, None, None]: - stream = self.client.chat_stream( - message=message, tools=tools, chat_history=chat_history, **kwargs - ) - - for event in stream: - yield event.__dict__ + yield from self.invoke_chat_stream(chat_request, **kwargs) diff --git a/src/backend/model_deployments/base.py b/src/backend/model_deployments/base.py index f44e2bbc13..61f570dfc3 100644 --- a/src/backend/model_deployments/base.py +++ b/src/backend/model_deployments/base.py @@ -13,7 +13,6 @@ class BaseDeployment: invoke_chat_stream: Generator[StreamedChatResponse, None, None]: Invoke the chat stream. invoke_search_queries: list[str]: Invoke the search queries. invoke_rerank: Any: Invoke the rerank. - invoke_tools: Any: Invoke the tools. list_models: List[str]: List all models. is_available: bool: Check if the deployment is available. """ @@ -48,12 +47,3 @@ def invoke_search_queries( def invoke_rerank( self, query: str, documents: List[Dict[str, Any]], **kwargs: Any ) -> Any: ... - - @abstractmethod - def invoke_tools( - self, - message: str, - tools: List[Any], - chat_history: List[Dict[str, str]] | None = None, - **kwargs: Any - ) -> Generator[StreamedChatResponse, None, None]: ... diff --git a/src/backend/model_deployments/bedrock.py b/src/backend/model_deployments/bedrock.py index b2dbb03f38..a14910186f 100644 --- a/src/backend/model_deployments/bedrock.py +++ b/src/backend/model_deployments/bedrock.py @@ -4,6 +4,7 @@ import cohere from cohere.types import StreamedChatResponse +from backend.chat.collate import to_dict from backend.model_deployments.base import BaseDeployment from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest @@ -76,7 +77,7 @@ def invoke_chat_stream( **kwargs, ) for event in stream: - yield event.__dict__ + yield to_dict(event) def invoke_search_queries( self, diff --git a/src/backend/model_deployments/cohere_platform.py b/src/backend/model_deployments/cohere_platform.py index f2f3ba4145..24fdfc7748 100644 --- a/src/backend/model_deployments/cohere_platform.py +++ b/src/backend/model_deployments/cohere_platform.py @@ -1,3 +1,4 @@ +import json import logging import os from typing import Any, Dict, Generator, List @@ -6,6 +7,8 @@ import requests from cohere.types import StreamedChatResponse +from backend.chat.collate import to_dict +from backend.chat.enums import StreamEvent from backend.model_deployments.base import BaseDeployment from backend.model_deployments.utils import get_model_config_var from backend.schemas.cohere_chat import CohereChatRequest @@ -57,22 +60,21 @@ def is_available(cls) -> bool: return all([os.environ.get(var) is not None for var in COHERE_ENV_VARS]) def invoke_chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any: - yield self.client.chat( + response = self.client.chat( **chat_request.model_dump(exclude={"stream"}), - force_single_step=True, **kwargs, ) + yield to_dict(response) def invoke_chat_stream( self, chat_request: CohereChatRequest, **kwargs: Any ) -> Generator[StreamedChatResponse, None, None]: stream = self.client.chat_stream( **chat_request.model_dump(exclude={"stream", "file_ids"}), - force_single_step=True, **kwargs, ) for event in stream: - yield event.__dict__ + yield to_dict(event) def invoke_search_queries( self, @@ -101,19 +103,7 @@ def invoke_rerank( def invoke_tools( self, - message: str, - tools: List[Any], - chat_history: List[Dict[str, str]] | None = None, + chat_request: CohereChatRequest, **kwargs: Any, ) -> Generator[StreamedChatResponse, None, None]: - stream = self.client.chat_stream( - message=message, - tools=tools, - model="command-r", - force_single_step=True, - chat_history=chat_history, - **kwargs, - ) - - for event in stream: - yield event.__dict__ + yield from self.invoke_chat_stream(chat_request, **kwargs) diff --git a/src/backend/schemas/chat.py b/src/backend/schemas/chat.py index 1bfe147196..906a01aa18 100644 --- a/src/backend/schemas/chat.py +++ b/src/backend/schemas/chat.py @@ -8,7 +8,7 @@ from backend.schemas.citation import Citation from backend.schemas.document import Document from backend.schemas.search_query import SearchQuery -from backend.schemas.tool import Tool, ToolCall +from backend.schemas.tool import Tool, ToolCall, ToolCallDelta class ChatRole(StrEnum): @@ -45,7 +45,11 @@ class ChatMessage(BaseModel): ) tool_results: List[Dict[str, Any]] | None = Field( title="Results from the tool call.", - default=[], + default=None, + ) + tool_calls: List[Dict[str, Any]] | None = Field( + title="List of tool calls generated for custom tools", + default=None, ) def to_dict(self) -> Dict[str, str]: @@ -143,11 +147,15 @@ class StreamToolCallsGeneration(ChatResponse): event_type: ClassVar[StreamEvent] = StreamEvent.TOOL_CALLS_GENERATION - tool_calls: List[ToolCall] = Field( + tool_calls: List[ToolCall] | None = Field( title="List of tool calls generated for custom tools", default=[], ) + text: str | None = Field( + title="Contents of the chat message.", + ) + class StreamEnd(ChatResponse): response_id: str | None = Field(default=None) @@ -220,24 +228,45 @@ class NonStreamedChatResponse(ChatResponse): ) +class StreamToolCallsChunk(ChatResponse): + event_type: ClassVar[StreamEvent] = StreamEvent.TOOL_CALLS_CHUNK + + tool_call_delta: ToolCallDelta | None = Field( + title="Partial tool call", + default=ToolCallDelta( + name=None, + index=None, + parameters=None, + ), + ) + + text: str | None = Field( + title="Contents of the chat message.", + ) + + +StreamEventType = Union[ + StreamStart, + StreamTextGeneration, + StreamCitationGeneration, + StreamQueryGeneration, + StreamSearchResults, + StreamEnd, + StreamToolInput, + StreamToolResult, + StreamSearchQueriesGeneration, + StreamToolCallsGeneration, + StreamToolCallsChunk, + NonStreamedChatResponse, +] + + class ChatResponseEvent(BaseModel): event: StreamEvent = Field( title="type of stream event", ) - data: Union[ - StreamStart, - StreamTextGeneration, - StreamCitationGeneration, - StreamQueryGeneration, - StreamSearchResults, - StreamEnd, - StreamToolInput, - StreamToolResult, - StreamSearchQueriesGeneration, - StreamToolCallsGeneration, - NonStreamedChatResponse, - ] = Field( + data: StreamEventType = Field( title="Data returned from chat response of a given event type", ) diff --git a/src/backend/schemas/cohere_chat.py b/src/backend/schemas/cohere_chat.py index 9051dbbadf..c6e3aefe9e 100644 --- a/src/backend/schemas/cohere_chat.py +++ b/src/backend/schemas/cohere_chat.py @@ -105,3 +105,7 @@ class CohereChatRequest(BaseChatRequest): default=None, title="A list of results from invoking tools recommended by the model in the previous chat turn. Results are used to produce a text response and will be referenced in citations.", ) + force_single_step: bool | None = Field( + default=None, + title="If set to true, the model will generate a single response in a single step. This is useful for generating a response to a single message.", + ) diff --git a/src/backend/schemas/tool.py b/src/backend/schemas/tool.py index f99956a314..29f9870738 100644 --- a/src/backend/schemas/tool.py +++ b/src/backend/schemas/tool.py @@ -35,3 +35,9 @@ class Config: class ToolCall(BaseModel): name: str parameters: dict = {} + + +class ToolCallDelta(BaseModel): + name: str | None + index: int | None + parameters: str | None diff --git a/src/backend/services/chat.py b/src/backend/services/chat.py index c72b22cc9e..3084969985 100644 --- a/src/backend/services/chat.py +++ b/src/backend/services/chat.py @@ -1,4 +1,5 @@ import json +import logging from typing import Any, Generator, List, Union from uuid import uuid4 @@ -26,10 +27,12 @@ NonStreamedChatResponse, StreamCitationGeneration, StreamEnd, + StreamEventType, StreamSearchQueriesGeneration, StreamSearchResults, StreamStart, StreamTextGeneration, + StreamToolCallsChunk, StreamToolCallsGeneration, StreamToolInput, StreamToolResult, @@ -39,7 +42,7 @@ from backend.schemas.conversation import UpdateConversation from backend.schemas.file import UpdateFile from backend.schemas.search_query import SearchQuery -from backend.schemas.tool import ToolCall +from backend.schemas.tool import ToolCall, ToolCallDelta from backend.services.auth.utils import get_header_user_id @@ -329,8 +332,14 @@ def create_chat_history( if chat_request.chat_history is not None: return chat_request.chat_history + if conversation.messages is None: + return [] + + # Don't include the user message that was just sent text_messages = [ - message for message in conversation.messages[:user_message_position] + message + for message in conversation.messages + if message.position < user_message_position ] return [ ChatMessage( @@ -368,6 +377,65 @@ def update_conversation_after_turn( conversation_crud.update_conversation(session, conversation, new_conversation) +def generate_chat_response( + session: DBSessionDep, + model_deployment_stream: Generator[StreamedChatResponse, None, None], + response_message: Message, + conversation_id: str, + user_id: str, + should_store: bool = True, + **kwargs: Any, +) -> NonStreamedChatResponse: + """ + Generate chat response from model deployment non streaming response. + Use the stream to generate the response and all the intermediate steps, then + return only the final step as a non-streamed response. + + Args: + session (DBSessionDep): Database session. + model_deployment_stream (Generator[StreamResponse, None, None]): Model deployment stream. + response_message (Message): Response message object. + conversation_id (str): Conversation ID. + user_id (str): User ID. + should_store (bool): Whether to store the conversation in the database. + **kwargs (Any): Additional keyword arguments. + + Yields: + bytes: Byte representation of chat response event. + """ + stream = generate_chat_stream( + session, + model_deployment_stream, + response_message, + conversation_id, + user_id, + should_store, + **kwargs, + ) + + non_streamed_chat_response = None + for event in stream: + event = json.loads(event) + if event["event"] == StreamEvent.STREAM_END: + data = event["data"] + non_streamed_chat_response = NonStreamedChatResponse( + text=data.get("text", ""), + response_id=response_message.id, + generation_id=response_message.generation_id, + chat_history=data.get("chat_history", []), + finish_reason=data.get("finish_reason", ""), + citations=data.get("citations", []), + search_queries=data.get("search_queries", []), + documents=data.get("documents", []), + search_results=data.get("search_results", []), + event_type=StreamEvent.NON_STREAMED_CHAT_RESPONSE, + conversation_id=conversation_id, + tool_calls=data.get("tool_calls", []), + ) + + return non_streamed_chat_response + + def generate_chat_stream( session: DBSessionDep, model_deployment_stream: Generator[StreamedChatResponse, None, None], @@ -395,107 +463,29 @@ def generate_chat_stream( stream_end_data = { "conversation_id": conversation_id, "response_id": response_message.id, + "text": "", + "citations": [], + "documents": [], + "search_results": [], + "search_queries": [], + "tool_calls": [], + "tool_results": [], } - # Given a stream of CohereEventStream objects, save the final message to DB and yield byte representations - final_message_text = "" - # Map the user facing document_ids field returned from model to storage ID for document model document_ids_to_document = {} - all_citations = [] stream_event = None for event in model_deployment_stream: - if event["event_type"] == StreamEvent.STREAM_START: - stream_event = StreamStart.model_validate(event) - response_message.generation_id = event["generation_id"] - stream_end_data["generation_id"] = event["generation_id"] - elif event["event_type"] == StreamEvent.TEXT_GENERATION: - final_message_text += event["text"] - stream_event = StreamTextGeneration.model_validate(event) - elif event["event_type"] == StreamEvent.SEARCH_RESULTS: - for document in event["documents"]: - storage_document = Document( - document_id=document.get("id", ""), - text=document.get("text", ""), - title=document.get("title", ""), - url=document.get("url", ""), - tool_name=document.get("tool_name", ""), - # all document fields except for id, tool_name and text - fields={ - k: v - for k, v in document.items() - if k not in ["id", "tool_name", "text"] - }, - user_id=response_message.user_id, - conversation_id=response_message.conversation_id, - message_id=response_message.id, - ) - document_ids_to_document[document["id"]] = storage_document - - documents = list(document_ids_to_document.values()) - response_message.documents = documents - stream_end_data["documents"] = documents - if "search_results" not in event or event["search_results"] is None: - event["search_results"] = [] - stream_event = StreamSearchResults( - **event - | { - "documents": documents, - "search_results": event["search_results"], - }, - ) - elif event["event_type"] == StreamEvent.SEARCH_QUERIES_GENERATION: - search_queries = [] - for search_query in event["search_queries"]: - search_queries.append( - SearchQuery( - text=search_query.text, - generation_id=search_query.generation_id, - ) - ) - stream_event = StreamSearchQueriesGeneration( - **event | {"search_queries": search_queries} - ) - stream_end_data["search_queries"] = search_queries - elif event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION: - tool_calls = [] - for tool_call in event["tool_calls"]: - tool_calls.append( - ToolCall( - name=tool_call.name, - parameters=tool_call.parameters, - ) - ) - stream_event = StreamToolCallsGeneration( - **event | {"tool_calls": tool_calls} + stream_event, stream_end_data, response_message, document_ids_to_document = ( + handle_stream_event( + event, + conversation_id, + stream_end_data, + response_message, + document_ids_to_document, ) - stream_end_data["tool_calls"] = tool_calls - elif event["event_type"] == StreamEvent.CITATION_GENERATION: - citations = [] - for event_citation in event["citations"]: - citation = Citation( - text=event_citation.text, - user_id=response_message.user_id, - start=event_citation.start, - end=event_citation.end, - document_ids=event_citation.document_ids, - ) - for document_id in citation.document_ids: - document = document_ids_to_document.get(document_id, None) - if document is not None: - citation.documents.append(document) - citations.append(citation) - stream_event = StreamCitationGeneration(**event | {"citations": citations}) - all_citations.extend(citations) - elif event["event_type"] == StreamEvent.STREAM_END: - response_message.citations = all_citations - response_message.text = final_message_text - - stream_end_data["citations"] = all_citations - stream_end_data["text"] = final_message_text - stream_end = StreamEnd.model_validate(event | stream_end_data) - stream_event = stream_end + ) yield json.dumps( jsonable_encoder( @@ -508,103 +498,214 @@ def generate_chat_stream( if should_store: update_conversation_after_turn( - session, response_message, conversation_id, final_message_text, user_id + session, response_message, conversation_id, stream_end_data["text"], user_id ) -def generate_chat_response( - session: DBSessionDep, - model_deployment_response: Generator[StreamedChatResponse, None, None], +def handle_stream_event( + event: dict[str, Any], + conversation_id: str, + stream_end_data: dict[str, Any], response_message: Message, + document_ids_to_document: dict[str, Document] = {}, +) -> tuple[StreamEventType, dict[str, Any], Message, dict[str, Document]]: + handlers = { + StreamEvent.STREAM_START: handle_stream_start, + StreamEvent.TEXT_GENERATION: handle_stream_text_generation, + StreamEvent.SEARCH_RESULTS: handle_stream_search_results, + StreamEvent.SEARCH_QUERIES_GENERATION: handle_stream_search_queries_generation, + StreamEvent.TOOL_CALLS_GENERATION: handle_stream_tool_calls_generation, + StreamEvent.CITATION_GENERATION: handle_stream_citation_generation, + StreamEvent.TOOL_CALLS_CHUNK: handle_stream_tool_calls_chunk, + StreamEvent.STREAM_END: handle_stream_end, + } + event_type = event["event_type"] + + if event_type not in handlers.keys(): + logging.warning(f"Event type {event_type} not supported") + return None, stream_end_data, response_message, document_ids_to_document + + return handlers[event_type]( + event, + conversation_id, + stream_end_data, + response_message, + document_ids_to_document, + ) + + +def handle_stream_start( + event: dict[str, Any], conversation_id: str, - user_id: str, - should_store: bool = True, - **kwargs: Any, -) -> NonStreamedChatResponse: - """ - Generate chat response from model deployment non streaming response. + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamStart, dict[str, Any], Message, dict[str, Document]]: + event["conversation_id"] = conversation_id + stream_event = StreamStart.model_validate(event) + response_message.generation_id = event["generation_id"] + stream_end_data["generation_id"] = event["generation_id"] + return stream_event, stream_end_data, response_message, document_ids_to_document + + +def handle_stream_text_generation( + event: dict[str, Any], + _: str, + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamTextGeneration, dict[str, Any], Message, dict[str, Document]]: + stream_end_data["text"] += event["text"] + stream_event = StreamTextGeneration.model_validate(event) + return stream_event, stream_end_data, response_message, document_ids_to_document - Args: - session (DBSessionDep): Database session. - model_deployment_response (Any): Model deployment response. - response_message (Message): Response message object. - conversation_id (str): Conversation ID. - user_id (str): User ID. - should_store (bool): Whether to store the conversation in the database. - **kwargs (Any): Additional keyword arguments. - Returns: - NonStreamedChatResponse: Chat response. - """ - model_deployment_response = next(model_deployment_response) - if not isinstance(model_deployment_response, dict): - response = model_deployment_response.__dict__ - else: - response = model_deployment_response - - chat_history = [] - for message in response.get("chat_history", []): - if not isinstance(message, dict): - message = message.__dict__ - - chat_history.append( - ChatMessage( - role=message.get("role", ChatRole.USER), - message=message.get("message"), - tool_results=message.get("tool_results"), - ) +def handle_stream_search_results( + event: dict[str, Any], + _: str, + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamSearchResults, dict[str, Any], Message, dict[str, Document]]: + for document in event["documents"]: + storage_document = Document( + document_id=document.get("id", ""), + text=document.get("text", ""), + title=document.get("title", ""), + url=document.get("url", ""), + tool_name=document.get("tool_name", ""), + # all document fields except for id, tool_name and text + fields={ + k: v + for k, v in document.items() + if k not in ["id", "tool_name", "text"] + }, + user_id=response_message.user_id, + conversation_id=response_message.conversation_id, + message_id=response_message.id, ) + document_ids_to_document[document["id"]] = storage_document - documents = [] - if "documents" in response and response["documents"]: - documents = [ - Document( - document_id=document.get("id", ""), - text=document.get("text", ""), - title=document.get("title", ""), - url=document.get("url", ""), + documents = list(document_ids_to_document.values()) + response_message.documents = documents + + stream_end_data["documents"].extend(documents) + if "search_results" not in event or event["search_results"] is None: + event["search_results"] = [] + + stream_event = StreamSearchResults( + **event + | { + "documents": documents, + "search_results": event["search_results"], + }, + ) + stream_end_data["search_results"].extend(event["search_results"]) + return stream_event, stream_end_data, response_message, document_ids_to_document + + +def handle_stream_search_queries_generation( + event: dict[str, Any], + _: str, + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamSearchQueriesGeneration, dict[str, Any], Message, dict[str, Document]]: + search_queries = [] + for search_query in event["search_queries"]: + search_queries.append( + SearchQuery( + text=search_query.get("text", ""), + generation_id=search_query.get("generation_id", ""), ) - for document in response.get("documents") - ] + ) + stream_event = StreamSearchQueriesGeneration( + **event | {"search_queries": search_queries} + ) + stream_end_data["search_queries"] = search_queries + return stream_event, stream_end_data, response_message, document_ids_to_document + +def handle_stream_tool_calls_generation( + event: dict[str, Any], + _: str, + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamToolCallsGeneration, dict[str, Any], Message, dict[str, Document]]: tool_calls = [] - if "tool_calls" in response and response["tool_calls"]: - for tool_call in response.get("tool_calls", []): - tool_calls.append( - ToolCall( - name=tool_call.name, - parameters=tool_call.parameters, - ) + tool_calls_event = event.get("tool_calls", []) + for tool_call in tool_calls_event: + tool_calls.append( + ToolCall( + name=tool_call.get("name"), + parameters=tool_call.get("parameters"), ) + ) + stream_event = StreamToolCallsGeneration(**event | {"tool_calls": tool_calls}) + stream_end_data["tool_calls"].extend(tool_calls) + return stream_event, stream_end_data, response_message, document_ids_to_document - non_streamed_chat_response = NonStreamedChatResponse( - text=response.get("text", ""), - response_id=response.get("response_id", ""), - generation_id=response.get("generation_id", ""), - chat_history=chat_history, - finish_reason=response.get("finish_reason", ""), - citations=response.get("citations", []), - search_queries=response.get("search_queries", []), - documents=documents, - search_results=response.get("search_results", []), - event_type=StreamEvent.NON_STREAMED_CHAT_RESPONSE, - conversation_id=conversation_id, - tool_calls=tool_calls, - ) - - response_message.text = non_streamed_chat_response.text - response_message.generation_id = non_streamed_chat_response.generation_id - if should_store: - update_conversation_after_turn( - session, - response_message, - conversation_id, - non_streamed_chat_response.text, - user_id, +def handle_stream_citation_generation( + event: dict[str, Any], + _: str, + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamCitationGeneration, dict[str, Any], Message, dict[str, Document]]: + citations = [] + for event_citation in event["citations"]: + citation = Citation( + text=event_citation.get("text"), + user_id=response_message.user_id, + start=event_citation.get("start"), + end=event_citation.get("end"), + document_ids=event_citation.get("document_ids"), ) + for document_id in citation.document_ids: + document = document_ids_to_document.get(document_id, None) + if document is not None: + citation.documents.append(document) + citations.append(citation) + stream_event = StreamCitationGeneration(**event | {"citations": citations}) + stream_end_data["citations"].extend(citations) + return stream_event, stream_end_data, response_message, document_ids_to_document + + +def handle_stream_tool_calls_chunk( + event: dict[str, Any], + _: str, + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamToolCallsChunk, dict[str, Any], Message, dict[str, Document]]: + event["text"] = event.get("text", "") + tool_call_delta = event.get("tool_call_delta", None) + if tool_call_delta: + tool_call = ToolCallDelta( + name=tool_call_delta.get("name"), + index=tool_call_delta.get("index"), + parameters=tool_call_delta.get("parameters"), + ) + event["tool_call_delta"] = tool_call - return non_streamed_chat_response + stream_event = StreamToolCallsChunk.model_validate(event) + return stream_event, stream_end_data, response_message, document_ids_to_document + + +def handle_stream_end( + event: dict[str, Any], + _: str, + stream_end_data: dict[str, Any], + response_message: Message, + document_ids_to_document: dict[str, Document], +) -> tuple[StreamEnd, dict[str, Any], Message, dict[str, Document]]: + response_message.citations = stream_end_data["citations"] + response_message.text = stream_end_data["text"] + stream_end = StreamEnd.model_validate(event | stream_end_data) + stream_event = stream_end + return stream_event, stream_end_data, response_message, document_ids_to_document def generate_langchain_chat_stream( diff --git a/src/backend/tests/model_deployments/mock_deployments/mock_azure.py b/src/backend/tests/model_deployments/mock_deployments/mock_azure.py index d6bc9e6907..432f758d74 100644 --- a/src/backend/tests/model_deployments/mock_deployments/mock_azure.py +++ b/src/backend/tests/model_deployments/mock_deployments/mock_azure.py @@ -66,11 +66,13 @@ def invoke_chat_stream( }, { "event_type": StreamEvent.STREAM_END, - "generation_id": "test", - "citations": [], - "documents": [], - "search_results": [], - "search_queries": [], + "response": { + "generation_id": "test", + "citations": [], + "documents": [], + "search_results": [], + "search_queries": [], + }, "finish_reason": "MAX_TOKENS", }, ] @@ -92,6 +94,28 @@ def invoke_rerank( ) -> Any: return None - def invoke_tools(self, message: str, tools: List[Any], **kwargs: Any) -> List[Any]: - # TODO: Add - pass + def invoke_tools( + self, message: str, tools: List[Any], **kwargs: Any + ) -> Generator[StreamedChatResponse, None, None]: + events = [ + { + "event_type": StreamEvent.STREAM_START, + "generation_id": "test", + }, + { + "event_type": StreamEvent.TEXT_GENERATION, + "text": "This is a test.", + }, + { + "event_type": StreamEvent.STREAM_END, + "generation_id": "test", + "citations": [], + "documents": [], + "search_results": [], + "search_queries": [], + "finish_reason": "MAX_TOKENS", + }, + ] + + for event in events: + yield event diff --git a/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py b/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py index a2522970e3..8e12266776 100644 --- a/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py +++ b/src/backend/tests/model_deployments/mock_deployments/mock_bedrock.py @@ -64,11 +64,13 @@ def invoke_chat_stream( }, { "event_type": StreamEvent.STREAM_END, - "generation_id": "test", - "citations": [], - "documents": [], - "search_results": [], - "search_queries": [], + "response": { + "generation_id": "test", + "citations": [], + "documents": [], + "search_results": [], + "search_queries": [], + }, "finish_reason": "MAX_TOKENS", }, ] diff --git a/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py b/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py index 574a12577b..ef1cae372d 100644 --- a/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py +++ b/src/backend/tests/model_deployments/mock_deployments/mock_cohere_platform.py @@ -64,11 +64,13 @@ def invoke_chat_stream( }, { "event_type": StreamEvent.STREAM_END, - "generation_id": "test", - "citations": [], - "documents": [], - "search_results": [], - "search_queries": [], + "response": { + "generation_id": "test", + "citations": [], + "documents": [], + "search_results": [], + "search_queries": [], + }, "finish_reason": "MAX_TOKENS", }, ] diff --git a/src/backend/tests/model_deployments/mock_deployments/mock_sagemaker.py b/src/backend/tests/model_deployments/mock_deployments/mock_sagemaker.py index 528c1aba8e..8e614151d2 100644 --- a/src/backend/tests/model_deployments/mock_deployments/mock_sagemaker.py +++ b/src/backend/tests/model_deployments/mock_deployments/mock_sagemaker.py @@ -38,11 +38,13 @@ def invoke_chat_stream( }, { "event_type": StreamEvent.STREAM_END, - "generation_id": "test", - "citations": [], - "documents": [], - "search_results": [], - "search_queries": [], + "response": { + "generation_id": "test", + "citations": [], + "documents": [], + "search_results": [], + "search_queries": [], + }, "finish_reason": "MAX_TOKENS", }, ] diff --git a/src/backend/tests/model_deployments/test_azure.py b/src/backend/tests/model_deployments/test_azure.py index d31774bb0a..3c35978bf7 100644 --- a/src/backend/tests/model_deployments/test_azure.py +++ b/src/backend/tests/model_deployments/test_azure.py @@ -16,7 +16,6 @@ def test_streamed_chat( mock_available_model_deployments, ): deployment = mock_azure_deployment.return_value - deployment.invoke_chat_stream = MagicMock() response = session_client_chat.post( "/v1/chat-stream", headers={ @@ -28,29 +27,6 @@ def test_streamed_chat( assert response.status_code == 200 assert type(deployment) is MockAzureDeployment - deployment.invoke_chat_stream.assert_called_once_with( - CohereChatRequest( - message="Hello", - chat_history=[], - conversation_id="", - documents=[], - model="command-r", - temperature=None, - k=None, - p=None, - preamble=None, - file_ids=None, - tools=[], - search_queries_only=False, - deployment=None, - max_tokens=10, - seed=None, - stop_sequences=None, - presence_penalty=None, - frequency_penalty=None, - prompt_truncation="AUTO_PRESERVE_ORDER", - ) - ) def test_non_streamed_chat( diff --git a/src/backend/tests/model_deployments/test_bedrock.py b/src/backend/tests/model_deployments/test_bedrock.py index 1650591d49..e9d1defa4d 100644 --- a/src/backend/tests/model_deployments/test_bedrock.py +++ b/src/backend/tests/model_deployments/test_bedrock.py @@ -15,7 +15,6 @@ def test_streamed_chat( mock_available_model_deployments, ): deployment = mock_bedrock_deployment.return_value - deployment.invoke_chat_stream = MagicMock() response = session_client_chat.post( "/v1/chat-stream", headers={ @@ -27,29 +26,6 @@ def test_streamed_chat( assert response.status_code == 200 assert type(deployment) is MockBedrockDeployment - deployment.invoke_chat_stream.assert_called_once_with( - CohereChatRequest( - message="Hello", - chat_history=[], - conversation_id="", - documents=[], - model="command-r", - temperature=None, - k=None, - p=None, - preamble=None, - file_ids=None, - tools=[], - search_queries_only=False, - deployment=None, - max_tokens=10, - seed=None, - stop_sequences=None, - presence_penalty=None, - frequency_penalty=None, - prompt_truncation="AUTO_PRESERVE_ORDER", - ) - ) def test_non_streamed_chat( diff --git a/src/backend/tests/model_deployments/test_cohere_platform.py b/src/backend/tests/model_deployments/test_cohere_platform.py index db60c93c4f..53cf462a0e 100644 --- a/src/backend/tests/model_deployments/test_cohere_platform.py +++ b/src/backend/tests/model_deployments/test_cohere_platform.py @@ -16,7 +16,6 @@ def test_streamed_chat( mock_available_model_deployments, ): deployment = mock_cohere_deployment.return_value - deployment.invoke_chat_stream = MagicMock() response = session_client_chat.post( "/v1/chat-stream", headers={ @@ -28,29 +27,6 @@ def test_streamed_chat( assert response.status_code == 200 assert type(deployment) is MockCohereDeployment - deployment.invoke_chat_stream.assert_called_once_with( - CohereChatRequest( - message="Hello", - chat_history=[], - conversation_id="", - documents=[], - model="command-r", - temperature=None, - k=None, - p=None, - preamble=None, - file_ids=None, - tools=[], - search_queries_only=False, - deployment=None, - max_tokens=10, - seed=None, - stop_sequences=None, - presence_penalty=None, - frequency_penalty=None, - prompt_truncation="AUTO_PRESERVE_ORDER", - ) - ) def test_non_streamed_chat( diff --git a/src/backend/tests/model_deployments/test_sagemaker.py b/src/backend/tests/model_deployments/test_sagemaker.py index a2f53c7398..526ec59ee0 100644 --- a/src/backend/tests/model_deployments/test_sagemaker.py +++ b/src/backend/tests/model_deployments/test_sagemaker.py @@ -16,7 +16,6 @@ def test_streamed_chat( mock_available_model_deployments, ): deployment = mock_sagemaker_deployment.return_value - deployment.invoke_chat_stream = MagicMock() response = session_client_chat.post( "/v1/chat-stream", headers={"User-Id": user.id, "Deployment-Name": ModelDeploymentName.SageMaker}, @@ -25,29 +24,6 @@ def test_streamed_chat( assert response.status_code == 200 assert type(deployment) is MockSageMakerDeployment - deployment.invoke_chat_stream.assert_called_once_with( - CohereChatRequest( - message="Hello", - chat_history=[], - conversation_id="", - documents=[], - model="command-r", - temperature=None, - k=None, - p=None, - preamble=None, - file_ids=None, - tools=[], - search_queries_only=False, - deployment=None, - max_tokens=10, - seed=None, - stop_sequences=None, - presence_penalty=None, - frequency_penalty=None, - prompt_truncation="AUTO_PRESERVE_ORDER", - ) - ) @pytest.mark.skip("Non-streamed chat is not supported for SageMaker yet") diff --git a/src/backend/tests/tools/test_collate.py b/src/backend/tests/tools/test_collate.py index 6f0455cfe3..7e34e7e5ea 100644 --- a/src/backend/tests/tools/test_collate.py +++ b/src/backend/tests/tools/test_collate.py @@ -15,24 +15,61 @@ @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_rerank() -> None: model = CohereDeployment(model_config={}) + outputs = [ + { + "text": "Mount Everest is Earth's highest mountain above sea level, located in the Mahalangur Himal sub-range of the Himalayas" + }, + { + "text": "There are four components - or parts - of the blood: red blood cells, white blood cells, plasma and platelets." + }, + { + "text": "'My Man Rocks Me (with One Steady Roll)' by Trixie Smith was issued in 1922, the first record to refer to 'rocking' and 'rolling' in a secular context" + }, + ] tool_results = [ { - "call": ToolCall(parameters={"query": "mountain"}, name="retriever"), - "outputs": [{"text": "hill"}, {"text": "goat"}, {"text": "cable"}], + "call": { + "parameters": {"query": "what is the highest mountain in the world?"}, + "name": "retriever", + }, + "outputs": outputs, + }, + { + "call": { + "parameters": {"query": "What are the 4 major components of blood?"}, + "name": "retriever", + }, + "outputs": outputs, }, { - "call": ToolCall(parameters={"query": "computer"}, name="retriever"), - "outputs": [{"text": "cable"}, {"text": "software"}, {"text": "penguin"}], + "call": { + "parameters": {"query": "When was 1st Olympics in history?"}, + "name": "retriever", + }, + "outputs": outputs, }, ] expected_output = [ { - "call": ToolCall(name="retriever", parameters={"query": "mountain"}), - "outputs": [], + "call": { + "parameters": {"query": "what is the highest mountain in the world?"}, + "name": "retriever", + }, + "outputs": [outputs[0]], + }, + { + "call": { + "parameters": {"query": "What are the 4 major components of blood?"}, + "name": "retriever", + }, + "outputs": [outputs[1]], }, { - "call": ToolCall(name="retriever", parameters={"query": "computer"}), + "call": { + "parameters": {"query": "When was 1st Olympics in history?"}, + "name": "retriever", + }, "outputs": [], }, ] From a6c04d36469d915f0b7848d278555bdc7755b2b1 Mon Sep 17 00:00:00 2001 From: Tianjing Li Date: Mon, 17 Jun 2024 11:28:35 -0400 Subject: [PATCH 3/4] Add logic to return Null on error_message if tool is available (#209) * Add logic to return Null on error_message if too lis available * lint --- src/backend/config/tools.py | 7 +++++-- src/backend/tests/routers/test_tool.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/src/backend/config/tools.py b/src/backend/config/tools.py index aadc569d99..e28592f884 100644 --- a/src/backend/config/tools.py +++ b/src/backend/config/tools.py @@ -152,17 +152,20 @@ def get_available_tools() -> dict[ToolName, dict]: key: value for key, value in ALL_TOOLS.items() if key in langchain_tools } + tools = ALL_TOOLS.copy() if use_community_tools: try: from community.config.tools import COMMUNITY_TOOLS tools = ALL_TOOLS.copy() tools.update(COMMUNITY_TOOLS) - return tools except ImportError: logging.warning("Community tools are not available. Skipping.") - return ALL_TOOLS + for tool in tools.values(): + tool.error_message = tool.error_message if not tool.is_available else None + + return tools AVAILABLE_TOOLS = get_available_tools() diff --git a/src/backend/tests/routers/test_tool.py b/src/backend/tests/routers/test_tool.py index f08582d710..ac6f181875 100644 --- a/src/backend/tests/routers/test_tool.py +++ b/src/backend/tests/routers/test_tool.py @@ -22,6 +22,14 @@ def test_list_tools(session_client: TestClient, session: Session) -> None: assert tool["description"] == tool_definition.description +def test_list_tools_error_message_none_if_available(client: TestClient) -> None: + response = client.get("/v1/tools") + assert response.status_code == 200 + for tool in response.json(): + if tool["is_available"]: + assert tool["error_message"] is None + + def test_list_tools_with_agent(session_client: TestClient, session: Session) -> None: agent = get_factory("Agent", session).create( name="test agent", tools=[ToolName.Wiki_Retriever_LangChain] From 99e2994444b3236215e5304e017e7ca1950a6aa3 Mon Sep 17 00:00:00 2001 From: Khalil Najjar Date: Mon, 17 Jun 2024 17:28:57 +0200 Subject: [PATCH 4/4] feat(assistants): Add assistant left side navigation (#220) * feat(assistant): add sidebar with Base Assistant * feat(assistant): add cohere colors picker * feat(assistant): AgentsSidePanel collapsed UI * docstring * code feedback --- .../src/components/Agents/AddAgentButton.tsx | 40 ------ .../src/components/Agents/AgentCard.tsx | 81 ++++++++++++ .../src/components/Agents/AgentsList.tsx | 43 ++++++ .../src/components/Agents/AgentsSidePanel.tsx | 109 ++++++++++++++++ .../src/components/Agents/BaseAgentButton.tsx | 16 --- .../src/components/Agents/Layout.tsx | 35 +---- .../src/components/Agents/LeftPanel.tsx | 14 -- .../src/components/Conversation/Header.tsx | 6 +- .../ConversationList/ConversationCard.tsx | 1 + .../coral_web/src/components/KebabMenu.tsx | 122 ++++++++---------- .../coral_web/src/pages/agents/index.tsx | 4 +- .../coral_web/src/pages/agents/new/index.tsx | 4 +- .../coral_web/src/stores/persistedStore.ts | 3 +- .../src/stores/slices/settingsSlice.ts | 11 ++ .../coral_web/src/utils/getCohereColor.ts | 24 ++++ 15 files changed, 343 insertions(+), 170 deletions(-) delete mode 100644 src/interfaces/coral_web/src/components/Agents/AddAgentButton.tsx create mode 100644 src/interfaces/coral_web/src/components/Agents/AgentCard.tsx create mode 100644 src/interfaces/coral_web/src/components/Agents/AgentsList.tsx create mode 100644 src/interfaces/coral_web/src/components/Agents/AgentsSidePanel.tsx delete mode 100644 src/interfaces/coral_web/src/components/Agents/BaseAgentButton.tsx delete mode 100644 src/interfaces/coral_web/src/components/Agents/LeftPanel.tsx create mode 100644 src/interfaces/coral_web/src/utils/getCohereColor.ts diff --git a/src/interfaces/coral_web/src/components/Agents/AddAgentButton.tsx b/src/interfaces/coral_web/src/components/Agents/AddAgentButton.tsx deleted file mode 100644 index 33f7b9713a..0000000000 --- a/src/interfaces/coral_web/src/components/Agents/AddAgentButton.tsx +++ /dev/null @@ -1,40 +0,0 @@ -import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react'; -import Link from 'next/link'; - -import { Icon, Text } from '@/components/Shared'; - -/** - * @description renders a button to add a new agent. - */ -export const AddAgentButton: React.FC = () => { - return ( - - -
- -
-
- - - - Create new agent - -
- - - Add an existing agent - - -
- ); -}; diff --git a/src/interfaces/coral_web/src/components/Agents/AgentCard.tsx b/src/interfaces/coral_web/src/components/Agents/AgentCard.tsx new file mode 100644 index 0000000000..0aa17b7fda --- /dev/null +++ b/src/interfaces/coral_web/src/components/Agents/AgentCard.tsx @@ -0,0 +1,81 @@ +import { Transition } from '@headlessui/react'; +import Link from 'next/link'; + +import { KebabMenu } from '@/components/KebabMenu'; +import { CoralLogo, Text, Tooltip } from '@/components/Shared'; +import { cn } from '@/utils'; +import { getCohereColor } from '@/utils/getCohereColor'; + +type Props = { + isExpanded: boolean; + name: string; + isBaseAgent?: boolean; + id?: string; +}; + +/** + * @description This component renders an agent card. + * It shows the agent's name and a colored icon with the first letter of the agent's name. + * If the agent is a base agent, it shows the Coral logo instead. + */ +export const AgentCard: React.FC = ({ name, id, isBaseAgent, isExpanded }) => { + return ( + + +
+ {isBaseAgent && } + {!isBaseAgent && ( + + {name[0]} + + )} +
+ + {name} + + + {}, + iconName: 'hide', + }, + ]} + /> + + +
+ ); +}; diff --git a/src/interfaces/coral_web/src/components/Agents/AgentsList.tsx b/src/interfaces/coral_web/src/components/Agents/AgentsList.tsx new file mode 100644 index 0000000000..530bc0c61f --- /dev/null +++ b/src/interfaces/coral_web/src/components/Agents/AgentsList.tsx @@ -0,0 +1,43 @@ +import { Transition } from '@headlessui/react'; + +import { AgentCard } from '@/components/Agents/AgentCard'; +import { Text } from '@/components/Shared'; +import { useSettingsStore } from '@/stores'; + +/** + * @description This component renders a list of agents. + * It shows the most recent agents and the base agents. + */ +export const AgentsList: React.FC = () => { + const { + settings: { isAgentsSidePanelOpen }, + } = useSettingsStore(); + + return ( +
+ + + Most recent + + + + + + +
+ ); +}; diff --git a/src/interfaces/coral_web/src/components/Agents/AgentsSidePanel.tsx b/src/interfaces/coral_web/src/components/Agents/AgentsSidePanel.tsx new file mode 100644 index 0000000000..2eefb0b4c4 --- /dev/null +++ b/src/interfaces/coral_web/src/components/Agents/AgentsSidePanel.tsx @@ -0,0 +1,109 @@ +import { Transition } from '@headlessui/react'; +import Link from 'next/link'; + +import IconButton from '@/components/IconButton'; +import { Button, Icon, IconProps, Logo, Tooltip } from '@/components/Shared'; +import { env } from '@/env.mjs'; +import { useSettingsStore } from '@/stores'; +import { cn } from '@/utils'; + +/** + * @description This component renders the agents side panel. + * It contains the logo and a button to expand or collapse the panel. + * It also renders the children components that are passed to it. + */ +export const AgentsSidePanel: React.FC = ({ children }) => { + const { + settings: { isAgentsSidePanelOpen }, + setIsAgentsSidePanelOpen, + } = useSettingsStore(); + + const navigationItems: { + label: string; + icon: IconProps['name']; + href?: string; + onClick?: () => void; + }[] = [ + { label: 'Create Assistant ', icon: 'add', href: '/agents/new' }, + { label: 'Sign Out', icon: 'profile', onClick: () => void 0 }, + ]; + + return ( +
+
+ {isAgentsSidePanelOpen && ( + +
+ +
+ + )} + setIsAgentsSidePanelOpen(!isAgentsSidePanelOpen)} + className={cn('transition delay-100 duration-200 ease-in-out', { + 'rotate-180 transform text-secondary-700': isAgentsSidePanelOpen, + })} + /> +
+
{children}
+ + {navigationItems.map(({ label, icon, href, onClick }) => ( + + + + ))} + + + {navigationItems.map(({ label, icon, href, onClick }) => ( +
+ ); +}; diff --git a/src/interfaces/coral_web/src/components/Agents/BaseAgentButton.tsx b/src/interfaces/coral_web/src/components/Agents/BaseAgentButton.tsx deleted file mode 100644 index 0b9f559175..0000000000 --- a/src/interfaces/coral_web/src/components/Agents/BaseAgentButton.tsx +++ /dev/null @@ -1,16 +0,0 @@ -import Link from 'next/link'; - -import { CoralLogo } from '@/components/Shared'; - -/** - * @description renders a button to navigate to the default knowledge agent page. - */ -export const BaseAgentButton: React.FC = () => { - return ( - -
- -
- - ); -}; diff --git a/src/interfaces/coral_web/src/components/Agents/Layout.tsx b/src/interfaces/coral_web/src/components/Agents/Layout.tsx index f55fa9f780..c8f2ee67ed 100644 --- a/src/interfaces/coral_web/src/components/Agents/Layout.tsx +++ b/src/interfaces/coral_web/src/components/Agents/Layout.tsx @@ -1,14 +1,10 @@ import { Transition } from '@headlessui/react'; import { capitalize } from 'lodash'; -import React, { Children, PropsWithChildren, useContext } from 'react'; +import React, { Children, PropsWithChildren } from 'react'; +import { AgentsSidePanel } from '@/components/Agents/AgentsSidePanel'; import { ConfigurationDrawer } from '@/components/Conversation/ConfigurationDrawer'; -import { DeploymentsDropdown } from '@/components/DeploymentsDropdown'; -import { EditEnvVariablesButton } from '@/components/EditEnvVariablesButton'; -import { Banner } from '@/components/Shared'; -import { NavigationBar } from '@/components/Shared/NavigationBar/NavigationBar'; import { PageHead } from '@/components/Shared/PageHead'; -import { BannerContext } from '@/context/BannerContext'; import { useIsDesktop } from '@/hooks/breakpoint'; import { useSettingsStore } from '@/stores'; import { cn } from '@/utils/cn'; @@ -21,14 +17,13 @@ type Props = { } & PropsWithChildren; /** - * This component is in charge of layout out the entire page. - * It shows the navigation bar, the left drawer and main content. - * On small devices (e.g. mobile), the left drawer and main section are stacked vertically. + * @description This component is in charge of layout out the entire page. + It shows the navigation bar, the left drawer and main content. + On small devices (e.g. mobile), the left drawer and main section are stacked vertically. */ export const Layout: React.FC = ({ title = 'Chat', children }) => { - const { message: bannerMessage } = useContext(BannerContext); const { - settings: { isConvListPanelOpen, isMobileConvListPanelOpen }, + settings: { isMobileConvListPanelOpen }, } = useSettingsStore(); const isDesktop = useIsDesktop(); @@ -52,24 +47,8 @@ export const Layout: React.FC = ({ title = 'Chat', children }) => { <>
- - - - - - - {bannerMessage && {bannerMessage}} -
-
- {leftElement} -
+ {leftElement} { - return ( -
- - -
- ); -}; diff --git a/src/interfaces/coral_web/src/components/Conversation/Header.tsx b/src/interfaces/coral_web/src/components/Conversation/Header.tsx index e72e0c197a..6738e431df 100644 --- a/src/interfaces/coral_web/src/components/Conversation/Header.tsx +++ b/src/interfaces/coral_web/src/components/Conversation/Header.tsx @@ -132,7 +132,11 @@ export const Header: React.FC = ({ conversationId, isStreaming }) => { - + = ({ isActive, conversation, flip {conversationLink}
= ({ items, className = '' }) => { - const [referenceElement, setReferenceElement] = useState(null); - const [popperElement, setPopperElement] = useState(null); - const { styles, attributes } = usePopper(referenceElement, popperElement, { - modifiers: [ - { - // Positions the menu relative to the kebab button - name: 'offset', - options: { - offset: [0, 4], - }, - }, - { - // Offsets the menu if it overflows and will be cutoff - name: 'preventOverflow', - options: { - padding: 16, - }, - }, - ], - }); - +export const KebabMenu: React.FC = ({ items, anchor, className = '' }) => { return ( {({ open }) => ( <> - e.stopPropagation()} > - - + = ({ items, className = '' }) => { 'transition-opacity ease-in-out', { hidden: !open } )} - ref={setPopperElement} - style={styles.popper} - {...attributes.popper} > {items.map( - ({ label, iconName, icon, onClick, className, visible = true }, index) => - visible && ( - - {!!iconName && ( - - )} - {!!icon && icon} - {label} - - ) + ({ label, iconName, icon, onClick, href, className, visible = true }, index) => { + return ( + visible && ( + { + e.stopPropagation(); + onClick?.(); + }} + {...(href ? { href, shallow: true } : {})} + > + {!!iconName && ( + + )} + {!!icon && icon} + {label} + + ) + ); + } )} diff --git a/src/interfaces/coral_web/src/pages/agents/index.tsx b/src/interfaces/coral_web/src/pages/agents/index.tsx index c7f20e9f64..bbdaf03186 100644 --- a/src/interfaces/coral_web/src/pages/agents/index.tsx +++ b/src/interfaces/coral_web/src/pages/agents/index.tsx @@ -3,8 +3,8 @@ import { GetServerSideProps, NextPage } from 'next'; import { useContext, useEffect } from 'react'; import { CohereClient } from '@/cohere-client'; +import { AgentsList } from '@/components/Agents/AgentsList'; import { Layout, LeftSection, MainSection } from '@/components/Agents/Layout'; -import { LeftPanel } from '@/components/Agents/LeftPanel'; import Conversation from '@/components/Conversation'; import { BannerContext } from '@/context/BannerContext'; import { useListAllDeployments } from '@/hooks/deployments'; @@ -53,7 +53,7 @@ const AgentsPage: NextPage = () => { return ( - + diff --git a/src/interfaces/coral_web/src/pages/agents/new/index.tsx b/src/interfaces/coral_web/src/pages/agents/new/index.tsx index cf2545a753..e87c89b3cf 100644 --- a/src/interfaces/coral_web/src/pages/agents/new/index.tsx +++ b/src/interfaces/coral_web/src/pages/agents/new/index.tsx @@ -2,8 +2,8 @@ import { QueryClient, dehydrate } from '@tanstack/react-query'; import { GetServerSideProps, NextPage } from 'next'; import { CohereClient } from '@/cohere-client'; +import { AgentsList } from '@/components/Agents/AgentsList'; import { Layout, LeftSection, MainSection } from '@/components/Agents/Layout'; -import { LeftPanel } from '@/components/Agents/LeftPanel'; import { appSSR } from '@/pages/_app'; type Props = {}; @@ -12,7 +12,7 @@ const AgentsNewPage: NextPage = () => { return ( - + Create a new agent diff --git a/src/interfaces/coral_web/src/stores/persistedStore.ts b/src/interfaces/coral_web/src/stores/persistedStore.ts index 5e10a73b51..196d55dea5 100644 --- a/src/interfaces/coral_web/src/stores/persistedStore.ts +++ b/src/interfaces/coral_web/src/stores/persistedStore.ts @@ -18,7 +18,7 @@ const usePersistedStore = create()( }), { name: 'settings', - version: 1, + version: 2, } ) ); @@ -48,6 +48,7 @@ export const useSettingsStore = () => { settings: state.settings, setSettings: state.setSettings, setIsConvListPanelOpen: state.setIsConvListPanelOpen, + setIsAgentsSidePanelOpen: state.setIsAgentsSidePanelOpen, }), shallow ); diff --git a/src/interfaces/coral_web/src/stores/slices/settingsSlice.ts b/src/interfaces/coral_web/src/stores/slices/settingsSlice.ts index 7e70fb95db..280f84724d 100644 --- a/src/interfaces/coral_web/src/stores/slices/settingsSlice.ts +++ b/src/interfaces/coral_web/src/stores/slices/settingsSlice.ts @@ -1,12 +1,14 @@ import { StateCreator } from 'zustand'; const INITIAL_STATE: Required = { + isAgentsSidePanelOpen: true, isConfigDrawerOpen: false, isConvListPanelOpen: true, isMobileConvListPanelOpen: false, }; type State = { + isAgentsSidePanelOpen: boolean; isConfigDrawerOpen: boolean; isConvListPanelOpen: boolean; isMobileConvListPanelOpen: boolean; @@ -15,6 +17,7 @@ type State = { type Actions = { setSettings: (settings: Partial) => void; setIsConvListPanelOpen: (isOpen: boolean) => void; + setIsAgentsSidePanelOpen: (isOpen: boolean) => void; }; export type SettingsStore = { @@ -30,6 +33,14 @@ export const createSettingsSlice: StateCreator ({ + settings: { + ...state.settings, + isAgentsSidePanelOpen: isOpen, + }, + })); + }, setIsConvListPanelOpen(isOpen) { set((state) => ({ settings: { diff --git a/src/interfaces/coral_web/src/utils/getCohereColor.ts b/src/interfaces/coral_web/src/utils/getCohereColor.ts new file mode 100644 index 0000000000..e0bb174ede --- /dev/null +++ b/src/interfaces/coral_web/src/utils/getCohereColor.ts @@ -0,0 +1,24 @@ +export const COLOR_LIST = [ + 'bg-quartz-500', + 'bg-green-500', + 'bg-primary-500', + 'bg-quartz-700', + 'bg-green-700', + 'bg-primary-700', +]; + +/** + * @description Get a color from the Cohere color palette, when no index is provided, a random color is returned + * @param id - id for generating a constant color in the palette + * @returns color from the Cohere color palette + */ +export const getCohereColor = (id?: string) => { + if (id === undefined) { + const randomIndex = Math.floor(Math.random() * COLOR_LIST.length); + + return COLOR_LIST[randomIndex]; + } + + const index = id.charCodeAt(0) % COLOR_LIST.length; + return COLOR_LIST[index]; +};