Skip to content

Commit c19b2ad

Browse files
authored
Merge pull request #1069 from julep-ai/x/model-validation
fix(agents-api): add model validation for agent and chat endpoints
2 parents 175b9ae + 828ac95 commit c19b2ad

File tree

9 files changed

+107
-4
lines changed

9 files changed

+107
-4
lines changed

agents-api/agents_api/clients/litellm.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import wraps
22
from typing import Literal
33

4+
import aiohttp
45
from beartype import beartype
56
from litellm import acompletion as _acompletion
67
from litellm import aembedding as _aembedding
@@ -109,3 +110,26 @@ async def aembedding(
109110
for item in embedding_list
110111
if len(item["embedding"]) >= dimensions
111112
]
113+
114+
115+
@beartype
116+
async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]:
117+
"""
118+
Fetches the list of available models from the LiteLLM server.
119+
120+
Returns:
121+
list[dict]: A list of model information dictionaries
122+
"""
123+
124+
headers = {
125+
"accept": "application/json",
126+
"x-api-key": custom_api_key or litellm_master_key
127+
}
128+
129+
async with aiohttp.ClientSession() as session, session.get(
130+
url=f"{litellm_url}/models" if not custom_api_key else "/models",
131+
headers=headers
132+
) as response:
133+
response.raise_for_status()
134+
data = await response.json()
135+
return data["data"]

agents-api/agents_api/routers/agents/create_agent.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from ...dependencies.developer_id import get_developer_id
1212
from ...queries.agents.create_agent import create_agent as create_agent_query
13+
from ..utils.model_validation import validate_model
1314
from .router import router
1415

1516

@@ -18,7 +19,10 @@ async def create_agent(
1819
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
1920
data: CreateAgentRequest,
2021
) -> ResourceCreatedResponse:
21-
# TODO: Validate model name
22+
23+
if data.model:
24+
await validate_model(data.model)
25+
2226
agent = await create_agent_query(
2327
developer_id=x_developer_id,
2428
data=data,

agents-api/agents_api/routers/agents/create_or_update_agent.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ...queries.agents.create_or_update_agent import (
1313
create_or_update_agent as create_or_update_agent_query,
1414
)
15+
from ..utils.model_validation import validate_model
1516
from .router import router
1617

1718

@@ -21,7 +22,10 @@ async def create_or_update_agent(
2122
data: CreateOrUpdateAgentRequest,
2223
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
2324
) -> ResourceCreatedResponse:
24-
# TODO: Validate model name
25+
26+
if data.model:
27+
await validate_model(data.model)
28+
2529
agent = await create_or_update_agent_query(
2630
developer_id=x_developer_id,
2731
agent_id=agent_id,

agents-api/agents_api/routers/agents/patch_agent.py

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ...autogen.openapi_model import PatchAgentRequest, ResourceUpdatedResponse
88
from ...dependencies.developer_id import get_developer_id
99
from ...queries.agents.patch_agent import patch_agent as patch_agent_query
10+
from ..utils.model_validation import validate_model
1011
from .router import router
1112

1213

@@ -21,6 +22,10 @@ async def patch_agent(
2122
agent_id: UUID,
2223
data: PatchAgentRequest,
2324
) -> ResourceUpdatedResponse:
25+
26+
if data.model:
27+
await validate_model(data.model)
28+
2429
return await patch_agent_query(
2530
agent_id=agent_id,
2631
developer_id=x_developer_id,

agents-api/agents_api/routers/agents/update_agent.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ...autogen.openapi_model import ResourceUpdatedResponse, UpdateAgentRequest
88
from ...dependencies.developer_id import get_developer_id
99
from ...queries.agents.update_agent import update_agent as update_agent_query
10+
from ..utils.model_validation import validate_model
1011
from .router import router
1112

1213

@@ -20,7 +21,11 @@ async def update_agent(
2021
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
2122
agent_id: UUID,
2223
data: UpdateAgentRequest,
23-
) -> ResourceUpdatedResponse:
24+
) -> ResourceUpdatedResponse:
25+
26+
if data.model:
27+
await validate_model(data.model)
28+
2429
return await update_agent_query(
2530
developer_id=x_developer_id,
2631
agent_id=agent_id,

agents-api/agents_api/routers/sessions/chat.py

+5
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from ...queries.chat.prepare_chat_context import prepare_chat_context
2424
from ...queries.entries.create_entries import create_entries
2525
from ...queries.sessions.count_sessions import count_sessions as count_sessions_query
26+
from ..utils.model_validation import validate_model
2627
from .metrics import total_tokens_per_user
2728
from .router import router
2829

@@ -55,6 +56,10 @@ async def chat(
5556
Returns:
5657
ChatResponse: The chat response.
5758
"""
59+
60+
if chat_input.model:
61+
await validate_model(chat_input.model)
62+
5863
# check if the developer is paid
5964
if "paid" not in developer.tags:
6065
# get the session length
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from fastapi import HTTPException
2+
from starlette.status import HTTP_400_BAD_REQUEST
3+
4+
from ...clients.litellm import get_model_list
5+
6+
7+
async def validate_model(model_name: str) -> None:
8+
"""
9+
Validates if a given model name is available in LiteLLM.
10+
Raises HTTPException if model is not available.
11+
"""
12+
models = await get_model_list()
13+
available_models = [model["id"] for model in models]
14+
15+
if model_name not in available_models:
16+
raise HTTPException(
17+
status_code=HTTP_400_BAD_REQUEST,
18+
detail=f"Model {model_name} not available. Available models: {available_models}"
19+
)

agents-api/tests/fixtures.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import random
33
import string
44
import sys
5+
from unittest.mock import patch
56
from uuid import UUID
67

78
from agents_api.autogen.openapi_model import (
@@ -440,10 +441,18 @@ async def test_tool(
440441
return tool
441442

442443

444+
SAMPLE_MODELS = [
445+
{"id": "gpt-4"},
446+
{"id": "gpt-3.5-turbo"},
447+
{"id": "gpt-4o-mini"},
448+
]
449+
450+
443451
@fixture(scope="global")
444452
def client(_dsn=pg_dsn):
445453
with TestClient(app=app) as client:
446-
yield client
454+
with patch("agents_api.routers.utils.model_validation.get_model_list", return_value=SAMPLE_MODELS):
455+
yield client
447456

448457

449458
@fixture(scope="global")
+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from unittest.mock import patch
2+
3+
from agents_api.routers.utils.model_validation import validate_model
4+
from fastapi import HTTPException
5+
from ward import raises, test
6+
7+
from tests.fixtures import SAMPLE_MODELS
8+
9+
10+
@test("validate_model: succeeds when model is available")
11+
async def _():
12+
# Use async context manager for patching
13+
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
14+
mock_get_models.return_value = SAMPLE_MODELS
15+
await validate_model("gpt-4o-mini")
16+
mock_get_models.assert_called_once()
17+
18+
19+
@test("validate_model: fails when model is unavailable")
20+
async def _():
21+
with patch("agents_api.routers.utils.model_validation.get_model_list") as mock_get_models:
22+
mock_get_models.return_value = SAMPLE_MODELS
23+
with raises(HTTPException) as exc:
24+
await validate_model("non-existent-model")
25+
26+
assert exc.raised.status_code == 400
27+
assert "Model non-existent-model not available" in exc.raised.detail
28+
mock_get_models.assert_called_once()

0 commit comments

Comments
 (0)