diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8500d820a4c..fbb073a3b2c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7706,6 +7706,154 @@ def _enrich_model_info_with_litellm_data( return model +async def _apply_search_filter_to_models( + all_models: List[Dict[str, Any]], + search: str, + page: int, + size: int, + prisma_client: Optional[Any], + proxy_config: Any, +) -> Tuple[List[Dict[str, Any]], Optional[int]]: + """ + Apply search filter to models, querying database for additional matching models. + + Args: + all_models: List of models to filter + search: Search term (case-insensitive) + page: Current page number + size: Page size + prisma_client: Prisma client for database queries + proxy_config: Proxy config for decrypting models + + Returns: + Tuple of (filtered_models, total_count). total_count is None if not searching. + """ + if not search or not search.strip(): + return all_models, None + + search_lower = search.lower().strip() + + # Filter models in router by search term + filtered_router_models = [ + m for m in all_models + if search_lower in m.get("model_name", "").lower() + ] + + # Separate filtered models into config vs db models, and track db model IDs + filtered_config_models = [] + db_model_ids_in_router = set() + + for m in filtered_router_models: + model_info = m.get("model_info", {}) + is_db_model = model_info.get("db_model", False) + model_id = model_info.get("id") + + if is_db_model and model_id: + db_model_ids_in_router.add(model_id) + else: + filtered_config_models.append(m) + + config_models_count = len(filtered_config_models) + db_models_in_router_count = len(db_model_ids_in_router) + router_models_count = config_models_count + db_models_in_router_count + + # Query database for additional models with search term + db_models = [] + db_models_total_count = 0 + models_needed_for_page = size * page + + try: + # Build where condition for database query + db_where_condition: Dict[str, Any] = { + "model_name": { + "contains": search_lower, + "mode": "insensitive", + } + } + # Exclude models already in router if we have any + if db_model_ids_in_router: + db_where_condition["model_id"] = { + "not": {"in": list(db_model_ids_in_router)} + } + + # Get total count of matching database models + db_models_total_count = await prisma_client.db.litellm_proxymodeltable.count( + where=db_where_condition + ) + + # Calculate total count for search results + search_total_count = router_models_count + db_models_total_count + + # Fetch database models if we need more for the current page + if router_models_count < models_needed_for_page: + models_to_fetch = min( + models_needed_for_page - router_models_count, + db_models_total_count + ) + + if models_to_fetch > 0: + db_models_raw = await prisma_client.db.litellm_proxymodeltable.find_many( + where=db_where_condition, + take=models_to_fetch, + ) + + # Convert database models to router format + for db_model in db_models_raw: + decrypted_models = proxy_config.decrypt_model_list_from_db([db_model]) + if decrypted_models: + db_models.extend(decrypted_models) + except Exception as e: + verbose_proxy_logger.exception( + f"Error querying database models with search: {str(e)}" + ) + # If error, use router models count as fallback + search_total_count = router_models_count + + # Combine all models + filtered_models = filtered_router_models + db_models + return filtered_models, search_total_count + + +def _paginate_models_response( + all_models: List[Dict[str, Any]], + page: int, + size: int, + total_count: Optional[int], + search: Optional[str], +) -> Dict[str, Any]: + """ + Paginate models and return response dictionary. + + Args: + all_models: List of all models + page: Current page number + size: Page size + total_count: Total count (if None, uses len(all_models)) + search: Search term (for logging) + + Returns: + Paginated response dictionary + """ + if total_count is None: + total_count = len(all_models) + + skip = (page - 1) * size + total_pages = -(-total_count // size) if total_count > 0 else 0 + paginated_models = all_models[skip : skip + size] + + verbose_proxy_logger.debug( + f"Pagination: skip={skip}, take={size}, total_count={total_count}, total_pages={total_pages}, search={search}" + ) + + return { + "data": paginated_models, + "total_count": total_count, + "current_page": page, + "total_pages": total_pages, + "size": size, + } + + @router.get( "/v2/model/info", description="v2 - returns models available to the user based on their API key permissions. Shows model info from config.yaml (except api key and api base). Filter to just user-added models with ?user_models_only=true", @@ -7727,6 +7875,9 @@ async def model_info_v2( debug: Optional[bool] = False, page: int = Query(1, description="Page number", ge=1), size: int = Query(50, description="Page size", ge=1), + search: Optional[str] = fastapi.Query( + None, description="Search model names (case-insensitive partial match)" + ), ): """ BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now. @@ -7760,6 +7911,16 @@ async def model_info_v2( if model is not None: all_models = [m for m in all_models if m["model_name"] == model] + # Apply search filter if provided + all_models, search_total_count = await _apply_search_filter_to_models( + all_models=all_models, + search=search or "", + page=page, + size=size, + prisma_client=prisma_client, + proxy_config=proxy_config, + ) + if user_models_only: all_models = await non_admin_all_models( all_models=all_models, @@ -7775,7 +7936,8 @@ async def model_info_v2( llm_router=llm_router, all_models=all_models, ) - # fill in model info based on config.yaml and litellm model_prices_and_context_window.json + + # Fill in model info based on config.yaml and litellm model_prices_and_context_window.json for i, _model in enumerate(all_models): all_models[i] = _enrich_model_info_with_litellm_data( model=_model, debug=debug if debug is not None else False, llm_router=llm_router @@ -7783,25 +7945,13 @@ async def model_info_v2( verbose_proxy_logger.debug("all_models: %s", all_models) - total_count = len(all_models) - - skip = (page - 1) * size - - total_pages = -(-total_count // size) if total_count > 0 else 0 - - paginated_models = all_models[skip : skip + size] - - verbose_proxy_logger.debug( - f"Pagination: skip={skip}, take={size}, total_count={total_count}, total_pages={total_pages}" + return _paginate_models_response( + all_models=all_models, + page=page, + size=size, + total_count=search_total_count, + search=search, ) - - return { - "data": paginated_models, - "total_count": total_count, - "current_page": page, - "total_pages": total_pages, - "size": size, - } @router.get( diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index cb519e9f509..34dc0d843ad 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -3443,6 +3443,497 @@ async def test_model_info_v2_pagination_edge_cases(monkeypatch): app.dependency_overrides = original_overrides +@pytest.mark.asyncio +async def test_model_info_v2_search_config_models(monkeypatch): + """ + Test search parameter for config models (models from config.yaml). + Config models don't have db_model=True in model_info. + """ + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth + + # Create mock config models (no db_model flag or db_model=False) + mock_config_models = [ + { + "model_name": "gpt-4-turbo", + "litellm_params": {"model": "gpt-4-turbo"}, + "model_info": {"id": "gpt-4-turbo"}, # No db_model flag = config model + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo"}, + "model_info": {"id": "gpt-3.5-turbo", "db_model": False}, # Explicitly config model + }, + { + "model_name": "claude-3-opus", + "litellm_params": {"model": "claude-3-opus"}, + "model_info": {"id": "claude-3-opus"}, # No db_model flag = config model + }, + { + "model_name": "gemini-pro", + "litellm_params": {"model": "gemini-pro"}, + "model_info": {"id": "gemini-pro"}, # No db_model flag = config model + }, + ] + + # Mock llm_router + mock_router = MagicMock() + mock_router.model_list = mock_config_models + + # Mock prisma_client + mock_prisma_client = MagicMock() + + # Mock proxy_config.get_config + mock_get_config = AsyncMock(return_value={}) + + # Mock user authentication + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_user_api_key_dict.user_id = "test-user" + mock_user_api_key_dict.api_key = "test-key" + mock_user_api_key_dict.team_models = [] + mock_user_api_key_dict.models = [] + + # Apply monkeypatches + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None) + monkeypatch.setattr(proxy_config, "get_config", mock_get_config) + + # Override auth dependency + original_overrides = app.dependency_overrides.copy() + app.dependency_overrides[user_api_key_auth] = lambda: mock_user_api_key_dict + + client = TestClient(app) + try: + # Test search for "gpt" - should return gpt-4-turbo and gpt-3.5-turbo + response = client.get("/v2/model/info", params={"search": "gpt"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 2 # Only config models matching search + assert len(data["data"]) == 2 + model_names = [m["model_name"] for m in data["data"]] + assert "gpt-4-turbo" in model_names + assert "gpt-3.5-turbo" in model_names + assert "claude-3-opus" not in model_names + assert "gemini-pro" not in model_names + + # Test search for "claude" - should return claude-3-opus + response = client.get("/v2/model/info", params={"search": "claude"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 1 + assert len(data["data"]) == 1 + assert data["data"][0]["model_name"] == "claude-3-opus" + + # Test case-insensitive search + response = client.get("/v2/model/info", params={"search": "GPT"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 2 + assert len(data["data"]) == 2 + + # Test partial match + response = client.get("/v2/model/info", params={"search": "turbo"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 2 + assert len(data["data"]) == 2 + model_names = [m["model_name"] for m in data["data"]] + assert "gpt-4-turbo" in model_names + assert "gpt-3.5-turbo" in model_names + + # Test search with no matches + response = client.get("/v2/model/info", params={"search": "nonexistent"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 0 + assert len(data["data"]) == 0 + + finally: + app.dependency_overrides = original_overrides + + +@pytest.mark.asyncio +async def test_model_info_v2_search_db_models(monkeypatch): + """ + Test search parameter for db models (models from database). + DB models have db_model=True and id in model_info. + """ + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth + + # Create mock db models (db_model=True with id) + mock_db_models_in_router = [ + { + "model_name": "db-gpt-4", + "litellm_params": {"model": "gpt-4"}, + "model_info": {"id": "db-model-1", "db_model": True}, # DB model + }, + { + "model_name": "db-claude-3", + "litellm_params": {"model": "claude-3"}, + "model_info": {"id": "db-model-2", "db_model": True}, # DB model + }, + ] + + # Mock llm_router + mock_router = MagicMock() + mock_router.model_list = mock_db_models_in_router + + # Mock prisma_client with database query methods + mock_db_models_from_db = [ + MagicMock( + model_id="db-model-3", + model_name="db-gemini-pro", + litellm_params='{"model": "gemini-pro"}', + model_info='{"id": "db-model-3", "db_model": true}', + ), + MagicMock( + model_id="db-model-4", + model_name="db-gpt-3.5", + litellm_params='{"model": "gpt-3.5-turbo"}', + model_info='{"id": "db-model-4", "db_model": true}', + ), + ] + + # Mock the database count and find_many methods dynamically based on search + async def mock_db_count_func(*args, **kwargs): + where_condition = kwargs.get("where", {}) + search_term = where_condition.get("model_name", {}).get("contains", "") + excluded_ids = where_condition.get("model_id", {}).get("not", {}).get("in", []) + + # Count models matching search term but not in excluded_ids + count = 0 + for model in mock_db_models_from_db: + if search_term.lower() in model.model_name.lower(): + if model.model_id not in excluded_ids: + count += 1 + return count + + async def mock_db_find_many_func(*args, **kwargs): + where_condition = kwargs.get("where", {}) + search_term = where_condition.get("model_name", {}).get("contains", "") + excluded_ids = where_condition.get("model_id", {}).get("not", {}).get("in", []) + take = kwargs.get("take", 10) + + # Return models matching search term but not in excluded_ids + result = [] + for model in mock_db_models_from_db: + if search_term.lower() in model.model_name.lower(): + if model.model_id not in excluded_ids: + result.append(model) + if len(result) >= take: + break + return result + + mock_db_count = AsyncMock(side_effect=mock_db_count_func) + mock_db_find_many = AsyncMock(side_effect=mock_db_find_many_func) + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_proxymodeltable.count = mock_db_count + mock_prisma_client.db.litellm_proxymodeltable.find_many = mock_db_find_many + + # Mock proxy_config.decrypt_model_list_from_db to return router-format models + def mock_decrypt_models(db_models_list): + result = [] + for db_model in db_models_list: + result.append( + { + "model_name": db_model.model_name, + "litellm_params": {"model": db_model.model_name.replace("db-", "")}, + "model_info": {"id": db_model.model_id, "db_model": True}, + } + ) + return result + + # Mock proxy_config.get_config + mock_get_config = AsyncMock(return_value={}) + + # Mock user authentication + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + mock_user_api_key_dict.user_id = "test-user" + mock_user_api_key_dict.api_key = "test-key" + mock_user_api_key_dict.team_models = [] + mock_user_api_key_dict.models = [] + + # Apply monkeypatches + monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client) + monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None) + monkeypatch.setattr(proxy_config, "get_config", mock_get_config) + monkeypatch.setattr(proxy_config, "decrypt_model_list_from_db", mock_decrypt_models) + + # Override auth dependency + original_overrides = app.dependency_overrides.copy() + app.dependency_overrides[user_api_key_auth] = lambda: mock_user_api_key_dict + + client = TestClient(app) + try: + # Test search for "gpt" - should return db-gpt-4 from router and db-gpt-3.5 from db + response = client.get("/v2/model/info", params={"search": "gpt"}) + assert response.status_code == 200 + data = response.json() + # Should have db-gpt-4 from router + db-gpt-3.5 from db = 2 total + assert data["total_count"] == 2 + assert len(data["data"]) == 2 + model_names = [m["model_name"] for m in data["data"]] + assert "db-gpt-4" in model_names + assert "db-gpt-3.5" in model_names + + # Verify database was queried + mock_db_count.assert_called() + # Verify the where condition excludes models already in router + call_args = mock_db_count.call_args + assert call_args is not None + where_condition = call_args[1]["where"] + assert "model_name" in where_condition + assert where_condition["model_name"]["contains"] == "gpt" + assert where_condition["model_name"]["mode"] == "insensitive" + + # Test search for "claude" - should return db-claude-3 from router only + response = client.get("/v2/model/info", params={"search": "claude"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 1 + assert len(data["data"]) == 1 + assert data["data"][0]["model_name"] == "db-claude-3" + + # Test search for "gemini" - should return db-gemini-pro from db only + response = client.get("/v2/model/info", params={"search": "gemini"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 1 + assert len(data["data"]) == 1 + assert data["data"][0]["model_name"] == "db-gemini-pro" + + # Test case-insensitive search + response = client.get("/v2/model/info", params={"search": "GPT"}) + assert response.status_code == 200 + data = response.json() + assert data["total_count"] == 2 + + finally: + app.dependency_overrides = original_overrides + + +@pytest.mark.asyncio +async def test_apply_search_filter_to_models(monkeypatch): + """ + Test the _apply_search_filter_to_models helper function. + Tests search filtering logic for config models, db models, and database queries. + """ + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy.proxy_server import _apply_search_filter_to_models, proxy_config + + # Create mock models with mix of config and db models + mock_models = [ + { + "model_name": "gpt-4-turbo", + "model_info": {"id": "gpt-4-turbo"}, # Config model + }, + { + "model_name": "db-gpt-3.5", + "model_info": {"id": "db-model-1", "db_model": True}, # DB model in router + }, + { + "model_name": "claude-3-opus", + "model_info": {"id": "claude-3-opus"}, # Config model + }, + ] + + # Mock prisma_client + mock_prisma_client = MagicMock() + mock_db_table = MagicMock() + mock_prisma_client.db.litellm_proxymodeltable = mock_db_table + + # Mock database models + mock_db_model_1 = MagicMock( + model_id="db-model-2", + model_name="db-gemini-pro", + litellm_params='{"model": "gemini-pro"}', + model_info='{"id": "db-model-2", "db_model": true}', + ) + + # Mock proxy_config.decrypt_model_list_from_db + mock_decrypt = MagicMock(return_value=[{"model_name": "db-gemini-pro", "model_info": {"id": "db-model-2", "db_model": True}}]) + + monkeypatch.setattr(proxy_config, "decrypt_model_list_from_db", mock_decrypt) + + # Test Case 1: No search term - should return all models unchanged + result_models, total_count = await _apply_search_filter_to_models( + all_models=mock_models.copy(), + search="", + page=1, + size=50, + prisma_client=mock_prisma_client, + proxy_config=proxy_config, + ) + assert result_models == mock_models + assert total_count is None + + # Test Case 2: Search for "gpt" - should filter router models and query DB + mock_db_table.count = AsyncMock(return_value=0) + mock_db_table.find_many = AsyncMock(return_value=[]) + + result_models, total_count = await _apply_search_filter_to_models( + all_models=mock_models.copy(), + search="gpt", + page=1, + size=50, + prisma_client=mock_prisma_client, + proxy_config=proxy_config, + ) + assert len(result_models) == 2 + model_names = [m["model_name"] for m in result_models] + assert "gpt-4-turbo" in model_names + assert "db-gpt-3.5" in model_names + assert "claude-3-opus" not in model_names + assert total_count == 2 # Only router models match + + # Test Case 3: Search with DB models matching + mock_db_table.count = AsyncMock(return_value=1) + mock_db_table.find_many = AsyncMock(return_value=[mock_db_model_1]) + + result_models, total_count = await _apply_search_filter_to_models( + all_models=mock_models.copy(), + search="gemini", + page=1, + size=50, + prisma_client=mock_prisma_client, + proxy_config=proxy_config, + ) + assert total_count == 1 # Router models (0) + DB models (1) + assert len(result_models) == 1 + assert result_models[0]["model_name"] == "db-gemini-pro" + + # Test Case 4: Case-insensitive search + # Reset mocks - no DB models should match "GPT" + mock_db_table.count = AsyncMock(return_value=0) + mock_db_table.find_many = AsyncMock(return_value=[]) + + result_models, total_count = await _apply_search_filter_to_models( + all_models=mock_models.copy(), + search="GPT", + page=1, + size=50, + prisma_client=mock_prisma_client, + proxy_config=proxy_config, + ) + assert len(result_models) == 2 + model_names = [m["model_name"] for m in result_models] + assert "gpt-4-turbo" in model_names + assert "db-gpt-3.5" in model_names + + # Test Case 5: Database query error - should fallback to router models count + mock_db_table.count = AsyncMock(side_effect=Exception("DB error")) + mock_db_table.find_many = AsyncMock(return_value=[]) + + result_models, total_count = await _apply_search_filter_to_models( + all_models=mock_models.copy(), + search="gpt", + page=1, + size=50, + prisma_client=mock_prisma_client, + proxy_config=proxy_config, + ) + # Should still return filtered router models + assert len(result_models) == 2 + assert total_count == 2 # Fallback to router models count + + +def test_paginate_models_response(): + """ + Test the _paginate_models_response helper function. + Tests pagination calculation and response formatting. + """ + from litellm.proxy.proxy_server import _paginate_models_response + + # Create mock models + mock_models = [ + {"model_name": f"model-{i}", "model_info": {"id": f"model-{i}"}} + for i in range(25) + ] + + # Test Case 1: Basic pagination - first page + result = _paginate_models_response( + all_models=mock_models, + page=1, + size=10, + total_count=None, + search=None, + ) + assert result["total_count"] == 25 + assert result["current_page"] == 1 + assert result["total_pages"] == 3 # ceil(25/10) = 3 + assert result["size"] == 10 + assert len(result["data"]) == 10 + assert result["data"][0]["model_name"] == "model-0" + + # Test Case 2: Second page + result = _paginate_models_response( + all_models=mock_models, + page=2, + size=10, + total_count=None, + search=None, + ) + assert result["current_page"] == 2 + assert len(result["data"]) == 10 + assert result["data"][0]["model_name"] == "model-10" + + # Test Case 3: Last page (partial) + result = _paginate_models_response( + all_models=mock_models, + page=3, + size=10, + total_count=None, + search=None, + ) + assert result["current_page"] == 3 + assert len(result["data"]) == 5 # Only 5 models left + assert result["data"][0]["model_name"] == "model-20" + + # Test Case 4: With explicit total_count (for search scenarios) + result = _paginate_models_response( + all_models=mock_models[:10], # Only 10 models in list + page=1, + size=10, + total_count=50, # But total_count says 50 + search="test", + ) + assert result["total_count"] == 50 + assert result["total_pages"] == 5 # ceil(50/10) = 5 + assert len(result["data"]) == 10 + + # Test Case 5: Empty models list + result = _paginate_models_response( + all_models=[], + page=1, + size=10, + total_count=0, + search=None, + ) + assert result["total_count"] == 0 + assert result["total_pages"] == 0 + assert len(result["data"]) == 0 + + # Test Case 6: Page beyond available data + result = _paginate_models_response( + all_models=mock_models[:10], + page=5, + size=10, + total_count=10, + search=None, + ) + assert result["current_page"] == 5 + assert len(result["data"]) == 0 # No data for page 5 + + def test_enrich_model_info_with_litellm_data(): """ Test the _enrich_model_info_with_litellm_data helper function. diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.test.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.test.ts index c6629e5396a..399bb3601aa 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.test.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.test.ts @@ -101,7 +101,8 @@ describe("useModelsInfo", () => { "test-user-id", "Admin", 1, - 50 + 50, + undefined ); expect(modelInfoCall).toHaveBeenCalledTimes(1); }); @@ -120,7 +121,8 @@ describe("useModelsInfo", () => { "test-user-id", "Admin", 2, - 25 + 25, + undefined ); }); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.ts index 1dbd79eacf9..008d02d4bb2 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.ts +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/models/useModels.ts @@ -27,7 +27,7 @@ const modelHubKeys = createQueryKeys("modelHub"); const allProxyModelsKeys = createQueryKeys("allProxyModels"); const selectedTeamModelsKeys = createQueryKeys("selectedTeamModels"); -export const useModelsInfo = (page: number = 1, size: number = 50) => { +export const useModelsInfo = (page: number = 1, size: number = 50, search?: string) => { const { accessToken, userId, userRole } = useAuthorized(); return useQuery({ queryKey: modelKeys.list({ @@ -36,9 +36,10 @@ export const useModelsInfo = (page: number = 1, size: number = 50) => { ...(userRole && { userRole }), page, size, + ...(search && { search }), }, }), - queryFn: async () => await modelInfoCall(accessToken!, userId!, userRole!, page, size), + queryFn: async () => await modelInfoCall(accessToken!, userId!, userRole!, page, size, search), enabled: Boolean(accessToken && userId && userRole), }); }; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.test.tsx b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.test.tsx index 8a2298361c5..813a365d367 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.test.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.test.tsx @@ -11,7 +11,7 @@ const mockUseModelsInfo = vi.fn(() => ({ })) as any; vi.mock("../../hooks/models/useModels", () => ({ - useModelsInfo: (page?: number, size?: number) => mockUseModelsInfo(page, size), + useModelsInfo: (page?: number, size?: number, search?: string) => mockUseModelsInfo(page, size, search), })); // Mock the useModelCostMap hook @@ -74,7 +74,6 @@ describe("AllModelsTab", () => { const mockSetSelectedModelGroup = vi.fn(); const mockSetSelectedModelId = vi.fn(); const mockSetSelectedTeamId = vi.fn(); - const mockSetEditModel = vi.fn(); const defaultProps = { selectedModelGroup: "all", @@ -83,7 +82,6 @@ describe("AllModelsTab", () => { availableModelAccessGroups: ["sales-team", "engineering-team"], setSelectedModelId: mockSetSelectedModelId, setSelectedTeamId: mockSetSelectedTeamId, - setEditModel: mockSetEditModel, }; const mockUseAuthorized = { @@ -176,8 +174,10 @@ describe("AllModelsTab", () => { render(); + // Component shows API total_count (2), not filtered count + // Since default is "personal" team and models don't have direct_access, they're filtered out await waitFor(() => { - expect(screen.getByText("Showing 0 results")).toBeInTheDocument(); + expect(screen.getByText("Showing 1 - 2 of 2 results")).toBeInTheDocument(); }); }); @@ -235,8 +235,10 @@ describe("AllModelsTab", () => { render(); + // Component shows API total_count (2), not filtered count + // Since default is "personal" team and models don't have direct_access, they're filtered out await waitFor(() => { - expect(screen.getByText("Showing 0 results")).toBeInTheDocument(); + expect(screen.getByText("Showing 1 - 2 of 2 results")).toBeInTheDocument(); }); }); @@ -280,8 +282,9 @@ describe("AllModelsTab", () => { render(); + // Component shows API total_count (2), but only 1 model has direct_access await waitFor(() => { - expect(screen.getByText("Showing 1 - 1 of 1 results")).toBeInTheDocument(); + expect(screen.getByText("Showing 1 - 2 of 2 results")).toBeInTheDocument(); }); }); @@ -419,14 +422,15 @@ describe("AllModelsTab", () => { ); // Set up mock to return page1Data for page 1 - mockUseModelsInfo.mockImplementation((page: number = 1) => { + mockUseModelsInfo.mockImplementation((page: number = 1, size?: number, search?: string) => { return { data: page1Data, isLoading: false, error: null }; }); render(); await waitFor(() => { - expect(screen.getByText("Showing 1 - 1 of 1 results")).toBeInTheDocument(); + // Component calculates: ((1-1)*50)+1 = 1, Math.min(1*50, 2) = 2 + expect(screen.getByText("Showing 1 - 2 of 2 results")).toBeInTheDocument(); }); // Check that Previous button is disabled on first page @@ -471,7 +475,7 @@ describe("AllModelsTab", () => { 50, // size ); - mockUseModelsInfo.mockImplementation(() => { + mockUseModelsInfo.mockImplementation((page?: number, size?: number, search?: string) => { return { data: singlePageData, isLoading: false, error: null }; }); diff --git a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.tsx b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.tsx index 04300b7fd16..478a5338b66 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/components/AllModelsTab.tsx @@ -8,10 +8,11 @@ import { getDisplayModelName } from "@/components/view_model/model_name_display" import { InfoCircleOutlined } from "@ant-design/icons"; import { PaginationState } from "@tanstack/react-table"; import { Grid, Select, SelectItem, TabPanel, Text } from "@tremor/react"; +import { Skeleton } from "antd"; +import debounce from "lodash/debounce"; import { useEffect, useMemo, useState } from "react"; import { useModelsInfo } from "../../hooks/models/useModels"; import { transformModelData } from "../utils/modelDataTransformer"; -import { Skeleton } from "antd"; type ModelViewMode = "all" | "current_team"; interface AllModelsTabProps { @@ -36,6 +37,7 @@ const AllModelsTab = ({ const { data: teams } = useTeams(); const [modelNameSearch, setModelNameSearch] = useState(""); + const [debouncedSearch, setDebouncedSearch] = useState(""); const [modelViewMode, setModelViewMode] = useState("current_team"); const [currentTeam, setCurrentTeam] = useState("personal"); const [showFilters, setShowFilters] = useState(false); @@ -48,7 +50,26 @@ const AllModelsTab = ({ pageSize: 50, }); - const { data: rawModelData, isLoading: isLoadingModelsInfo } = useModelsInfo(currentPage, pageSize); + // Debounce search input + const debouncedUpdateSearch = useMemo( + () => + debounce((value: string) => { + setDebouncedSearch(value); + // Reset to page 1 when search changes + setCurrentPage(1); + setPagination((prev: PaginationState) => ({ ...prev, pageIndex: 0 })); + }, 200), + [] + ); + + useEffect(() => { + debouncedUpdateSearch(modelNameSearch); + return () => { + debouncedUpdateSearch.cancel(); + }; + }, [modelNameSearch, debouncedUpdateSearch]); + + const { data: rawModelData, isLoading: isLoadingModelsInfo } = useModelsInfo(currentPage, pageSize, debouncedSearch || undefined); const isLoading = isLoadingModelsInfo || isLoadingModelCostMap; const getProviderFromModel = (model: string) => { @@ -88,10 +109,8 @@ const AllModelsTab = ({ return []; } + // Server-side search is now handled by the API, so we only filter by other criteria return modelData.data.filter((model: any) => { - const searchMatch = - modelNameSearch === "" || model.model_name.toLowerCase().includes(modelNameSearch.toLowerCase()); - const modelNameMatch = selectedModelGroup === "all" || model.model_name === selectedModelGroup || @@ -120,13 +139,13 @@ const AllModelsTab = ({ } } - return searchMatch && modelNameMatch && accessGroupMatch && teamAccessMatch; + return modelNameMatch && accessGroupMatch && teamAccessMatch; }); - }, [modelData, modelNameSearch, selectedModelGroup, selectedModelAccessGroupFilter, currentTeam, modelViewMode]); + }, [modelData, selectedModelGroup, selectedModelAccessGroupFilter, currentTeam, modelViewMode]); useEffect(() => { setPagination((prev: PaginationState) => ({ ...prev, pageIndex: 0 })); - }, [modelNameSearch, selectedModelGroup, selectedModelAccessGroupFilter, currentTeam, modelViewMode]); + }, [selectedModelGroup, selectedModelAccessGroupFilter, currentTeam, modelViewMode]); const resetFilters = () => { setModelNameSearch(""); @@ -354,8 +373,8 @@ const AllModelsTab = ({ ) : ( - {filteredData.length > 0 - ? `Showing 1 - ${filteredData.length} of ${filteredData.length} results` + {paginationMeta.total_count > 0 + ? `Showing ${((currentPage - 1) * pageSize) + 1} - ${Math.min(currentPage * pageSize, paginationMeta.total_count)} of ${paginationMeta.total_count} results` : "Showing 0 results"} )} diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index a4f2dfce3bd..fd55a8e067d 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -2007,17 +2007,20 @@ export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: st let ModelListerrorShown = false; let errorTimer: NodeJS.Timeout | null = null; -export const modelInfoCall = async (accessToken: string, userID: string, userRole: string, page: number = 1, size: number = 50) => { +export const modelInfoCall = async (accessToken: string, userID: string, userRole: string, page: number = 1, size: number = 50, search?: string) => { /** * Get all models on proxy */ try { - console.log("modelInfoCall:", accessToken, userID, userRole, page, size); + console.log("modelInfoCall:", accessToken, userID, userRole, page, size, search); let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`; const params = new URLSearchParams(); params.append("include_team_models", "true"); params.append("page", page.toString()); params.append("size", size.toString()); + if (search && search.trim()) { + params.append("search", search.trim()); + } if (params.toString()) { url += `?${params.toString()}`; }