From 4f24533039728dd31cc33391e88e01624eb69ab7 Mon Sep 17 00:00:00 2001 From: hzt <3061613175@qq.com> Date: Fri, 20 Feb 2026 13:52:52 +0800 Subject: [PATCH] fix(proxy): add model_ids param to access group endpoints for precise deployment tagging POST /access_group/new and PUT /access_group/{name}/update now accept an optional model_ids list that targets specific deployments by their unique model_id, instead of tagging every deployment that shares a model_name. When model_ids is provided it takes priority over model_names, giving API callers the same single-deployment precision that the UI already has via PATCH /model/{model_id}/update. Backward compatible: model_names continues to work as before. Closes #21544 --- ...model_access_group_management_endpoints.py | 153 +++++++++---- .../model_management_endpoints.py | 9 +- .../test_access_group_management.py | 210 ++++++++++++++++++ 3 files changed, 329 insertions(+), 43 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py b/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py index 0c820f6b7891..dc80d6eb6eea 100644 --- a/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_access_group_management_endpoints.py @@ -6,7 +6,7 @@ """ import json -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from fastapi import APIRouter, Depends, HTTPException @@ -141,6 +141,50 @@ async def update_deployments_with_access_group( return models_updated +async def update_specific_deployments_with_access_group( + model_ids: List[str], + access_group: str, + prisma_client: PrismaClient, +) -> int: + """ + Update specific deployments (by model_id) to include the access group. + + Unlike update_deployments_with_access_group which tags ALL deployments sharing + a model_name, this function only tags the specific deployments identified by + their unique model_id. + """ + models_updated = 0 + for model_id in model_ids: + verbose_proxy_logger.debug( + f"Updating specific deployment model_id: {model_id}" + ) + deployment = await prisma_client.db.litellm_proxymodeltable.find_unique( + where={"model_id": model_id} + ) + if deployment is None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Deployment with model_id '{model_id}' not found in Database." + }, + ) + model_info = deployment.model_info or {} + updated_model_info, was_modified = add_access_group_to_deployment( + model_info=model_info, + access_group=access_group, + ) + if was_modified: + await prisma_client.db.litellm_proxymodeltable.update( + where={"model_id": model_id}, + data={"model_info": json.dumps(updated_model_info)}, + ) + models_updated += 1 + verbose_proxy_logger.debug( + f"Updated deployment {model_id} with access group: {access_group}" + ) + return models_updated + + def remove_access_group_from_deployment( model_info: Dict[str, Any], access_group: str ) -> Tuple[Dict[str, Any], bool]: @@ -263,24 +307,31 @@ async def create_model_group( detail={"error": "access_group is required and cannot be empty"}, ) - # Validation: Check if model_names list is provided and not empty - if not data.model_names or len(data.model_names) == 0: + # Validation: Check that at least one of model_names or model_ids is provided + has_model_names = data.model_names and len(data.model_names) > 0 + has_model_ids = data.model_ids and len(data.model_ids) > 0 + + if not has_model_names and not has_model_ids: raise HTTPException( status_code=400, - detail={"error": "model_names list is required and cannot be empty"}, + detail={"error": "Either model_names or model_ids must be provided and non-empty"}, ) - - # Validation: Check if all models exist in the router - all_valid, missing_models = validate_models_exist( - model_names=data.model_names, - llm_router=llm_router, - ) - - if not all_valid: - raise HTTPException( - status_code=400, - detail={"error": f"Model(s) not found: {', '.join(missing_models)}"}, + + # If model_ids is provided, use it (more precise targeting) + use_model_ids = has_model_ids + + # Validate model_names exist in router (only if using model_names path) + if not use_model_ids and has_model_names: + all_valid, missing_models = validate_models_exist( + model_names=data.model_names, + llm_router=llm_router, ) + + if not all_valid: + raise HTTPException( + status_code=400, + detail={"error": f"Model(s) not found: {', '.join(missing_models)}"}, + ) # Check if database is connected if prisma_client is None: @@ -301,12 +352,19 @@ async def create_model_group( detail={"error": f"Access group '{data.access_group}' already exists. Use PUT /access_group/{data.access_group}/update to modify it."}, ) - # Update deployments using helper function - models_updated = await update_deployments_with_access_group( - model_names=data.model_names, - access_group=data.access_group, - prisma_client=prisma_client, - ) + # Update deployments using the appropriate method + if use_model_ids: + models_updated = await update_specific_deployments_with_access_group( + model_ids=data.model_ids, + access_group=data.access_group, + prisma_client=prisma_client, + ) + else: + models_updated = await update_deployments_with_access_group( + model_names=data.model_names, + access_group=data.access_group, + prisma_client=prisma_client, + ) await clear_cache() @@ -317,6 +375,7 @@ async def create_model_group( return NewModelGroupResponse( access_group=data.access_group, model_names=data.model_names, + model_ids=data.model_ids, models_updated=models_updated, ) @@ -496,12 +555,17 @@ async def update_access_group( f"Updating access group: {access_group} with models: {data.model_names}" ) - # Validation: Check if model_names list is provided and not empty - if not data.model_names or len(data.model_names) == 0: + # Validation: Check that at least one of model_names or model_ids is provided + has_model_names = data.model_names and len(data.model_names) > 0 + has_model_ids = data.model_ids and len(data.model_ids) > 0 + + if not has_model_names and not has_model_ids: raise HTTPException( status_code=400, - detail={"error": "model_names list is required and cannot be empty"}, + detail={"error": "Either model_names or model_ids must be provided and non-empty"}, ) + + use_model_ids = has_model_ids # Validation: Check if access group exists try: @@ -521,17 +585,18 @@ async def update_access_group( detail={"error": f"Failed to check access group existence: {str(e)}"}, ) - # Validation: Check if all new models exist - all_valid, missing_models = validate_models_exist( - model_names=data.model_names, - llm_router=llm_router, - ) - - if not all_valid: - raise HTTPException( - status_code=400, - detail={"error": f"Model(s) not found: {', '.join(missing_models)}"}, + # Validation: Check if all new models exist (only if using model_names path) + if not use_model_ids and has_model_names: + all_valid, missing_models = validate_models_exist( + model_names=data.model_names, + llm_router=llm_router, ) + + if not all_valid: + raise HTTPException( + status_code=400, + detail={"error": f"Model(s) not found: {', '.join(missing_models)}"}, + ) try: # Step 1: Remove access group from ALL DB deployments (skip config models) @@ -552,12 +617,19 @@ async def update_access_group( data={"model_info": json.dumps(updated_model_info)}, ) - # Step 2: Add access group to new model_names - models_updated = await update_deployments_with_access_group( - model_names=data.model_names, - access_group=access_group, - prisma_client=prisma_client, - ) + # Step 2: Add access group using the appropriate method + if use_model_ids: + models_updated = await update_specific_deployments_with_access_group( + model_ids=data.model_ids, + access_group=access_group, + prisma_client=prisma_client, + ) + else: + models_updated = await update_deployments_with_access_group( + model_names=data.model_names, + access_group=access_group, + prisma_client=prisma_client, + ) # Clear cache and reload models to pick up the access group changes await clear_cache() @@ -569,6 +641,7 @@ async def update_access_group( return NewModelGroupResponse( access_group=access_group, model_names=data.model_names, + model_ids=data.model_ids, models_updated=models_updated, ) diff --git a/litellm/types/proxy/management_endpoints/model_management_endpoints.py b/litellm/types/proxy/management_endpoints/model_management_endpoints.py index c488c46ecc2c..6f07e5c6de08 100644 --- a/litellm/types/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/types/proxy/management_endpoints/model_management_endpoints.py @@ -21,17 +21,20 @@ class UpdateUsefulLinksRequest(BaseModel): class NewModelGroupRequest(BaseModel): access_group: str # The access group name (e.g., "production-models") - model_names: List[str] # Existing model groups to include (e.g., ["gpt-4", "claude-3"]) + model_names: Optional[List[str]] = None # Existing model groups to include - tags ALL deployments for each name + model_ids: Optional[List[str]] = None # Specific deployment IDs to tag (more precise than model_names) class NewModelGroupResponse(BaseModel): access_group: str - model_names: List[str] + model_names: Optional[List[str]] = None + model_ids: Optional[List[str]] = None models_updated: int # Number of models updated class UpdateModelGroupRequest(BaseModel): - model_names: List[str] # Updated list of model groups to include + model_names: Optional[List[str]] = None # Updated list of model groups to include - tags ALL deployments for each name + model_ids: Optional[List[str]] = None # Specific deployment IDs to tag (more precise than model_names) class DeleteModelGroupResponse(BaseModel): diff --git a/tests/test_litellm/proxy/management_endpoints/test_access_group_management.py b/tests/test_litellm/proxy/management_endpoints/test_access_group_management.py index 1846ffaeb662..18dcb2b0b2d3 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_access_group_management.py +++ b/tests/test_litellm/proxy/management_endpoints/test_access_group_management.py @@ -78,3 +78,213 @@ async def test_create_duplicate_access_group_fails(): assert exc_info.value.status_code == 409 assert "already exists" in str(exc_info.value.detail) +@pytest.mark.asyncio +async def test_create_access_group_with_model_ids_tags_only_specific_deployments(): + """ + Test that using model_ids only tags the specific deployments, not all + deployments sharing the same model_name. + + Fixes: https://github.com/BerriAI/litellm/issues/21544 + """ + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.management_endpoints.model_access_group_management_endpoints import ( + create_model_group, + ) + from litellm.types.proxy.management_endpoints.model_management_endpoints import ( + NewModelGroupRequest, + ) + + deploy_a = MagicMock(model_id="deploy-A", model_name="gpt-4o", model_info={}) + + mock_prisma = MagicMock() + mock_prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[]) + mock_prisma.db.litellm_proxymodeltable.find_unique = AsyncMock(return_value=deploy_a) + mock_prisma.db.litellm_proxymodeltable.update = AsyncMock() + + mock_user = UserAPIKeyAuth( + user_id="test_admin", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + request_data = NewModelGroupRequest( + access_group="production-models", + model_ids=["deploy-A"], + ) + + with patch("litellm.proxy.proxy_server.llm_router", MagicMock()), \ + patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), \ + patch( + "litellm.proxy.management_endpoints.model_access_group_management_endpoints.clear_cache", + new_callable=AsyncMock, + ): + response = await create_model_group(data=request_data, user_api_key_dict=mock_user) + + assert response.models_updated == 1 + assert response.model_ids == ["deploy-A"] + mock_prisma.db.litellm_proxymodeltable.find_unique.assert_called_once_with( + where={"model_id": "deploy-A"} + ) + assert mock_prisma.db.litellm_proxymodeltable.update.call_count == 1 + update_call = mock_prisma.db.litellm_proxymodeltable.update.call_args + assert update_call.kwargs["where"] == {"model_id": "deploy-A"} + + +@pytest.mark.asyncio +async def test_create_access_group_with_model_names_tags_all_deployments(): + """ + Test backward compat: model_names still tags ALL deployments sharing that model_name. + """ + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.management_endpoints.model_access_group_management_endpoints import ( + create_model_group, + ) + from litellm.types.proxy.management_endpoints.model_management_endpoints import ( + NewModelGroupRequest, + ) + + deploy_a = MagicMock(model_id="deploy-A", model_name="gpt-4o", model_info={}) + deploy_b = MagicMock(model_id="deploy-B", model_name="gpt-4o", model_info={}) + deploy_c = MagicMock(model_id="deploy-C", model_name="gpt-4o", model_info={}) + + mock_router = Router( + model_list=[{"model_name": "gpt-4o", "litellm_params": {"model": "gpt-4o", "api_key": "fake-key"}}] + ) + + mock_prisma = MagicMock() + mock_prisma.db.litellm_proxymodeltable.find_many = AsyncMock( + side_effect=[[], [deploy_a, deploy_b, deploy_c]] + ) + mock_prisma.db.litellm_proxymodeltable.update = AsyncMock() + + mock_user = UserAPIKeyAuth( + user_id="test_admin", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + request_data = NewModelGroupRequest(access_group="production-models", model_names=["gpt-4o"]) + + with patch("litellm.proxy.proxy_server.llm_router", mock_router), \ + patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), \ + patch( + "litellm.proxy.management_endpoints.model_access_group_management_endpoints.clear_cache", + new_callable=AsyncMock, + ): + response = await create_model_group(data=request_data, user_api_key_dict=mock_user) + + assert response.models_updated == 3 + assert response.model_names == ["gpt-4o"] + assert mock_prisma.db.litellm_proxymodeltable.update.call_count == 3 + + +@pytest.mark.asyncio +async def test_create_access_group_model_ids_takes_priority_over_model_names(): + """ + Test that when both model_ids and model_names are provided, model_ids is used. + """ + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.management_endpoints.model_access_group_management_endpoints import ( + create_model_group, + ) + from litellm.types.proxy.management_endpoints.model_management_endpoints import ( + NewModelGroupRequest, + ) + + deploy_a = MagicMock(model_id="deploy-A", model_name="gpt-4o", model_info={}) + + mock_prisma = MagicMock() + mock_prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[]) + mock_prisma.db.litellm_proxymodeltable.find_unique = AsyncMock(return_value=deploy_a) + mock_prisma.db.litellm_proxymodeltable.update = AsyncMock() + + mock_user = UserAPIKeyAuth( + user_id="test_admin", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + request_data = NewModelGroupRequest( + access_group="production-models", + model_names=["gpt-4o"], + model_ids=["deploy-A"], + ) + + with patch("litellm.proxy.proxy_server.llm_router", MagicMock()), \ + patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), \ + patch( + "litellm.proxy.management_endpoints.model_access_group_management_endpoints.clear_cache", + new_callable=AsyncMock, + ): + response = await create_model_group(data=request_data, user_api_key_dict=mock_user) + + assert response.models_updated == 1 + mock_prisma.db.litellm_proxymodeltable.find_unique.assert_called_once_with( + where={"model_id": "deploy-A"} + ) + + +@pytest.mark.asyncio +async def test_create_access_group_requires_model_names_or_model_ids(): + """ + Test that creating an access group without model_names or model_ids fails. + """ + from fastapi import HTTPException + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.management_endpoints.model_access_group_management_endpoints import ( + create_model_group, + ) + from litellm.types.proxy.management_endpoints.model_management_endpoints import ( + NewModelGroupRequest, + ) + + mock_user = UserAPIKeyAuth( + user_id="test_admin", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + request_data = NewModelGroupRequest(access_group="production-models") + + with patch("litellm.proxy.proxy_server.llm_router", MagicMock()), \ + patch("litellm.proxy.proxy_server.prisma_client", MagicMock()): + with pytest.raises(HTTPException) as exc_info: + await create_model_group(data=request_data, user_api_key_dict=mock_user) + assert exc_info.value.status_code == 400 + assert "model_names or model_ids" in str(exc_info.value.detail) + + +@pytest.mark.asyncio +async def test_create_access_group_invalid_model_id_returns_400(): + """ + Test that passing a non-existent model_id returns 400 error. + """ + from fastapi import HTTPException + from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth + from litellm.proxy.management_endpoints.model_access_group_management_endpoints import ( + create_model_group, + ) + from litellm.types.proxy.management_endpoints.model_management_endpoints import ( + NewModelGroupRequest, + ) + + mock_prisma = MagicMock() + mock_prisma.db.litellm_proxymodeltable.find_many = AsyncMock(return_value=[]) + mock_prisma.db.litellm_proxymodeltable.find_unique = AsyncMock(return_value=None) + + mock_user = UserAPIKeyAuth( + user_id="test_admin", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + request_data = NewModelGroupRequest( + access_group="production-models", + model_ids=["non-existent-id"], + ) + + with patch("litellm.proxy.proxy_server.llm_router", MagicMock()), \ + patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), \ + patch( + "litellm.proxy.management_endpoints.model_access_group_management_endpoints.clear_cache", + new_callable=AsyncMock, + ): + with pytest.raises(HTTPException) as exc_info: + await create_model_group(data=request_data, user_api_key_dict=mock_user) + assert exc_info.value.status_code == 400 + assert "non-existent-id" in str(exc_info.value.detail)