Skip to content

Commit

Permalink
Improve agent creation flow (#814)
Browse files Browse the repository at this point in the history
* TLK-1771 - Initial commit

* TLK-1771 - Improve agent creation flow

* TLK-1771 - Improve agent creation flow - lint

* TLK-1771 - Improve agent creation flow - lint

* TLK-1507 - AWS Copilot deployment to ECS - review fixes

* TLK-1771 - Improve agent creation flow
  • Loading branch information
EugeneLightsOn authored Oct 22, 2024
1 parent 2900eb0 commit ea4a660
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 8 deletions.
14 changes: 14 additions & 0 deletions src/backend/config/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,18 @@ def get_default_deployment(**kwargs) -> BaseDeployment:
return fallback


def find_config_by_deployment_id(deployment_id: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.id == deployment_id:
return deployment
return None


def find_config_by_deployment_name(deployment_name: str) -> Deployment:
for deployment in AVAILABLE_MODEL_DEPLOYMENTS.values():
if deployment.name == deployment_name:
return deployment
return None


AVAILABLE_MODEL_DEPLOYMENTS = get_available_deployments()
36 changes: 36 additions & 0 deletions src/backend/crud/deployment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import os

from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel, Deployment
from backend.model_deployments.utils import class_name_validator
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.deployment import DeploymentCreate, DeploymentUpdate
from backend.services.transaction import validate_transaction
from community.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS as COMMUNITY_DEPLOYMENTS,
)


@validate_transaction
Expand Down Expand Up @@ -184,3 +190,33 @@ def delete_deployment(db: Session, deployment_id: str) -> None:
deployment = db.query(Deployment).filter(Deployment.id == deployment_id)
deployment.delete()
db.commit()


@validate_transaction
def create_deployment_by_config(db: Session, deployment_config: DeploymentSchema) -> Deployment:
"""
Create a new deployment by config.
Args:
db (Session): Database session.
deployment (str): Deployment data to be created.
deployment_config (DeploymentSchema): Deployment config.
Returns:
Deployment: Created deployment.
"""
deployment = Deployment(
name=deployment_config.name,
description="",
default_deployment_config= {
env_var: os.environ.get(env_var, "")
for env_var in deployment_config.env_vars
},
deployment_class_name=deployment_config.deployment_class.__name__,
is_community=deployment_config.name in COMMUNITY_DEPLOYMENTS
)
db.add(deployment)
db.commit()
db.refresh(deployment)
return deployment

49 changes: 48 additions & 1 deletion src/backend/crud/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from sqlalchemy.orm import Session

from backend.database_models import AgentDeploymentModel
from backend.database_models import AgentDeploymentModel, Deployment
from backend.database_models.model import Model
from backend.schemas.deployment import Deployment as DeploymentSchema
from backend.schemas.model import ModelCreate, ModelUpdate
from backend.services.transaction import validate_transaction

Expand Down Expand Up @@ -38,6 +39,20 @@ def get_model(db: Session, model_id: str) -> Model | None:
return db.query(Model).filter(Model.id == model_id).first()


def get_model_by_name(db: Session, model_name: str) -> Model | None:
"""
Get a model by name.
Args:
db (Session): Database session.
model_name (str): Model name.
Returns:
Model: Model with the given name.
"""
return db.query(Model).filter(Model.name == model_name).first()


def get_models(db: Session, offset: int = 0, limit: int = 100) -> list[Model]:
"""
List all models.
Expand Down Expand Up @@ -140,3 +155,35 @@ def get_models_by_agent_id(
.offset(offset)
.all()
)


def create_model_by_config(db: Session, deployment: Deployment, deployment_config: DeploymentSchema, model: str) -> Model:
"""
Create a new model by config if present
Args:
db (Session): Database session.
deployment (Deployment): Deployment data.
deployment_config (DeploymentSchema): Deployment config data.
model (str): Model data.
Returns:
Model: Created model.
"""
deployment_config_models = deployment_config.models
deployment_db_models = get_models_by_deployment_id(db, deployment.id)
model_to_return = None
for deployment_config_model in deployment_config_models:
model_in_db = any(record.name == deployment_config_model for record in deployment_db_models)
if not model_in_db:
new_model = Model(
name=deployment_config_model,
cohere_name=deployment_config_model,
deployment_id=deployment.id,
)
db.add(new_model)
db.commit()
if model == new_model.name:
model_to_return = new_model

return model_to_return
24 changes: 23 additions & 1 deletion src/backend/services/request_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from fastapi import HTTPException, Request

import backend.crud.user as user_crud
from backend.config.deployments import AVAILABLE_MODEL_DEPLOYMENTS
from backend.config.deployments import (
AVAILABLE_MODEL_DEPLOYMENTS,
find_config_by_deployment_id,
find_config_by_deployment_name,
)
from backend.config.tools import AVAILABLE_TOOLS
from backend.crud import agent as agent_crud
from backend.crud import conversation as conversation_crud
from backend.crud import deployment as deployment_crud
from backend.crud import model as model_crud
from backend.crud import organization as organization_crud
from backend.database_models.database import DBSessionDep
from backend.model_deployments.utils import class_name_validator
Expand All @@ -34,6 +39,19 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep
deployment_db = deployment_crud.get_deployment_by_name(session, deployment)
if not deployment_db:
deployment_db = deployment_crud.get_deployment(session, deployment)

# Check deployment config settings availability
deployment_config = find_config_by_deployment_id(deployment)
if not deployment_config:
deployment_config = find_config_by_deployment_name(deployment)
if not deployment_config:
raise HTTPException(
status_code=400,
detail=f"Deployment {deployment} not found or is not available in the Database.",
)

if not deployment_db:
deployment_db = deployment_crud.create_deployment_by_config(session, deployment_config)
if not deployment_db:
raise HTTPException(
status_code=400,
Expand All @@ -48,6 +66,10 @@ def validate_deployment_model(deployment: str, model: str, session: DBSessionDep
),
None,
)
if not deployment_model:
deployment_model = model_crud.create_model_by_config(
session, deployment_db, deployment_config, model
)
if not deployment_model:
raise HTTPException(
status_code=404,
Expand Down
2 changes: 1 addition & 1 deletion src/backend/tests/unit/factories/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ class Meta:

deployment = factory.SubFactory(DeploymentFactory)
deployment_id = factory.SelfAttribute("deployment.id")
name = factory.Faker("name")
name = "command-r-plus"
cohere_name = factory.Faker("name")
description = factory.Faker("text")
32 changes: 32 additions & 0 deletions src/backend/tests/unit/routers/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
import os

import pytest
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.config.deployments import ModelDeploymentName
from backend.config.tools import ToolName
from backend.crud import agent as agent_crud
from backend.crud import deployment as deployment_crud
from backend.database_models.agent import Agent
from backend.database_models.agent_tool_metadata import AgentToolMetadata
from backend.database_models.snapshot import Snapshot
from backend.tests.unit.factories import get_factory

is_cohere_env_set = (
os.environ.get("COHERE_API_KEY") is not None
and os.environ.get("COHERE_API_KEY") != ""
)

def test_create_agent_missing_name(
session_client: TestClient, session: Session, user
Expand Down Expand Up @@ -96,6 +103,31 @@ def test_create_agent_invalid_deployment(
}


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
def test_create_agent_deployment_not_in_db(
session_client: TestClient, session: Session, user
) -> None:
request_json = {
"name": "test agent",
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"model": "command-r-plus",
"deployment": ModelDeploymentName.CoherePlatform,
}
cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
deployment_crud.delete_deployment(session, cohere_deployment.id)
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": user.id}
)
cohere_deployment = deployment_crud.get_deployment_by_name(session, ModelDeploymentName.CoherePlatform)
deployment_models = cohere_deployment.models
deployment_models_list = [model.name for model in deployment_models]
assert response.status_code == 200
assert cohere_deployment
assert "command-r-plus" in deployment_models_list


def test_create_agent_invalid_tool(
session_client: TestClient, session: Session, user
) -> None:
Expand Down
10 changes: 5 additions & 5 deletions src/backend/tests/unit/routers/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,20 +236,20 @@ def test_streaming_chat_with_tools_not_in_agent_tools(
},
json={
"message": "Who is a tallest nba player",
"tools": [{"name": "web_search"}],
"tools": [{"name": "tavily_web_search"}],
"agent_id": agent.id,
},
)

assert response.status_code == 200
validate_chat_streaming_tool_cals_response(response, ["web_search"])
validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"])


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
def test_streaming_chat_with_agent_tools_and_empty_request_tools(
session_client_chat: TestClient, session_chat: Session, user: User
):
agent = get_factory("Agent", session_chat).create(user=user, tools=["web_search"])
agent = get_factory("Agent", session_chat).create(user=user, tools=["tavily_web_search"])
deployment = get_factory("Deployment", session_chat).create()
model = get_factory("Model", session_chat).create(deployment=deployment)
get_factory("AgentDeploymentModel", session_chat).create(
Expand All @@ -273,7 +273,7 @@ def test_streaming_chat_with_agent_tools_and_empty_request_tools(
)

assert response.status_code == 200
validate_chat_streaming_tool_cals_response(response, ["web_search"])
validate_chat_streaming_tool_cals_response(response, ["tavily_web_search"])


@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
Expand Down Expand Up @@ -369,7 +369,7 @@ def test_streaming_fail_chat_missing_message(
"loc": ["body", "message"],
"msg": "Field required",
"input": {},
"url": "https://errors.pydantic.dev/2.8/v/missing",
"url": "https://errors.pydantic.dev/2.9/v/missing",
}
]
}
Expand Down

0 comments on commit ea4a660

Please sign in to comment.