From c460ddea61023175347530d700021ed423e18b40 Mon Sep 17 00:00:00 2001 From: EugeneP Date: Wed, 16 Oct 2024 19:17:56 +0200 Subject: [PATCH] TLK-1771 - Improve agent creation flow --- src/backend/tests/unit/routers/test_agent.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/backend/tests/unit/routers/test_agent.py b/src/backend/tests/unit/routers/test_agent.py index ffe83c831b..b047318a82 100644 --- a/src/backend/tests/unit/routers/test_agent.py +++ b/src/backend/tests/unit/routers/test_agent.py @@ -1,4 +1,6 @@ +import os +import pytest from fastapi.testclient import TestClient from sqlalchemy.orm import Session @@ -6,12 +8,15 @@ from backend.config.tools import ToolName from backend.crud import agent as agent_crud from backend.crud import deployment as deployment_crud -from backend.crud import model as model_crud from backend.database_models.agent import Agent from backend.database_models.agent_tool_metadata import AgentToolMetadata from backend.database_models.snapshot import Snapshot from backend.tests.unit.factories import get_factory +is_cohere_env_set = ( + os.environ.get("COHERE_API_KEY") is not None + and os.environ.get("COHERE_API_KEY") != "" +) def test_create_agent_missing_name( session_client: TestClient, session: Session, user @@ -98,6 +103,7 @@ def test_create_agent_invalid_deployment( } +@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set") def test_create_agent_deployment_not_in_db( session_client: TestClient, session: Session, user ) -> None: @@ -110,18 +116,16 @@ def test_create_agent_deployment_not_in_db( "deployment": ModelDeploymentName.CoherePlatform, } cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) - assert cohere_deployment deployment_crud.delete_deployment(session, cohere_deployment.id) - cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) - assert not cohere_deployment response = session_client.post( "/v1/agents", json=request_json, headers={"User-Id": user.id} ) cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform) - model_command_r_plus = model_crud.get_model_by_name(session, "command-r-plus") + deployment_models = cohere_deployment.models + deployment_models_list = [model.name for model in deployment_models] assert response.status_code == 200 assert cohere_deployment - assert model_command_r_plus + assert "command-r-plus" in deployment_models_list def test_create_agent_invalid_tool(