Skip to content

Commit

Permalink
[coral-web] rename schema fields and update openapi client (cohere-ai…
Browse files Browse the repository at this point in the history
…#189)

* update

* update client
  • Loading branch information
scott-cohere authored and ClaytonSmith committed Jun 11, 2024
1 parent c468520 commit a897f73
Show file tree
Hide file tree
Showing 28 changed files with 375 additions and 100 deletions.
12 changes: 7 additions & 5 deletions src/backend/database_models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
from backend.database_models.base import Base


class Deployment(StrEnum):
class AgentDeployment(StrEnum):
COHERE_PLATFORM = "Cohere Platform"
SAGE_MAKER = "SageMaker"
AZURE = "Azure"
BEDROCK = "Bedrock"


class Model(StrEnum):
class AgentModel(StrEnum):
COMMAND_R = "command-r"
COMMAND_R_PLUS = "command-r-plus"
COMMAND_LIGHT = "command-light"
Expand All @@ -33,9 +33,11 @@ class Agent(Base):
# TODO @scott-cohere: eventually switch to Fkey when new deployment tables are implemented
# TODO @scott-cohere: deployments have different names for models, need to implement mapping later
# enum place holders
model: Mapped[Model] = mapped_column(Enum(Model, native_enum=False), nullable=False)
deployment: Mapped[Deployment] = mapped_column(
Enum(Deployment, native_enum=False), nullable=False
model: Mapped[AgentModel] = mapped_column(
Enum(AgentModel, native_enum=False), nullable=False
)
deployment: Mapped[AgentDeployment] = mapped_column(
Enum(AgentDeployment, native_enum=False), nullable=False
)

user_id: Mapped[str] = mapped_column(Text, nullable=False)
Expand Down
14 changes: 7 additions & 7 deletions src/backend/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pydantic import BaseModel

from backend.database_models.agent import Deployment, Model
from backend.database_models.agent import AgentDeployment, AgentModel


class AgentBase(BaseModel):
Expand All @@ -22,8 +22,8 @@ class Agent(AgentBase):
temperature: float
# tools: List[Tool]

model: Model
deployment: Deployment
model: AgentModel
deployment: AgentDeployment

class Config:
from_attributes = True
Expand All @@ -36,8 +36,8 @@ class CreateAgent(BaseModel):
description: Optional[str] = None
preamble: Optional[str] = None
temperature: Optional[float] = None
model: Model
deployment: Deployment
model: AgentModel
deployment: AgentDeployment

class Config:
from_attributes = True
Expand All @@ -50,8 +50,8 @@ class UpdateAgent(BaseModel):
description: Optional[str] = None
preamble: Optional[str] = None
temperature: Optional[float] = None
model: Optional[Model] = None
deployment: Optional[Deployment] = None
model: Optional[AgentModel] = None
deployment: Optional[AgentDeployment] = None
# tools: Optional[List[Tool]] = None

class Config:
Expand Down
42 changes: 21 additions & 21 deletions src/backend/tests/crud/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy.exc import IntegrityError

from backend.crud import agent as agent_crud
from backend.database_models.agent import Agent, Deployment, Model
from backend.database_models.agent import Agent, AgentDeployment, AgentModel
from backend.schemas.agent import UpdateAgent
from backend.tests.factories import get_factory

Expand All @@ -17,8 +17,8 @@ def test_create_agent(session, user):
description="test",
preamble="test",
temperature=0.5,
model=Model.COMMAND_R_PLUS,
deployment=Deployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R_PLUS,
deployment=AgentDeployment.COHERE_PLATFORM,
)

agent = agent_crud.create_agent(session, agent_data)
Expand All @@ -28,8 +28,8 @@ def test_create_agent(session, user):
assert agent.description == "test"
assert agent.preamble == "test"
assert agent.temperature == 0.5
assert agent.model == Model.COMMAND_R_PLUS
assert agent.deployment == Deployment.COHERE_PLATFORM
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM

agent = agent_crud.get_agent(session, agent.id)
assert agent.user_id == user.id
Expand All @@ -38,16 +38,16 @@ def test_create_agent(session, user):
assert agent.description == "test"
assert agent.preamble == "test"
assert agent.temperature == 0.5
assert agent.model == Model.COMMAND_R_PLUS
assert agent.deployment == Deployment.COHERE_PLATFORM
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM


def test_create_agent_empty_non_required_fields(session, user):
agent_data = Agent(
user_id=user.id,
name="test",
deployment=Deployment.COHERE_PLATFORM,
model=Model.COMMAND_R_PLUS,
deployment=AgentDeployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R_PLUS,
)

agent = agent_crud.create_agent(session, agent_data)
Expand All @@ -57,8 +57,8 @@ def test_create_agent_empty_non_required_fields(session, user):
assert agent.description == ""
assert agent.preamble == ""
assert agent.temperature == 0.3
assert agent.model == Model.COMMAND_R_PLUS
assert agent.deployment == Deployment.COHERE_PLATFORM
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM

agent = agent_crud.get_agent(session, agent.id)
assert agent.user_id == user.id
Expand All @@ -67,15 +67,15 @@ def test_create_agent_empty_non_required_fields(session, user):
assert agent.description == ""
assert agent.preamble == ""
assert agent.temperature == 0.3
assert agent.model == Model.COMMAND_R_PLUS
assert agent.deployment == Deployment.COHERE_PLATFORM
assert agent.model == AgentModel.COMMAND_R_PLUS
assert agent.deployment == AgentDeployment.COHERE_PLATFORM


def test_create_agent_missing_name(session, user):
agent_data = Agent(
user_id=user.id,
model=Model.COMMAND_R_PLUS,
deployment=Deployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R_PLUS,
deployment=AgentDeployment.COHERE_PLATFORM,
)

with pytest.raises(IntegrityError):
Expand All @@ -86,7 +86,7 @@ def test_create_agent_missing_model(session, user):
agent_data = Agent(
user_id=user.id,
name="test",
deployment=Deployment.COHERE_PLATFORM,
deployment=AgentDeployment.COHERE_PLATFORM,
)

with pytest.raises(IntegrityError):
Expand All @@ -97,7 +97,7 @@ def test_create_agent_missing_deployment(session, user):
agent_data = Agent(
user_id=user.id,
name="test",
model=Model.COMMAND_R_PLUS,
model=AgentModel.COMMAND_R_PLUS,
)

with pytest.raises(IntegrityError):
Expand All @@ -107,8 +107,8 @@ def test_create_agent_missing_deployment(session, user):
def test_create_agent_missing_user_id(session):
agent_data = Agent(
name="test",
model=Model.COMMAND_R_PLUS,
deployment=Deployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R_PLUS,
deployment=AgentDeployment.COHERE_PLATFORM,
)

with pytest.raises(IntegrityError):
Expand All @@ -127,8 +127,8 @@ def test_create_agent_duplicate_name_version(session, user):
description="test",
preamble="test",
temperature=0.5,
model=Model.COMMAND_R_PLUS,
deployment=Deployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R_PLUS,
deployment=AgentDeployment.COHERE_PLATFORM,
)

