File tree 9 files changed +107
-4
lines changed
9 files changed +107
-4
lines changed Original file line number Diff line number Diff line change 1
1
from functools import wraps
2
2
from typing import Literal
3
3
4
+ import aiohttp
4
5
from beartype import beartype
5
6
from litellm import acompletion as _acompletion
6
7
from litellm import aembedding as _aembedding
@@ -109,3 +110,26 @@ async def aembedding(
109
110
for item in embedding_list
110
111
if len (item ["embedding" ]) >= dimensions
111
112
]
113
+
114
+
115
+ @beartype
116
+ async def get_model_list (* , custom_api_key : str | None = None ) -> list [dict ]:
117
+ """
118
+ Fetches the list of available models from the LiteLLM server.
119
+
120
+ Returns:
121
+ list[dict]: A list of model information dictionaries
122
+ """
123
+
124
+ headers = {
125
+ "accept" : "application/json" ,
126
+ "x-api-key" : custom_api_key or litellm_master_key
127
+ }
128
+
129
+ async with aiohttp .ClientSession () as session , session .get (
130
+ url = f"{ litellm_url } /models" if not custom_api_key else "/models" ,
131
+ headers = headers
132
+ ) as response :
133
+ response .raise_for_status ()
134
+ data = await response .json ()
135
+ return data ["data" ]
Original file line number Diff line number Diff line change 10
10
)
11
11
from ...dependencies .developer_id import get_developer_id
12
12
from ...queries .agents .create_agent import create_agent as create_agent_query
13
+ from ..utils .model_validation import validate_model
13
14
from .router import router
14
15
15
16
@@ -18,7 +19,10 @@ async def create_agent(
18
19
x_developer_id : Annotated [UUID , Depends (get_developer_id )],
19
20
data : CreateAgentRequest ,
20
21
) -> ResourceCreatedResponse :
21
- # TODO: Validate model name
22
+
23
+ if data .model :
24
+ await validate_model (data .model )
25
+
22
26
agent = await create_agent_query (
23
27
developer_id = x_developer_id ,
24
28
data = data ,
Original file line number Diff line number Diff line change 12
12
from ...queries .agents .create_or_update_agent import (
13
13
create_or_update_agent as create_or_update_agent_query ,
14
14
)
15
+ from ..utils .model_validation import validate_model
15
16
from .router import router
16
17
17
18
@@ -21,7 +22,10 @@ async def create_or_update_agent(
21
22
data : CreateOrUpdateAgentRequest ,
22
23
x_developer_id : Annotated [UUID , Depends (get_developer_id )],
23
24
) -> ResourceCreatedResponse :
24
- # TODO: Validate model name
25
+
26
+ if data .model :
27
+ await validate_model (data .model )
28
+
25
29
agent = await create_or_update_agent_query (
26
30
developer_id = x_developer_id ,
27
31
agent_id = agent_id ,
Original file line number Diff line number Diff line change 7
7
from ...autogen .openapi_model import PatchAgentRequest , ResourceUpdatedResponse
8
8
from ...dependencies .developer_id import get_developer_id
9
9
from ...queries .agents .patch_agent import patch_agent as patch_agent_query
10
+ from ..utils .model_validation import validate_model
10
11
from .router import router
11
12
12
13
@@ -21,6 +22,10 @@ async def patch_agent(
21
22
agent_id : UUID ,
22
23
data : PatchAgentRequest ,
23
24
) -> ResourceUpdatedResponse :
25
+
26
+ if data .model :
27
+ await validate_model (data .model )
28
+
24
29
return await patch_agent_query (
25
30
agent_id = agent_id ,
26
31
developer_id = x_developer_id ,
Original file line number Diff line number Diff line change 7
7
from ...autogen .openapi_model import ResourceUpdatedResponse , UpdateAgentRequest
8
8
from ...dependencies .developer_id import get_developer_id
9
9
from ...queries .agents .update_agent import update_agent as update_agent_query
10
+ from ..utils .model_validation import validate_model
10
11
from .router import router
11
12
12
13
@@ -20,7 +21,11 @@ async def update_agent(
20
21
x_developer_id : Annotated [UUID , Depends (get_developer_id )],
21
22
agent_id : UUID ,
22
23
data : UpdateAgentRequest ,
23
- ) -> ResourceUpdatedResponse :
24
+ ) -> ResourceUpdatedResponse :
25
+
26
+ if data .model :
27
+ await validate_model (data .model )
28
+
24
29
return await update_agent_query (
25
30
developer_id = x_developer_id ,
26
31
agent_id = agent_id ,
Original file line number Diff line number Diff line change 23
23
from ...queries .chat .prepare_chat_context import prepare_chat_context
24
24
from ...queries .entries .create_entries import create_entries
25
25
from ...queries .sessions .count_sessions import count_sessions as count_sessions_query
26
+ from ..utils .model_validation import validate_model
26
27
from .metrics import total_tokens_per_user
27
28
from .router import router
28
29
@@ -55,6 +56,10 @@ async def chat(
55
56
Returns:
56
57
ChatResponse: The chat response.
57
58
"""
59
+
60
+ if chat_input .model :
61
+ await validate_model (chat_input .model )
62
+
58
63
# check if the developer is paid
59
64
if "paid" not in developer .tags :
60
65
# get the session length
Original file line number Diff line number Diff line change
1
+ from fastapi import HTTPException
2
+ from starlette .status import HTTP_400_BAD_REQUEST
3
+
4
+ from ...clients .litellm import get_model_list
5
+
6
+
7
+ async def validate_model (model_name : str ) -> None :
8
+ """
9
+ Validates if a given model name is available in LiteLLM.
10
+ Raises HTTPException if model is not available.
11
+ """
12
+ models = await get_model_list ()
13
+ available_models = [model ["id" ] for model in models ]
14
+
15
+ if model_name not in available_models :
16
+ raise HTTPException (
17
+ status_code = HTTP_400_BAD_REQUEST ,
18
+ detail = f"Model { model_name } not available. Available models: { available_models } "
19
+ )
Original file line number Diff line number Diff line change 2
2
import random
3
3
import string
4
4
import sys
5
+ from unittest .mock import patch
5
6
from uuid import UUID
6
7
7
8
from agents_api .autogen .openapi_model import (
@@ -440,10 +441,18 @@ async def test_tool(
440
441
return tool
441
442
442
443
444
+ SAMPLE_MODELS = [
445
+ {"id" : "gpt-4" },
446
+ {"id" : "gpt-3.5-turbo" },
447
+ {"id" : "gpt-4o-mini" },
448
+ ]
449
+
450
+
443
451
@fixture (scope = "global" )
444
452
def client (_dsn = pg_dsn ):
445
453
with TestClient (app = app ) as client :
446
- yield client
454
+ with patch ("agents_api.routers.utils.model_validation.get_model_list" , return_value = SAMPLE_MODELS ):
455
+ yield client
447
456
448
457
449
458
@fixture (scope = "global" )
Original file line number Diff line number Diff line change
1
+ from unittest .mock import patch
2
+
3
+ from agents_api .routers .utils .model_validation import validate_model
4
+ from fastapi import HTTPException
5
+ from ward import raises , test
6
+
7
+ from tests .fixtures import SAMPLE_MODELS
8
+
9
+
10
+ @test ("validate_model: succeeds when model is available" )
11
+ async def _ ():
12
+ # Use async context manager for patching
13
+ with patch ("agents_api.routers.utils.model_validation.get_model_list" ) as mock_get_models :
14
+ mock_get_models .return_value = SAMPLE_MODELS
15
+ await validate_model ("gpt-4o-mini" )
16
+ mock_get_models .assert_called_once ()
17
+
18
+
19
+ @test ("validate_model: fails when model is unavailable" )
20
+ async def _ ():
21
+ with patch ("agents_api.routers.utils.model_validation.get_model_list" ) as mock_get_models :
22
+ mock_get_models .return_value = SAMPLE_MODELS
23
+ with raises (HTTPException ) as exc :
24
+ await validate_model ("non-existent-model" )
25
+
26
+ assert exc .raised .status_code == 400
27
+ assert "Model non-existent-model not available" in exc .raised .detail
28
+ mock_get_models .assert_called_once ()
You can’t perform that action at this time.
0 commit comments