From ea4a66066070b49667b69154589617b0218435f4 Mon Sep 17 00:00:00 2001 From: Eugene P <144219719+EugeneLightsOn@users.noreply.github.com> Date: Tue, 22 Oct 2024 20:48:48 +0200 Subject: [PATCH] Improve agent creation flow (#814) * TLK-1771 - Initial commit * TLK-1771 - Improve agent creation flow * TLK-1771 - Improve agent creation flow - lint * TLK-1771 - Improve agent creation flow - lint * TLK-1507 - AWS Copilot deployment to ECS - review fixes * TLK-1771 - Improve agent creation flow --- src/backend/config/deployments.py | 14 ++++++ src/backend/crud/deployment.py | 36 ++++++++++++++ src/backend/crud/model.py | 49 +++++++++++++++++++- src/backend/services/request_validators.py | 24 +++++++++- src/backend/tests/unit/factories/model.py | 2 +- src/backend/tests/unit/routers/test_agent.py | 32 +++++++++++++ src/backend/tests/unit/routers/test_chat.py | 10 ++-- 7 files changed, 159 insertions(+), 8 deletions(-) diff --git a/src/backend/config/deployments.py b/src/backend/config/deployments.py index 3070a3e924..2397ce9eff 100644 --- a/src/backend/config/deployments.py +++ b/src/backend/config/deployments.py @@ -123,4 +123,18 @@ def get_default_deployment(**kwargs) -> BaseDeployment: return fallback +def find_config_by_deployment_id(deployment_id: str) -> Deployment: + for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): + if deployment.id == deployment_id: + return deployment + return None + + +def find_config_by_deployment_name(deployment_name: str) -> Deployment: + for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values(): + if deployment.name == deployment_name: + return deployment + return None + + AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments() diff --git a/src/backend/crud/deployment.py b/src/backend/crud/deployment.py index 9de13a5fd0..6c2090291a 100644 --- a/src/backend/crud/deployment.py +++ b/src/backend/crud/deployment.py @@ -1,9 +1,15 @@ +import os + from sqlalchemy.orm import Session from backend.database_models import AgentDeploymentModel, Deployment from backend.model_deployments.utils import class_name_validator +from backend.schemas.deployment import Deployment as DeploymentSchema from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate from backend.services.transaction import validate_transaction +from community.config.deployments import ( + AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS, +) @validate_transaction @@ -184,3 +190,33 @@ def delete_deployment(db: Session, deployment_id: str) -> None: deployment = db.query(Deployment).filter(Deployment.id == deployment_id) deployment.delete() db.commit() + + +@validate_transaction +def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment: + """ + Create a new deployment by config. + + Args: + db (Session): Database session. + deployment (str): Deployment data to be created. + deployment_config (DeploymentSchema): Deployment config. + + Returns: + Deployment: Created deployment. + """ + deployment = Deployment( + name=deployment_config.name, + description="", + default_deployment_config= { + env_var: os.environ.get(env_var, "") + for env_var in deployment_config.env_vars + }, + deployment_class_name=deployment_config.deployment_class.__name__, + is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS + ) + db.add(deployment) + db.commit() + db.refresh(deployment) + return deployment + diff --git a/src/backend/crud/model.py b/src/backend/crud/model.py index 3a55d73de2..84122891a1 100644 --- a/src/backend/crud/model.py +++ b/src/backend/crud/model.py @@ -1,7 +1,8 @@ from sqlalchemy.orm import Session -from backend.database_models import AgentDeploymentModel +from backend.database_models import AgentDeploymentModel, Deployment from backend.database_models.model import Model +from backend.schemas.deployment import Deployment as DeploymentSchema from backend.schemas.model import ModelCreate, ModelUpdate from backend.services.transaction import validate_transaction @@ -38,6 +39,20 @@ def get_model(db: Session, model_id: str) -> Model | None: return db.query(Model).filter(Model.id == model_id).first() +def get_model_by_name(db: Session, model_name: str) -> Model | None: + """ + Get a model by name. + + Args: + db (Session): Database session. + model_name (str): Model name. + + Returns: + Model: Model with the given name. + """ + return db.query(Model).filter(Model.name == model_name).first() + + def get_models(db: Session, offset: int = 0, limit: int = 100) -> list[Model]: """ List all models. @@ -140,3 +155,35 @@ def get_models_by_agent_id( .offset(offset) .all() ) + + +def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model: + """ + Create a new model by config if present + + Args: + db (Session): Database session. + deployment (Deployment): Deployment data. + deployment_config (DeploymentSchema): Deployment config data. + model (str): Model data. + + Returns: + Model: Created model. + """ + deployment_config_models = deployment_config.models + deployment_db_models = get_models_by_deployment_id(db, deployment.id) + model_to_return = None + for deployment_config_model in deployment_config_models: + model_in_db = any(record.name == deployment_config_model for record in deployment_db_models) + if not model_in_db: + new_model = Model( + name=deployment_config_model, + cohere_name=deployment_config_model, + deployment_id=deployment.id, + ) + db.add(new_model) + db.commit() + if model == new_model.name: + model_to_return = new_model + + return model_to_return diff --git a/src/backend/services/request_validators.py b/src/backend/services/request_validators.py index 20d22ffa9c..badb2b4369 100644 --- a/src/backend/services/request_validators.py +++ b/src/backend/services/request_validators.py @@ -3,11 +3,16 @@ from fastapi import HTTPException, Request import backend.crud.user as user_crud -from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS +from backend.config.deployments import ( + AVAILABLE_MODEL_DEPLOYMENTS, + find_config_by_deployment_id, + find_config_by_deployment_name, +) from backend.config.tools import AVAILABLE_TOOLS from backend.crud import agent as agent_crud from backend.crud import conversation as conversation_crud from backend.crud import deployment as deployment_crud +from backend.crud import model as model_crud from backend.crud import organization as organization_crud from backend.database_models.database import DBSessionDep from backend.model_deployments.utils import class_name_validator @@ -34,6 +39,19 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep deployment_db = deployment_crud.get_deployment_by_name(session, deployment) if not deployment_db: deployment_db = deployment_crud.get_deployment(session, deployment) + + # Check deployment config settings availability + deployment_config = find_config_by_deployment_id(deployment) + if not deployment_config: + deployment_config = find_config_by_deployment_name(deployment) + if not deployment_config: + raise HTTPException( + status_code=400, + detail=f"Deployment {deployment} not found or is not available in the Database.", + ) + + if not deployment_db: + deployment_db = deployment_crud.create_deployment_by_config(session, deployment_config) if not deployment_db: raise HTTPException( status_code=400, @@ -48,6 +66,10 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep ), None, ) + if not deployment_model: + deployment_model = model_crud.create_model_by_config( + session, deployment_db, deployment_config, model + ) if not deployment_model: raise HTTPException( status_code=404, diff --git a/src/backend/tests/unit/factories/model.py b/src/backend/tests/unit/factories/model.py index 7050475a99..f359a253fb 100644 --- a/src/backend/tests/unit/factories/model.py +++ b/src/backend/tests/unit/factories/model.py @@ -11,6 +11,6 @@ class Meta: deployment = factory.SubFactory(DeploymentFactory) deployment_id = factory.SelfAttribute("deployment.id") - name = factory.Faker("name") + name = "command-r-plus" cohere_name = factory.Faker("name") description = factory.Faker("text") diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index 0e4645b42a..b047318a82 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -1,15 +1,22 @@ +import os +import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session from backend.config.deployments import ModelDeploymentName from backend.config.tools import ToolName from backend.crud import agent as agent_crud +from backend.crud import deployment as deployment_crud from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.database_models.snapshot import Snapshot from backend.tests.unit.factories import get_factory +is_cohere_env_set = ( + os.environ.get("COHERE_API_KEY") is not None + and os.environ.get("COHERE_API_KEY") != "" +) def test_create_agent_missing_name( session_client: TestClient, session: Session, user @@ -96,6 +103,31 @@ def test_create_agent_invalid_deployment( } +@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") +def test_create_agent_deployment_not_in_db( + session_client: TestClient, session: Session, user +) -> None: + request_json = { + "name": "test agent", + "description": "test description", + "preamble": "test preamble", + "temperature": 0.5, + "model": "command-r-plus", + "deployment": ModelDeploymentName.CoherePlatform, + } + cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) + deployment_crud.delete_deployment(session, cohere_deployment.id) + response = session_client.post( + "/v1/agents", json=request_json, headers={"User-Id": user.id} + ) + cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) + deployment_models = cohere_deployment.models + deployment_models_list = [model.name for model in deployment_models] + assert response.status_code == 200 + assert cohere_deployment + assert "command-r-plus" in deployment_models_list + + def test_create_agent_invalid_tool( session_client: TestClient, session: Session, user ) -> None: diff --git a/src/backend/tests/unit/routers/test_chat.py b/src/backend/tests/unit/routers/test_chat.py index 6c61dc0080..7e8d06ea2e 100644 --- a/src/backend/tests/unit/routers/test_chat.py +++ b/src/backend/tests/unit/routers/test_chat.py @@ -236,20 +236,20 @@ def test_streaming_chat_with_tools_not_in_agent_tools( }, json={ "message": "Who is a tallest nba player", - "tools": [{"name": "web_search"}], + "tools": [{"name": "tavily_web_search"}], "agent_id": agent.id, }, ) assert response.status_code == 200 - validate_chat_streaming_tool_cals_response(response, ["web_search"]) + validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"]) @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_streaming_chat_with_agent_tools_and_empty_request_tools( session_client_chat: TestClient, session_chat: Session, user: User ): - agent = get_factory("Agent", session_chat).create(user=user, tools=["web_search"]) + agent = get_factory("Agent", session_chat).create(user=user, tools=["tavily_web_search"]) deployment = get_factory("Deployment", session_chat).create() model = get_factory("Model", session_chat).create(deployment=deployment) get_factory("AgentDeploymentModel", session_chat).create( @@ -273,7 +273,7 @@ def test_streaming_chat_with_agent_tools_and_empty_request_tools( ) assert response.status_code == 200 - validate_chat_streaming_tool_cals_response(response, ["web_search"]) + validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"]) @pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") @@ -369,7 +369,7 @@ def test_streaming_fail_chat_missing_message( "loc": ["body", "message"], "msg": "Field required", "input": {}, - "url": "https://errors.pydantic.dev/2.8/v/missing", + "url": "https://errors.pydantic.dev/2.9/v/missing", } ] }