Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 193 additions & 15 deletions litellm/proxy/proxy_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7892,6 +7892,120 @@ def _paginate_models_response(
}


async def _filter_models_by_team_id(
all_models: List[Dict[str, Any]],
team_id: str,
prisma_client: PrismaClient,
llm_router: Router,
) -> List[Dict[str, Any]]:
"""
Filter models by team ID. Returns models where:
- direct_access is True, OR
- team_id is in access_via_team_ids

Also searches config and database for models accessible to the team.

Args:
all_models: List of models to filter
team_id: Team ID to filter by
prisma_client: Prisma client for database queries
llm_router: Router instance for config queries

Returns:
Filtered list of models
"""
# Get team from database
try:
team_db_object = await prisma_client.db.litellm_teamtable.find_unique(
where={"team_id": team_id}
)
if team_db_object is None:
verbose_proxy_logger.warning(f"Team {team_id} not found in database")
# If team doesn't exist, return empty list
return []

team_object = LiteLLM_TeamTable(**team_db_object.model_dump())
except Exception as e:
verbose_proxy_logger.exception(f"Error fetching team {team_id}: {str(e)}")
return []

# Get models accessible to this team (similar to _add_team_models_to_all_models)
team_accessible_model_ids: Set[str] = set()

if (
len(team_object.models) == 0 # empty list = all model access
or SpecialModelNames.all_proxy_models.value in team_object.models
):
# Team has access to all models
model_list = llm_router.get_model_list() if llm_router else []
if model_list is not None:
for model in model_list:
model_id = model.get("model_info", {}).get("id", None)
if model_id is None:
continue
# if team model id set, check if team id matches
team_model_id = model.get("model_info", {}).get("team_id", None)
can_add_model = False
if team_model_id is None:
can_add_model = True
elif team_model_id == team_id:
can_add_model = True

if can_add_model:
team_accessible_model_ids.add(model_id)
else:
# Team has access to specific models
for model_name in team_object.models:
_models = llm_router.get_model_list(
model_name=model_name, team_id=team_id
) if llm_router else []
if _models is not None:
for model in _models:
model_id = model.get("model_info", {}).get("id", None)
if model_id is not None:
team_accessible_model_ids.add(model_id)

# Also search database for models accessible to this team
# This complements the config search done above
try:
if team_object.models and SpecialModelNames.all_proxy_models.value not in team_object.models:
# Team has specific models - check database for those model names
db_models = await prisma_client.db.litellm_proxymodeltable.find_many(
where={"model_name": {"in": team_object.models}}
)
for db_model in db_models:
model_id = db_model.model_id
if model_id:
team_accessible_model_ids.add(model_id)
except Exception as e:
verbose_proxy_logger.debug(f"Error querying database models for team {team_id}: {str(e)}")

# Filter models based on direct_access or access_via_team_ids
# Models are already enriched with these fields before this function is called
filtered_models = []
for _model in all_models:
model_info = _model.get("model_info", {})
model_id = model_info.get("id", None)

# Include if direct_access is True
if model_info.get("direct_access", False):
filtered_models.append(_model)
continue

# Include if team_id is in access_via_team_ids
access_via_team_ids = model_info.get("access_via_team_ids", [])
if isinstance(access_via_team_ids, list) and team_id in access_via_team_ids:
filtered_models.append(_model)
continue

# Also include if model_id is in team_accessible_model_ids (from config/db search)
# This catches models that might not have been enriched with access_via_team_ids yet
if model_id and model_id in team_accessible_model_ids:
filtered_models.append(_model)

return filtered_models


@router.get(
"/v2/model/info",
description="v2 - returns models available to the user based on their API key permissions. Shows model info from config.yaml (except api key and api base). Filter to just user-added models with ?user_models_only=true",
Expand All @@ -7916,6 +8030,12 @@ async def model_info_v2(
search: Optional[str] = fastapi.Query(
None, description="Search model names (case-insensitive partial match)"
),
modelId: Optional[str] = fastapi.Query(
None, description="Search for a specific model by its unique ID"
),
teamId: Optional[str] = fastapi.Query(
None, description="Filter models by team ID. Returns models with direct_access=True or teamId in access_via_team_ids"
),
):
"""
BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now.
Expand All @@ -7940,24 +8060,65 @@ async def model_info_v2(

# Load existing config
await proxy_config.get_config()
all_models = copy.deepcopy(llm_router.model_list)

if user_model is not None:
# if user does not use a config.yaml, https://github.com/BerriAI/litellm/issues/2061
all_models += [user_model]
# If modelId is provided, search for the specific model
if modelId is not None:
found_model = None

# First, search in config
if llm_router is not None:
found_model = llm_router.get_model_info(id=modelId)
if found_model:
found_model = copy.deepcopy(found_model)

# If not found in config, search in database
if found_model is None:
try:
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": modelId}
)
if db_model:
# Convert database model to router format
decrypted_models = proxy_config.decrypt_model_list_from_db([db_model])
if decrypted_models:
found_model = decrypted_models[0]
except Exception as e:
verbose_proxy_logger.exception(
f"Error querying database for modelId {modelId}: {str(e)}"
)

# If model found, verify search filter if provided
if found_model is not None:
if search is not None and search.strip():
search_lower = search.lower().strip()
model_name = found_model.get("model_name", "")
if search_lower not in model_name.lower():
# Model found but doesn't match search filter
found_model = None

# Set all_models to the found model or empty list
all_models = [found_model] if found_model is not None else []
search_total_count: Optional[int] = len(all_models)
else:
# Normal flow when modelId is not provided
all_models = copy.deepcopy(llm_router.model_list)

if model is not None:
all_models = [m for m in all_models if m["model_name"] == model]
if user_model is not None:
# if user does not use a config.yaml, https://github.com/BerriAI/litellm/issues/2061
all_models += [user_model]

# Apply search filter if provided
all_models, search_total_count = await _apply_search_filter_to_models(
all_models=all_models,
search=search or "",
page=page,
size=size,
prisma_client=prisma_client,
proxy_config=proxy_config,
)
if model is not None:
all_models = [m for m in all_models if m["model_name"] == model]

# Apply search filter if provided
all_models, search_total_count = await _apply_search_filter_to_models(
all_models=all_models,
search=search or "",
page=page,
size=size,
prisma_client=prisma_client,
proxy_config=proxy_config,
)

if user_models_only:
all_models = await non_admin_all_models(
Expand All @@ -7976,10 +8137,27 @@ async def model_info_v2(
)

# Fill in model info based on config.yaml and litellm model_prices_and_context_window.json
# This must happen before teamId filtering so that direct_access and access_via_team_ids are populated
for i, _model in enumerate(all_models):
all_models[i] = _enrich_model_info_with_litellm_data(
model=_model, debug=debug if debug is not None else False, llm_router=llm_router
)

# Apply teamId filter if provided
if teamId is not None and teamId.strip():
all_models = await _filter_models_by_team_id(
all_models=all_models,
team_id=teamId.strip(),
prisma_client=prisma_client,
llm_router=llm_router,
)
# Update search_total_count after teamId filter is applied
search_total_count = len(all_models)

# If modelId was provided, update search_total_count after filters are applied
# to ensure pagination reflects the final filtered result (0 or 1)
if modelId is not None:
search_total_count = len(all_models)

verbose_proxy_logger.debug("all_models: %s", all_models)

Expand Down
Loading
Loading