Skip to content

Commit 6494660

Browse files
committed
New test for models endpoint
1 parent f22de7a commit 6494660

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

tests/unit/app/endpoints/test_models.py

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
"""Unit tests for the /models REST API endpoint."""
22

3-
from unittest.mock import Mock
4-
53
import pytest
64

75
from fastapi import HTTPException, Request, status
86

7+
from llama_stack_client import APIConnectionError
8+
99
from app.endpoints.models import models_endpoint_handler
1010
from configuration import AppConfig
1111

@@ -142,7 +142,7 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker):
142142
cfg.init_from_dict(config_dict)
143143

144144
# Mock the LlamaStack client
145-
mock_client = Mock()
145+
mock_client = mocker.Mock()
146146
mock_client.models.list.return_value = []
147147
mock_lsc = mocker.patch("client.LlamaStackClientHolder.get_client")
148148
mock_lsc.return_value = mock_client
@@ -157,3 +157,50 @@ def test_models_endpoint_handler_unable_to_retrieve_models_list(mocker):
157157
)
158158
response = models_endpoint_handler(request)
159159
assert response is not None
160+
161+
162+
def test_models_endpoint_llama_stack_connection_error(mocker):
163+
"""Test the model endpoint when LlamaStack connection fails."""
164+
# configuration for tests
165+
config_dict = {
166+
"name": "foo",
167+
"service": {
168+
"host": "localhost",
169+
"port": 8080,
170+
"auth_enabled": False,
171+
"workers": 1,
172+
"color_log": True,
173+
"access_log": True,
174+
},
175+
"llama_stack": {
176+
"api_key": "xyzzy",
177+
"url": "http://x.y.com:1234",
178+
"use_as_library_client": False,
179+
},
180+
"user_data_collection": {
181+
"feedback_enabled": False,
182+
},
183+
"customization": None,
184+
}
185+
186+
# mock LlamaStackClientHolder to raise APIConnectionError
187+
# when models.list() method is called
188+
mock_client = mocker.Mock()
189+
mock_client.models.list.side_effect = APIConnectionError(request=None)
190+
mock_client_holder = mocker.patch("app.endpoints.models.LlamaStackClientHolder")
191+
mock_client_holder.return_value.get_client.return_value = mock_client
192+
193+
cfg = AppConfig()
194+
cfg.init_from_dict(config_dict)
195+
196+
request = Request(
197+
scope={
198+
"type": "http",
199+
"headers": [(b"authorization", b"Bearer invalid-token")],
200+
}
201+
)
202+
203+
with pytest.raises(HTTPException) as e:
204+
models_endpoint_handler(request)
205+
assert e.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
206+
assert e.detail["response"] == "Unable to connect to Llama Stack"

tests/unit/test_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ def test_async_client_get_client_method() -> None:
2424

2525
with pytest.raises(
2626
RuntimeError,
27-
match="AsyncLlamaStackClient has not been initialised. Ensure 'load\\(..\\)' has been called.",
27+
match=(
28+
"AsyncLlamaStackClient has not been initialised. "
29+
"Ensure 'load\\(..\\)' has been called."
30+
),
2831
):
2932
client.get_client()
3033

0 commit comments

Comments
 (0)