Skip to content

Commit

Permalink
[backend] make deployment field optional in API and DB (cohere-ai#213)
Browse files Browse the repository at this point in the history
* changes

* saving changes

* lint
  • Loading branch information
scott-cohere authored Jun 14, 2024
1 parent 24c4164 commit 3cd9a44
Show file tree
Hide file tree
Showing 14 changed files with 93 additions and 61 deletions.
4 changes: 3 additions & 1 deletion src/backend/database_models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ class Agent(Base):
Enum(AgentModel, native_enum=False), nullable=False
)
deployment: Mapped[AgentDeployment] = mapped_column(
Enum(AgentDeployment, native_enum=False), nullable=False
Enum(AgentDeployment, native_enum=False),
default=AgentDeployment.COHERE_PLATFORM,
nullable=False,
)

user_id: Mapped[str] = mapped_column(Text, nullable=False)
Expand Down
2 changes: 1 addition & 1 deletion src/backend/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class CreateAgent(BaseModel):
preamble: Optional[str] = None
temperature: Optional[float] = None
model: AgentModel
deployment: AgentDeployment
deployment: Optional[AgentDeployment] = None
tools: Optional[list[ToolName]] = None

class Config:
Expand Down
12 changes: 0 additions & 12 deletions src/backend/tests/crud/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ def test_create_agent_empty_non_required_fields(session, user):
agent_data = Agent(
user_id=user.id,
name="test",
deployment=AgentDeployment.COHERE_PLATFORM,
model=AgentModel.COMMAND_R_PLUS,
)

Expand Down Expand Up @@ -99,17 +98,6 @@ def test_create_agent_missing_model(session, user):
_ = agent_crud.create_agent(session, agent_data)


def test_create_agent_missing_deployment(session, user):
agent_data = Agent(
user_id=user.id,
name="test",
model=AgentModel.COMMAND_R_PLUS,
)

with pytest.raises(IntegrityError):
_ = agent_crud.create_agent(session, agent_data)


