Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import json
from typing import Any, Dict, List, Tuple
from typing import Any, Dict, List, Optional, Tuple

from fastapi import APIRouter, Depends, HTTPException

Expand Down Expand Up @@ -141,6 +141,50 @@ async def update_deployments_with_access_group(
return models_updated


async def update_specific_deployments_with_access_group(
model_ids: List[str],
access_group: str,
prisma_client: PrismaClient,
) -> int:
"""
Update specific deployments (by model_id) to include the access group.

Unlike update_deployments_with_access_group which tags ALL deployments sharing
a model_name, this function only tags the specific deployments identified by
their unique model_id.
"""
models_updated = 0
for model_id in model_ids:
verbose_proxy_logger.debug(
f"Updating specific deployment model_id: {model_id}"
)
deployment = await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": model_id}
)
if deployment is None:
raise HTTPException(
status_code=400,
detail={
"error": f"Deployment with model_id '{model_id}' not found in Database."
},
)
model_info = deployment.model_info or {}
updated_model_info, was_modified = add_access_group_to_deployment(
model_info=model_info,
access_group=access_group,
)
if was_modified:
await prisma_client.db.litellm_proxymodeltable.update(
where={"model_id": model_id},
data={"model_info": json.dumps(updated_model_info)},
)
models_updated += 1
verbose_proxy_logger.debug(
f"Updated deployment {model_id} with access group: {access_group}"
)
return models_updated


def remove_access_group_from_deployment(
model_info: Dict[str, Any], access_group: str
) -> Tuple[Dict[str, Any], bool]:
Expand Down Expand Up @@ -263,24 +307,31 @@ async def create_model_group(
detail={"error": "access_group is required and cannot be empty"},
)

# Validation: Check if model_names list is provided and not empty
if not data.model_names or len(data.model_names) == 0:
# Validation: Check that at least one of model_names or model_ids is provided
has_model_names = data.model_names and len(data.model_names) > 0
has_model_ids = data.model_ids and len(data.model_ids) > 0

if not has_model_names and not has_model_ids:
raise HTTPException(
status_code=400,
detail={"error": "model_names list is required and cannot be empty"},
detail={"error": "Either model_names or model_ids must be provided and non-empty"},
)

# Validation: Check if all models exist in the router
all_valid, missing_models = validate_models_exist(
model_names=data.model_names,
llm_router=llm_router,
)

if not all_valid:
raise HTTPException(
status_code=400,
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},

# If model_ids is provided, use it (more precise targeting)
use_model_ids = has_model_ids

# Validate model_names exist in router (only if using model_names path)
if not use_model_ids and has_model_names:
all_valid, missing_models = validate_models_exist(
model_names=data.model_names,
llm_router=llm_router,
)

if not all_valid:
raise HTTPException(
status_code=400,
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},
)

# Check if database is connected
if prisma_client is None:
Expand All @@ -301,12 +352,19 @@ async def create_model_group(
detail={"error": f"Access group '{data.access_group}' already exists. Use PUT /access_group/{data.access_group}/update to modify it."},
)

# Update deployments using helper function
models_updated = await update_deployments_with_access_group(
model_names=data.model_names,
access_group=data.access_group,
prisma_client=prisma_client,
)
# Update deployments using the appropriate method
if use_model_ids:
models_updated = await update_specific_deployments_with_access_group(
model_ids=data.model_ids,
access_group=data.access_group,
prisma_client=prisma_client,
)
else:
models_updated = await update_deployments_with_access_group(
model_names=data.model_names,
access_group=data.access_group,
prisma_client=prisma_client,
)

await clear_cache()

Expand All @@ -317,6 +375,7 @@ async def create_model_group(
return NewModelGroupResponse(
access_group=data.access_group,
model_names=data.model_names,
model_ids=data.model_ids,
models_updated=models_updated,
)

Expand Down Expand Up @@ -496,12 +555,17 @@ async def update_access_group(
f"Updating access group: {access_group} with models: {data.model_names}"
)

# Validation: Check if model_names list is provided and not empty
if not data.model_names or len(data.model_names) == 0:
# Validation: Check that at least one of model_names or model_ids is provided
has_model_names = data.model_names and len(data.model_names) > 0
has_model_ids = data.model_ids and len(data.model_ids) > 0

if not has_model_names and not has_model_ids:
raise HTTPException(
status_code=400,
detail={"error": "model_names list is required and cannot be empty"},
detail={"error": "Either model_names or model_ids must be provided and non-empty"},
)

use_model_ids = has_model_ids

# Validation: Check if access group exists
try:
Expand All @@ -521,17 +585,18 @@ async def update_access_group(
detail={"error": f"Failed to check access group existence: {str(e)}"},
)

# Validation: Check if all new models exist
all_valid, missing_models = validate_models_exist(
model_names=data.model_names,
llm_router=llm_router,
)

if not all_valid:
raise HTTPException(
status_code=400,
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},
# Validation: Check if all new models exist (only if using model_names path)
if not use_model_ids and has_model_names:
all_valid, missing_models = validate_models_exist(
model_names=data.model_names,
llm_router=llm_router,
)

if not all_valid:
raise HTTPException(
status_code=400,
detail={"error": f"Model(s) not found: {', '.join(missing_models)}"},
)

try:
# Step 1: Remove access group from ALL DB deployments (skip config models)
Expand All @@ -552,12 +617,19 @@ async def update_access_group(
data={"model_info": json.dumps(updated_model_info)},
)

# Step 2: Add access group to new model_names
models_updated = await update_deployments_with_access_group(
model_names=data.model_names,
access_group=access_group,
prisma_client=prisma_client,
)
# Step 2: Add access group using the appropriate method
if use_model_ids:
models_updated = await update_specific_deployments_with_access_group(
model_ids=data.model_ids,
access_group=access_group,
prisma_client=prisma_client,
)
else:
models_updated = await update_deployments_with_access_group(
model_names=data.model_names,
access_group=access_group,
prisma_client=prisma_client,
)

# Clear cache and reload models to pick up the access group changes
await clear_cache()
Expand All @@ -569,6 +641,7 @@ async def update_access_group(
return NewModelGroupResponse(
access_group=access_group,
model_names=data.model_names,
model_ids=data.model_ids,
models_updated=models_updated,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,20 @@ class UpdateUsefulLinksRequest(BaseModel):

class NewModelGroupRequest(BaseModel):
access_group: str # The access group name (e.g., "production-models")
model_names: List[str] # Existing model groups to include (e.g., ["gpt-4", "claude-3"])
model_names: Optional[List[str]] = None # Existing model groups to include - tags ALL deployments for each name
model_ids: Optional[List[str]] = None # Specific deployment IDs to tag (more precise than model_names)


class NewModelGroupResponse(BaseModel):
access_group: str
model_names: List[str]
model_names: Optional[List[str]] = None
model_ids: Optional[List[str]] = None
models_updated: int # Number of models updated


class UpdateModelGroupRequest(BaseModel):
model_names: List[str] # Updated list of model groups to include
model_names: Optional[List[str]] = None # Updated list of model groups to include - tags ALL deployments for each name
model_ids: Optional[List[str]] = None # Specific deployment IDs to tag (more precise than model_names)


class DeleteModelGroupResponse(BaseModel):
Expand Down
Loading
Loading