with pytest.raises(IntegrityError):
Expand Down
18 changes: 9 additions & 9 deletions src/backend/tests/factories/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import factory

from backend.database_models.agent import Agent, Deployment, Model
from backend.database_models.agent import Agent, AgentDeployment, AgentModel

from .base import BaseFactory

Expand All @@ -20,18 +20,18 @@ class Meta:
model = factory.Faker(
"random_element",
elements=(
Model.COMMAND_R,
Model.COMMAND_R_PLUS,
Model.COMMAND_LIGHT,
Model.COMMAND,
AgentModel.COMMAND_R,
AgentModel.COMMAND_R_PLUS,
AgentModel.COMMAND_LIGHT,
AgentModel.COMMAND,
),
)
deployment = factory.Faker(
"random_element",
elements=(
Deployment.COHERE_PLATFORM,
Deployment.SAGE_MAKER,
Deployment.AZURE,
Deployment.BEDROCK,
AgentDeployment.COHERE_PLATFORM,
AgentDeployment.SAGE_MAKER,
AgentDeployment.AZURE,
AgentDeployment.BEDROCK,
),
)
46 changes: 23 additions & 23 deletions src/backend/tests/routers/test_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from fastapi.testclient import TestClient
from sqlalchemy.orm import Session

from backend.database_models.agent import Agent, Deployment, Model
from backend.database_models.agent import Agent, AgentDeployment, AgentModel
from backend.tests.factories import get_factory


Expand All @@ -12,8 +12,8 @@ def test_create_agent(session_client: TestClient, session: Session) -> None:
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"model": Model.COMMAND_R,
"deployment": Deployment.COHERE_PLATFORM,
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
}

response = session_client.post(
Expand Down Expand Up @@ -48,8 +48,8 @@ def test_create_agent_missing_name(
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"model": Model.COMMAND_R,
"deployment": Deployment.COHERE_PLATFORM,
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
}
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": "123"}
Expand All @@ -65,7 +65,7 @@ def test_create_agent_missing_model(
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"deployment": Deployment.COHERE_PLATFORM,
"deployment": AgentDeployment.COHERE_PLATFORM,
}
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": "123"}
Expand All @@ -81,7 +81,7 @@ def test_create_agent_missing_deployment(
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"model": Model.COMMAND_R,
"model": AgentModel.COMMAND_R,
}
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": "123"}
Expand All @@ -94,8 +94,8 @@ def test_create_agent_missing_user_id_header(
) -> None:
request_json = {
"name": "test agent",
"model": Model.COMMAND_R,
"deployment": Deployment.COHERE_PLATFORM,
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
}
response = session_client.post("/v1/agents", json=request_json)
assert response.status_code == 401
Expand All @@ -106,8 +106,8 @@ def test_create_agent_missing_non_required_fields(
) -> None:
request_json = {
"name": "test agent",
"model": Model.COMMAND_R,
"deployment": Deployment.COHERE_PLATFORM,
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
}

print(request_json)
Expand Down Expand Up @@ -216,8 +216,8 @@ def test_update_agent(session_client: TestClient, session: Session) -> None:
description="test description",
preamble="test preamble",
temperature=0.5,
model=Model.COMMAND_R,
deployment=Deployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R,
deployment=AgentDeployment.COHERE_PLATFORM,
)