def test_create_agent_missing_user_id(session):
agent_data = Agent(
name="test",
Expand Down
19 changes: 0 additions & 19 deletions src/backend/tests/routers/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,6 @@ def test_create_agent_missing_model(
assert response.status_code == 422


def test_create_agent_missing_deployment(
session_client: TestClient, session: Session
) -> None:
request_json = {
"name": "test agent",
"description": "test description",
"preamble": "test preamble",
"temperature": 0.5,
"model": AgentModel.COMMAND_R,
}
response = session_client.post(
"/v1/agents", json=request_json, headers={"User-Id": "123"}
)
assert response.status_code == 422


def test_create_agent_missing_user_id_header(
session_client: TestClient, session: Session
) -> None:
Expand All @@ -111,7 +95,6 @@ def test_create_agent_missing_non_required_fields(
request_json = {
"name": "test agent",
"model": AgentModel.COMMAND_R,
"deployment": AgentDeployment.COHERE_PLATFORM,
}

print(request_json)
Expand All @@ -128,7 +111,6 @@ def test_create_agent_missing_non_required_fields(
assert response_agent["preamble"] == ""
assert response_agent["temperature"] == 0.3
assert response_agent["model"] == request_json["model"]
assert response_agent["deployment"] == request_json["deployment"]

agent = session.get(Agent, response_agent["id"])
assert agent is not None
Expand All @@ -138,7 +120,6 @@ def test_create_agent_missing_non_required_fields(
assert agent.preamble == ""
assert agent.temperature == 0.3
assert agent.model == request_json["model"]
assert agent.deployment == request_json["deployment"]


def test_create_agent_wrong_model_deployment_enums(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ export type { OpenAPIConfig } from './core/OpenAPI';
export type { Agent } from './models/Agent';
export { AgentDeployment } from './models/AgentDeployment';
export { AgentModel } from './models/AgentModel';
export type { Auth } from './models/Auth';
export type { Body_upload_file_v1_conversations_upload_file_post } from './models/Body_upload_file_v1_conversations_upload_file_post';
export { Category } from './models/Category';
export type { ChatMessage } from './models/ChatMessage';
Expand All @@ -31,9 +30,12 @@ export type { Deployment } from './models/Deployment';
export type { Document } from './models/Document';
export type { File } from './models/File';
export type { HTTPValidationError } from './models/HTTPValidationError';
export type { JWTResponse } from './models/JWTResponse';
export type { LangchainChatRequest } from './models/LangchainChatRequest';
export type { ListAuthStrategy } from './models/ListAuthStrategy';
export type { ListFile } from './models/ListFile';
export type { Login } from './models/Login';
export type { Logout } from './models/Logout';
export type { ManagedTool } from './models/ManagedTool';
export type { Message } from './models/Message';
export { MessageAgent } from './models/MessageAgent';
Expand All @@ -53,6 +55,7 @@ export type { StreamToolResult } from './models/StreamToolResult';
export type { Tool } from './models/Tool';
export type { ToolCall } from './models/ToolCall';
export { ToolInputType } from './models/ToolInputType';
export { ToolName } from './models/ToolName';
export type { UpdateAgent } from './models/UpdateAgent';
export type { UpdateConversation } from './models/UpdateConversation';
export type { UpdateDeploymentEnv } from './models/UpdateDeploymentEnv';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/* eslint-disable */
import type { AgentDeployment } from './AgentDeployment';
import type { AgentModel } from './AgentModel';
import type { ToolName } from './ToolName';

export type Agent = {
user_id: string;
Expand All @@ -18,6 +19,7 @@ export type Agent = {
description: string | null;
preamble: string | null;
temperature: number;
tools: Array<ToolName>;
model: AgentModel;
deployment: AgentDeployment;
};
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/* eslint-disable */
import type { AgentDeployment } from './AgentDeployment';
import type { AgentModel } from './AgentModel';
import type { ToolName } from './ToolName';

export type CreateAgent = {
name: string;
Expand All @@ -15,5 +16,6 @@ export type CreateAgent = {
preamble?: string | null;
temperature?: number | null;
model: AgentModel;
deployment: AgentDeployment;
deployment?: AgentDeployment | null;
tools?: Array<ToolName> | null;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type JWTResponse = {
token: string;
};
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Auth = {
export type ListAuthStrategy = {
strategy: string;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export type Logout = {};
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ export type StreamEnd = {
search_results?: Array<Record<string, any>>;
search_queries?: Array<SearchQuery>;
tool_calls?: Array<ToolCall>;
finish_reason: string;
finish_reason?: string | null;
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
/* generated using openapi-typescript-codegen -- do no edit */
/* istanbul ignore file */
/* tslint:disable */
/* eslint-disable */
export enum ToolName {
WIKIPEDIA = 'Wikipedia',
SEARCH_FILE = 'search_file',
READ_DOCUMENT = 'read_document',
PYTHON_INTERPRETER = 'Python_Interpreter',
CALCULATOR = 'Calculator',
INTERNET_SEARCH = 'Internet_Search',
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
/* eslint-disable */
import type { AgentDeployment } from './AgentDeployment';
import type { AgentModel } from './AgentModel';
import type { ToolName } from './ToolName';

export type UpdateAgent = {
name?: string | null;
Expand All @@ -16,4 +17,5 @@ export type UpdateAgent = {
temperature?: number | null;
model?: AgentModel | null;
deployment?: AgentDeployment | null;
tools?: Array<ToolName> | null;
};
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import type { CancelablePromise } from '../core/CancelablePromise';
import { OpenAPI } from '../core/OpenAPI';
import { request as __request } from '../core/request';
import type { Agent } from '../models/Agent';
import type { Auth } from '../models/Auth';
import type { Body_upload_file_v1_conversations_upload_file_post } from '../models/Body_upload_file_v1_conversations_upload_file_post';
import type { ChatResponseEvent } from '../models/ChatResponseEvent';
import type { CohereChatRequest } from '../models/CohereChatRequest';
Expand All @@ -23,9 +22,12 @@ import type { DeleteFile } from '../models/DeleteFile';
import type { DeleteUser } from '../models/DeleteUser';
import type { Deployment } from '../models/Deployment';
import type { File } from '../models/File';
import type { JWTResponse } from '../models/JWTResponse';
import type { LangchainChatRequest } from '../models/LangchainChatRequest';
import type { ListAuthStrategy } from '../models/ListAuthStrategy';
import type { ListFile } from '../models/ListFile';
import type { Login } from '../models/Login';
import type { Logout } from '../models/Logout';
import type { ManagedTool } from '../models/ManagedTool';
import type { NonStreamedChatResponse } from '../models/NonStreamedChatResponse';
import type { UpdateAgent } from '../models/UpdateAgent';
Expand All @@ -44,10 +46,10 @@ export class DefaultService {
*
* Returns:
* List[dict]: List of dictionaries containing the enabled auth strategy names.
* @returns any Successful Response
* @returns ListAuthStrategy Successful Response
* @throws ApiError
*/
public static getStrategiesV1AuthStrategiesGet(): CancelablePromise<any> {
public static getStrategiesV1AuthStrategiesGet(): CancelablePromise<Array<ListAuthStrategy>> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/auth_strategies',
Expand All @@ -74,7 +76,11 @@ export class DefaultService {
* @returns any Successful Response
* @throws ApiError
*/
public static loginV1LoginPost({ requestBody }: { requestBody: Login }): CancelablePromise<any> {
public static loginV1LoginPost({
requestBody,
}: {
requestBody: Login;
}): CancelablePromise<JWTResponse | null> {
return __request(OpenAPI, {
method: 'POST',
url: '/v1/login',
Expand All @@ -86,35 +92,47 @@ export class DefaultService {
});
}
/**
* Authenticate
* Authentication endpoint used for OAuth strategies. Logs the user in the redirect environment and then
* sets the current session with the user returned from the auth token.
* Google Authenticate
* Callback authentication endpoint used for Google OAuth after redirecting to
* the service's login screen.
*
* Args:
* request (Request): current Request object.
* login (Login): Login payload.
*
* Returns:
* RedirectResponse: On success.
*
* Raises:
* HTTPException: If authentication fails, or strategy is invalid.
* @returns any Successful Response
* @returns JWTResponse Successful Response
* @throws ApiError
*/
public static authenticateV1AuthPost({
requestBody,
}: {
requestBody: Auth;
}): CancelablePromise<any> {
public static googleAuthenticateV1GoogleAuthGet(): CancelablePromise<JWTResponse> {
return __request(OpenAPI, {
method: 'POST',
url: '/v1/auth',
body: requestBody,
mediaType: 'application/json',
errors: {
422: `Validation Error`,
},
method: 'GET',
url: '/v1/google/auth',
});
}
/**
* Oidc Authenticate
* Callback authentication endpoint used for OIDC after redirecting to
* the service's login screen.
*
* Args:
* request (Request): current Request object.
*
* Returns:
* RedirectResponse: On success.
*
* Raises:
* HTTPException: If authentication fails, or strategy is invalid.
* @returns JWTResponse Successful Response
* @throws ApiError
*/
public static oidcAuthenticateV1OidcAuthGet(): CancelablePromise<JWTResponse> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/oidc/auth',
});
}
/**
Expand All @@ -126,10 +144,10 @@ export class DefaultService {
*
* Returns:
* dict: Empty on success
* @returns any Successful Response
* @returns Logout Successful Response
* @throws ApiError
*/
public static logoutV1LogoutGet(): CancelablePromise<any> {
public static logoutV1LogoutGet(): CancelablePromise<Logout> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/logout',
Expand Down Expand Up @@ -663,10 +681,20 @@ export class DefaultService {
* @returns ManagedTool Successful Response
* @throws ApiError
*/
public static listToolsV1ToolsGet(): CancelablePromise<Array<ManagedTool>> {
public static listToolsV1ToolsGet({
agentId,
}: {
agentId?: string | null;
}): CancelablePromise<Array<ManagedTool>> {
return __request(OpenAPI, {
method: 'GET',
url: '/v1/tools',
query: {
agent_id: agentId,
},
errors: {
422: `Validation Error`,
},
});
}
/**
Expand Down

0 comments on commit 3cd9a44

Please sign in to comment.