Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve agent creation flow #814

Merged
merged 8 commits into from
Oct 22, 2024
14 changes: 14 additions & 0 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,18 @@ def get_default_deployment(**kwargs) -> BaseDeployment:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None


AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
36 changes: 36 additions & 0 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import os

from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel, Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
)


@validate_transaction
Expand Down Expand Up @@ -184,3 +190,33 @@ def delete_deployment(db: Session, deployment_id: str) -> None:
deployment = db.query(Deployment).filter(Deployment.id == deployment_id)
deployment.delete()
db.commit()


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
"""
Create a new deployment by config.

Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.

Returns:
Deployment: Created deployment.
"""
deployment = Deployment(
name=deployment_config.name,
description="",
default_deployment_config= {
env_var: os.environ.get(env_var, "")
EugeneLightsOn marked this conversation as resolved.
Show resolved Hide resolved
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
)
db.add(deployment)
db.commit()
db.refresh(deployment)
return deployment

49 changes: 48 additions & 1 deletion src/backend/crud/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel
from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.transaction import validate_transaction

Expand Down Expand Up @@ -38,6 +39,20 @@ def get_model(db: Session, model_id: str) -> Model | None:
return db.query(Model).filter(Model.id == model_id).first()


def get_model_by_name(db: Session, model_name: str) -> Model | None:
"""
Get a model by name.

Args:
db (Session): Database session.
model_name (str): Model name.

Returns:
Model: Model with the given name.
"""
return db.query(Model).filter(Model.name == model_name).first()


def get_models(db: Session, offset: int = 0, limit: int = 100) -> list[Model]:
"""
List all models.
Expand Down Expand Up @@ -140,3 +155,35 @@ def get_models_by_agent_id(
.offset(offset)
.all()
)


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
"""
Create a new model by config if present

Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
model (str): Model data.

Returns:
Model: Created model.
"""
deployment_config_models = deployment_config.models
deployment_db_models = get_models_by_deployment_id(db, deployment.id)
model_to_return = None
for deployment_config_model in deployment_config_models:
model_in_db = any(record.name == deployment_config_model for record in deployment_db_models)
if not model_in_db:
new_model = Model(
name=deployment_config_model,
cohere_name=deployment_config_model,
deployment_id=deployment.id,
)
db.add(new_model)
db.commit()
if model == new_model.name:
model_to_return = new_model

return model_to_return
24 changes: 23 additions & 1 deletion src/backend/services/request_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from fastapi import HTTPException, Request

import backend.crud.user as user_crud
from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
find_config_by_deployment_id,
find_config_by_deployment_name,
)
from backend.config.tools import AVAILABLE_TOOLS
from backend.crud import agent as agent_crud
from backend.crud import conversation as conversation_crud
from backend.crud import deployment as deployment_crud
from backend.crud import model as model_crud
from backend.crud import organization as organization_crud
from backend.database_models.database import DBSessionDep
from backend.model_deployments.utils import class_name_validator
Expand All @@ -34,6 +39,19 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep
deployment_db = deployment_crud.get_deployment_by_name(session, deployment)
if not deployment_db:
deployment_db = deployment_crud.get_deployment(session, deployment)

# Check deployment config settings availability
deployment_config = find_config_by_deployment_id(deployment)
if not deployment_config:
deployment_config = find_config_by_deployment_name(deployment)
if not deployment_config:
raise HTTPException(
status_code=400,
detail=f"Deployment {deployment} not found or is not available in the Database.",
)

if not deployment_db:
deployment_db = deployment_crud.create_deployment_by_config(session, deployment_config)
if not deployment_db:
raise HTTPException(
status_code=400,
Expand All @@ -48,6 +66,10 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep
),
None,
)
if not deployment_model:
deployment_model = model_crud.create_model_by_config(
session, deployment_db, deployment_config, model
)
if not deployment_model:
raise HTTPException(
status_code=404,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/unit/factories/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ class Meta:

deployment = factory.SubFactory(DeploymentFactory)
deployment_id = factory.SelfAttribute("deployment.id")
name = factory.Faker("name")
name = "command-r-plus"
cohere_name = factory.Faker("name")
description = factory.Faker("text")
32 changes: 32 additions & 0 deletions src/backend/tests/unit/routers/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import os

import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import ToolName
from backend.crud import agent as agent_crud
from backend.crud import deployment as deployment_crud
from backend.database_models.agent import Agent
from backend.database_models.agent_tool_metadata import AgentToolMetadata
from backend.database_models.snapshot import Snapshot
from backend.tests.unit.factories import get_factory

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
and os.environ.get("COHERE_API_KEY") != ""
)

def test_create_agent_missing_name(
session_client: TestClient, session: Session, user
Expand Down Expand Up @@ -96,6 +103,31 @@ def test_create_agent_invalid_deployment(
}


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
def test_create_agent_deployment_not_in_db(
session_client: TestClient, session: Session, user
) -> None:
request_json = {
"name": "test agent",
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"model": "command-r-plus",
"deployment": ModelDeploymentName.CoherePlatform,
}
cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
deployment_crud.delete_deployment(session, cohere_deployment.id)
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": user.id}
)
cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
deployment_models = cohere_deployment.models
deployment_models_list = [model.name for model in deployment_models]
assert response.status_code == 200
assert cohere_deployment
assert "command-r-plus" in deployment_models_list


def test_create_agent_invalid_tool(
session_client: TestClient, session: Session, user
) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/backend/tests/unit/routers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,20 @@ def test_streaming_chat_with_tools_not_in_agent_tools(
},
json={
"message": "Who is a tallest nba player",
"tools": [{"name": "web_search"}],
"tools": [{"name": "tavily_web_search"}],
"agent_id": agent.id,
},
)

assert response.status_code == 200
validate_chat_streaming_tool_cals_response(response, ["web_search"])
validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"])


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
def test_streaming_chat_with_agent_tools_and_empty_request_tools(
session_client_chat: TestClient, session_chat: Session, user: User
):
agent = get_factory("Agent", session_chat).create(user=user, tools=["web_search"])
agent = get_factory("Agent", session_chat).create(user=user, tools=["tavily_web_search"])
deployment = get_factory("Deployment", session_chat).create()
model = get_factory("Model", session_chat).create(deployment=deployment)
get_factory("AgentDeploymentModel", session_chat).create(
Expand All @@ -273,7 +273,7 @@ def test_streaming_chat_with_agent_tools_and_empty_request_tools(
)

assert response.status_code == 200
validate_chat_streaming_tool_cals_response(response, ["web_search"])
validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"])


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_streaming_fail_chat_missing_message(
"loc": ["body", "message"],
"msg": "Field required",
"input": {},
"url": "https://errors.pydantic.dev/2.8/v/missing",
"url": "https://errors.pydantic.dev/2.9/v/missing",
}
]
}
Expand Down
Loading