request_json = {
Expand All @@ -226,8 +226,8 @@ def test_update_agent(session_client: TestClient, session: Session) -> None:
"description": "updated description",
"preamble": "updated preamble",
"temperature": 0.7,
"model": Model.COMMAND_R_PLUS,
"deployment": Deployment.SAGE_MAKER,
"model": AgentModel.COMMAND_R_PLUS,
"deployment": AgentDeployment.SAGE_MAKER,
}

response = session_client.put(
Expand All @@ -240,8 +240,8 @@ def test_update_agent(session_client: TestClient, session: Session) -> None:
assert updated_agent["description"] == "updated description"
assert updated_agent["preamble"] == "updated preamble"
assert updated_agent["temperature"] == 0.7
assert updated_agent["model"] == Model.COMMAND_R_PLUS
assert updated_agent["deployment"] == Deployment.SAGE_MAKER
assert updated_agent["model"] == AgentModel.COMMAND_R_PLUS
assert updated_agent["deployment"] == AgentDeployment.SAGE_MAKER


def test_partial_update_agent(session_client: TestClient, session: Session) -> None:
Expand All @@ -251,8 +251,8 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N
description="test description",
preamble="test preamble",
temperature=0.5,
model=Model.COMMAND_R,
deployment=Deployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R,
deployment=AgentDeployment.COHERE_PLATFORM,
)

request_json = {
Expand All @@ -269,8 +269,8 @@ def test_partial_update_agent(session_client: TestClient, session: Session) -> N
assert updated_agent["description"] == "test description"
assert updated_agent["preamble"] == "test preamble"
assert updated_agent["temperature"] == 0.5
assert updated_agent["model"] == Model.COMMAND_R
assert updated_agent["deployment"] == Deployment.COHERE_PLATFORM
assert updated_agent["model"] == AgentModel.COMMAND_R
assert updated_agent["deployment"] == AgentDeployment.COHERE_PLATFORM


def test_update_nonexistent_agent(session_client: TestClient, session: Session) -> None:
Expand All @@ -293,8 +293,8 @@ def test_update_agent_wrong_model_deployment_enums(
description="test description",
preamble="test preamble",
temperature=0.5,
model=Model.COMMAND_R,
deployment=Deployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R,
deployment=AgentDeployment.COHERE_PLATFORM,
)

request_json = {
Expand Down
7 changes: 7 additions & 0 deletions src/interfaces/coral_web/src/cohere-client/generated/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ export { CancelablePromise, CancelError } from './core/CancelablePromise';
export { OpenAPI } from './core/OpenAPI';
export type { OpenAPIConfig } from './core/OpenAPI';

export type { Agent } from './models/Agent';
export { AgentDeployment } from './models/AgentDeployment';
export { AgentModel } from './models/AgentModel';
export type { Auth } from './models/Auth';
export type { Body_upload_file_v1_conversations_upload_file_post } from './models/Body_upload_file_v1_conversations_upload_file_post';
export { Category } from './models/Category';
export type { ChatMessage } from './models/ChatMessage';
Expand All @@ -17,7 +21,9 @@ export { CohereChatPromptTruncation } from './models/CohereChatPromptTruncation'
export type { CohereChatRequest } from './models/CohereChatRequest';
export type { Conversation } from './models/Conversation';
export type { ConversationWithoutMessages } from './models/ConversationWithoutMessages';
export type { CreateAgent } from './models/CreateAgent';
export type { CreateUser } from './models/CreateUser';
export type { DeleteAgent } from './models/DeleteAgent';
export type { DeleteConversation } from './models/DeleteConversation';
export type { DeleteFile } from './models/DeleteFile';
export type { DeleteUser } from './models/DeleteUser';
Expand Down Expand Up @@ -47,6 +53,7 @@ export type { StreamToolResult } from './models/StreamToolResult';
export type { Tool } from './models/Tool';
export type { ToolCall } from './models/ToolCall';
export { ToolInputType } from './models/ToolInputType';
export type { UpdateAgent } from './models/UpdateAgent';
export type { UpdateConversation } from './models/UpdateConversation';
export type { UpdateDeploymentEnv } from './models/UpdateDeploymentEnv';
export type { UpdateFile } from './models/UpdateFile';
Expand Down
Loading

0 comments on commit a897f73

Please sign in to comment.