diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f74a89054c4..b1439712d7f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -550,9 +550,9 @@ def generate_feedback_box(): server_root_path = get_server_root_path() _license_check = LicenseCheck() premium_user: bool = _license_check.is_premium() -premium_user_data: Optional["EnterpriseLicenseData"] = ( - _license_check.airgapped_license_data -) +premium_user_data: Optional[ + "EnterpriseLicenseData" +] = _license_check.airgapped_license_data global_max_parallel_request_retries_env: Optional[str] = os.getenv( "LITELLM_GLOBAL_MAX_PARALLEL_REQUEST_RETRIES" ) @@ -825,6 +825,7 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915 title=_title, description=_description, version=version, + root_path=server_root_path, lifespan=proxy_startup_event, # type: ignore[reportGeneralTypeIssues] ) @@ -1209,9 +1210,9 @@ async def root_redirect(): config_agents: Optional[List[AgentConfig]] = None otel_logging = False prisma_client: Optional[PrismaClient] = None -shared_aiohttp_session: Optional["ClientSession"] = ( - None # Global shared session for connection reuse -) +shared_aiohttp_session: Optional[ + "ClientSession" +] = None # Global shared session for connection reuse user_api_key_cache = DualCache( default_in_memory_ttl=UserAPIKeyCacheTTLEnum.in_memory_cache_ttl.value ) @@ -1219,9 +1220,9 @@ async def root_redirect(): dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[RedisCache] = ( - None # redis cache used for tracking spend, tpm/rpm limits -) +redis_usage_cache: Optional[ + RedisCache +] = None # redis cache used for tracking spend, tpm/rpm limits polling_via_cache_enabled: Union[Literal["all"], List[str], bool] = False polling_cache_ttl: int = 3600 # Default 1 hour TTL for polling cache user_custom_auth = None @@ -1560,9 +1561,9 @@ async def _update_team_cache(): _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[LiteLLM_TeamTable] = ( - await user_api_key_cache.async_get_cache(key=_id) - ) + existing_spend_obj: Optional[ + LiteLLM_TeamTable + ] = await user_api_key_cache.async_get_cache(key=_id) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2856,6 +2857,7 @@ async def _init_policy_engine( from litellm.proxy.policy_engine.init_policies import init_policies from litellm.proxy.policy_engine.policy_validator import PolicyValidator + if config is None: verbose_proxy_logger.debug("Policy engine: config is None, skipping") return @@ -2867,7 +2869,9 @@ async def _init_policy_engine( policy_attachments_config = config.get("policy_attachments", None) - verbose_proxy_logger.info(f"Policy engine: found {len(policies_config)} policies in config") + verbose_proxy_logger.info( + f"Policy engine: found {len(policies_config)} policies in config" + ) # Initialize policies await init_policies( @@ -4009,10 +4013,10 @@ async def _init_guardrails_in_db(self, prisma_client: PrismaClient): ) try: - guardrails_in_db: List[Guardrail] = ( - await GuardrailRegistry.get_all_guardrails_from_db( - prisma_client=prisma_client - ) + guardrails_in_db: List[ + Guardrail + ] = await GuardrailRegistry.get_all_guardrails_from_db( + prisma_client=prisma_client ) verbose_proxy_logger.debug( "guardrails from the DB %s", str(guardrails_in_db) @@ -4046,7 +4050,9 @@ async def _init_policies_in_db(self, prisma_client: PrismaClient): await policy_registry.sync_policies_from_db(prisma_client=prisma_client) # Sync attachments from DB to in-memory registry - await attachment_registry.sync_attachments_from_db(prisma_client=prisma_client) + await attachment_registry.sync_attachments_from_db( + prisma_client=prisma_client + ) verbose_proxy_logger.debug( "Successfully synced policies and attachments from DB" @@ -4369,9 +4375,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ["AZURE_API_VERSION"] = ( - api_version # set this for azure - litellm can read this from the env - ) + os.environ[ + "AZURE_API_VERSION" + ] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -5217,7 +5223,9 @@ async def model_list( # Include model access groups if requested if include_model_access_groups: - proxy_model_list = list(set(proxy_model_list + list(model_access_groups.keys()))) + proxy_model_list = list( + set(proxy_model_list + list(model_access_groups.keys())) + ) # Get complete model list including wildcard routes if requested from litellm.proxy.auth.model_checks import get_complete_model_list @@ -7674,12 +7682,12 @@ def _enrich_model_info_with_litellm_data( """ Enrich a model dictionary with litellm model info (pricing, context window, etc.) and remove sensitive information. - + Args: model: Model dictionary to enrich debug: Whether to include debug information like openai_client llm_router: Optional router instance for debug info - + Returns: Enriched model dictionary with sensitive info removed """ @@ -7689,9 +7697,7 @@ def _enrich_model_info_with_litellm_data( _openai_client = "None" if llm_router is not None: _openai_client = ( - llm_router._get_client( - deployment=model, kwargs={}, client_type="async" - ) + llm_router._get_client(deployment=model, kwargs={}, client_type="async") or "None" ) else: @@ -7749,7 +7755,7 @@ async def _apply_search_filter_to_models( ) -> Tuple[List[Dict[str, Any]], Optional[int]]: """ Apply search filter to models, querying database for additional matching models. - + Args: all_models: List of models to filter search: Search term (case-insensitive) @@ -7757,44 +7763,43 @@ async def _apply_search_filter_to_models( size: Page size prisma_client: Prisma client for database queries proxy_config: Proxy config for decrypting models - + Returns: Tuple of (filtered_models, total_count). total_count is None if not searching. """ if not search or not search.strip(): return all_models, None - + search_lower = search.lower().strip() - + # Filter models in router by search term filtered_router_models = [ - m for m in all_models - if search_lower in m.get("model_name", "").lower() + m for m in all_models if search_lower in m.get("model_name", "").lower() ] - + # Separate filtered models into config vs db models, and track db model IDs filtered_config_models = [] db_model_ids_in_router = set() - + for m in filtered_router_models: model_info = m.get("model_info", {}) is_db_model = model_info.get("db_model", False) model_id = model_info.get("id") - + if is_db_model and model_id: db_model_ids_in_router.add(model_id) else: filtered_config_models.append(m) - + config_models_count = len(filtered_config_models) db_models_in_router_count = len(db_model_ids_in_router) router_models_count = config_models_count + db_models_in_router_count - + # Query database for additional models with search term db_models = [] db_models_total_count = 0 models_needed_for_page = size * page - + # Only query database if prisma_client is available if prisma_client is not None: try: @@ -7810,31 +7815,36 @@ async def _apply_search_filter_to_models( db_where_condition["model_id"] = { "not": {"in": list(db_model_ids_in_router)} } - + # Get total count of matching database models - db_models_total_count = await prisma_client.db.litellm_proxymodeltable.count( - where=db_where_condition + db_models_total_count = ( + await prisma_client.db.litellm_proxymodeltable.count( + where=db_where_condition + ) ) - + # Calculate total count for search results search_total_count = router_models_count + db_models_total_count - + # Fetch database models if we need more for the current page if router_models_count < models_needed_for_page: models_to_fetch = min( - models_needed_for_page - router_models_count, - db_models_total_count + models_needed_for_page - router_models_count, db_models_total_count ) - + if models_to_fetch > 0: - db_models_raw = await prisma_client.db.litellm_proxymodeltable.find_many( - where=db_where_condition, - take=models_to_fetch, + db_models_raw = ( + await prisma_client.db.litellm_proxymodeltable.find_many( + where=db_where_condition, + take=models_to_fetch, + ) ) - + # Convert database models to router format for db_model in db_models_raw: - decrypted_models = proxy_config.decrypt_model_list_from_db([db_model]) + decrypted_models = proxy_config.decrypt_model_list_from_db( + [db_model] + ) if decrypted_models: db_models.extend(decrypted_models) except Exception as e: @@ -7846,7 +7856,7 @@ async def _apply_search_filter_to_models( else: # If no prisma_client, only use router models search_total_count = router_models_count - + # Combine all models filtered_models = filtered_router_models + db_models return filtered_models, search_total_count @@ -7861,28 +7871,28 @@ def _paginate_models_response( ) -> Dict[str, Any]: """ Paginate models and return response dictionary. - + Args: all_models: List of all models page: Current page number size: Page size total_count: Total count (if None, uses len(all_models)) search: Search term (for logging) - + Returns: Paginated response dictionary """ if total_count is None: total_count = len(all_models) - + skip = (page - 1) * size total_pages = -(-total_count // size) if total_count > 0 else 0 paginated_models = all_models[skip : skip + size] - + verbose_proxy_logger.debug( f"Pagination: skip={skip}, take={size}, total_count={total_count}, total_pages={total_pages}, search={search}" ) - + return { "data": paginated_models, "total_count": total_count, @@ -7902,15 +7912,15 @@ async def _filter_models_by_team_id( Filter models by team ID. Returns models where: - direct_access is True, OR - team_id is in access_via_team_ids - + Also searches config and database for models accessible to the team. - + Args: all_models: List of models to filter team_id: Team ID to filter by prisma_client: Prisma client for database queries llm_router: Router instance for config queries - + Returns: Filtered list of models """ @@ -7923,15 +7933,15 @@ async def _filter_models_by_team_id( verbose_proxy_logger.warning(f"Team {team_id} not found in database") # If team doesn't exist, return empty list return [] - + team_object = LiteLLM_TeamTable(**team_db_object.model_dump()) except Exception as e: verbose_proxy_logger.exception(f"Error fetching team {team_id}: {str(e)}") return [] - + # Get models accessible to this team (similar to _add_team_models_to_all_models) team_accessible_model_ids: Set[str] = set() - + if ( len(team_object.models) == 0 # empty list = all model access or SpecialModelNames.all_proxy_models.value in team_object.models @@ -7950,25 +7960,30 @@ async def _filter_models_by_team_id( can_add_model = True elif team_model_id == team_id: can_add_model = True - + if can_add_model: team_accessible_model_ids.add(model_id) else: # Team has access to specific models for model_name in team_object.models: - _models = llm_router.get_model_list( - model_name=model_name, team_id=team_id - ) if llm_router else [] + _models = ( + llm_router.get_model_list(model_name=model_name, team_id=team_id) + if llm_router + else [] + ) if _models is not None: for model in _models: model_id = model.get("model_info", {}).get("id", None) if model_id is not None: team_accessible_model_ids.add(model_id) - + # Also search database for models accessible to this team # This complements the config search done above try: - if team_object.models and SpecialModelNames.all_proxy_models.value not in team_object.models: + if ( + team_object.models + and SpecialModelNames.all_proxy_models.value not in team_object.models + ): # Team has specific models - check database for those model names db_models = await prisma_client.db.litellm_proxymodeltable.find_many( where={"model_name": {"in": team_object.models}} @@ -7978,31 +7993,33 @@ async def _filter_models_by_team_id( if model_id: team_accessible_model_ids.add(model_id) except Exception as e: - verbose_proxy_logger.debug(f"Error querying database models for team {team_id}: {str(e)}") - + verbose_proxy_logger.debug( + f"Error querying database models for team {team_id}: {str(e)}" + ) + # Filter models based on direct_access or access_via_team_ids # Models are already enriched with these fields before this function is called filtered_models = [] for _model in all_models: model_info = _model.get("model_info", {}) model_id = model_info.get("id", None) - + # Include if direct_access is True if model_info.get("direct_access", False): filtered_models.append(_model) continue - + # Include if team_id is in access_via_team_ids access_via_team_ids = model_info.get("access_via_team_ids", []) if isinstance(access_via_team_ids, list) and team_id in access_via_team_ids: filtered_models.append(_model) continue - + # Also include if model_id is in team_accessible_model_ids (from config/db search) # This catches models that might not have been enriched with access_via_team_ids yet if model_id and model_id in team_accessible_model_ids: filtered_models.append(_model) - + return filtered_models @@ -8034,7 +8051,8 @@ async def model_info_v2( None, description="Search for a specific model by its unique ID" ), teamId: Optional[str] = fastapi.Query( - None, description="Filter models by team ID. Returns models with direct_access=True or teamId in access_via_team_ids" + None, + description="Filter models by team ID. Returns models with direct_access=True or teamId in access_via_team_ids", ), ): """ @@ -8064,13 +8082,13 @@ async def model_info_v2( # If modelId is provided, search for the specific model if modelId is not None: found_model = None - + # First, search in config if llm_router is not None: found_model = llm_router.get_model_info(id=modelId) if found_model: found_model = copy.deepcopy(found_model) - + # If not found in config, search in database if found_model is None: try: @@ -8079,14 +8097,16 @@ async def model_info_v2( ) if db_model: # Convert database model to router format - decrypted_models = proxy_config.decrypt_model_list_from_db([db_model]) + decrypted_models = proxy_config.decrypt_model_list_from_db( + [db_model] + ) if decrypted_models: found_model = decrypted_models[0] except Exception as e: verbose_proxy_logger.exception( f"Error querying database for modelId {modelId}: {str(e)}" ) - + # If model found, verify search filter if provided if found_model is not None: if search is not None and search.strip(): @@ -8095,7 +8115,7 @@ async def model_info_v2( if search_lower not in model_name.lower(): # Model found but doesn't match search filter found_model = None - + # Set all_models to the found model or empty list all_models = [found_model] if found_model is not None else [] search_total_count: Optional[int] = len(all_models) @@ -8135,14 +8155,16 @@ async def model_info_v2( llm_router=llm_router, all_models=all_models, ) - + # Fill in model info based on config.yaml and litellm model_prices_and_context_window.json # This must happen before teamId filtering so that direct_access and access_via_team_ids are populated for i, _model in enumerate(all_models): all_models[i] = _enrich_model_info_with_litellm_data( - model=_model, debug=debug if debug is not None else False, llm_router=llm_router + model=_model, + debug=debug if debug is not None else False, + llm_router=llm_router, ) - + # Apply teamId filter if provided if teamId is not None and teamId.strip(): all_models = await _filter_models_by_team_id( @@ -8153,14 +8175,14 @@ async def model_info_v2( ) # Update search_total_count after teamId filter is applied search_total_count = len(all_models) - + # If modelId was provided, update search_total_count after filters are applied # to ensure pagination reflects the final filtered result (0 or 1) if modelId is not None: search_total_count = len(all_models) verbose_proxy_logger.debug("all_models: %s", all_models) - + return _paginate_models_response( all_models=all_models, page=page, @@ -10278,9 +10300,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[idx].field_description = ( - sub_field_info.description - ) + nested_fields[ + idx + ].field_description = sub_field_info.description idx += 1 _stored_in_db = None diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 5dcb0339739..2d62bde22e5 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -982,7 +982,9 @@ async def _process_guardrail_callback( try: # Check if load balancing should be used - if guardrail_name and self._should_use_guardrail_load_balancing(guardrail_name): + if guardrail_name and self._should_use_guardrail_load_balancing( + guardrail_name + ): response = await self._execute_guardrail_with_load_balancing( guardrail_name=guardrail_name, hook_type="pre_call", @@ -1017,7 +1019,11 @@ async def _process_guardrail_callback( latency_seconds = guardrail_end_time - guardrail_start_time # Get guardrail name for metrics (fallback if not set) - metrics_guardrail_name = guardrail_name or getattr(callback, "guardrail_name", callback.__class__.__name__) or "unknown" + metrics_guardrail_name = ( + guardrail_name + or getattr(callback, "guardrail_name", callback.__class__.__name__) + or "unknown" + ) # Find PrometheusLogger in callbacks and record metrics for prom_callback in litellm.callbacks: @@ -2218,17 +2224,17 @@ async def _query_first_with_cached_plan_fallback( ) -> Optional[dict]: """ Execute a query with automatic fallback for PostgreSQL cached plan errors. - + This handles the "cached plan must not change result type" error that occurs during rolling deployments when schema changes are applied while old pods still have cached query plans expecting the old schema. - + Args: sql_query: SQL query string to execute - + Returns: Query result or None - + Raises: Original exception if not a cached plan error """ @@ -2241,7 +2247,7 @@ async def _query_first_with_cached_plan_fallback( # Add a unique comment to make the query different sql_query_retry = sql_query.replace( "SELECT", - f"SELECT /* cache_invalidated_{int(time.time() * 1000)} */" + f"SELECT /* cache_invalidated_{int(time.time() * 1000)} */", ) verbose_proxy_logger.warning( "PostgreSQL cached plan error detected for token lookup, " @@ -2583,7 +2589,9 @@ async def get_data( # noqa: PLR0915 WHERE v.token = '{token}' """ - response = await self._query_first_with_cached_plan_fallback(sql_query) + response = await self._query_first_with_cached_plan_fallback( + sql_query + ) if response is not None: if response["team_models"] is None: @@ -4227,7 +4235,7 @@ def get_server_root_path() -> str: - If SERVER_ROOT_PATH is set, return it. - Otherwise, default to "/". """ - return os.getenv("SERVER_ROOT_PATH", "/") + return os.getenv("SERVER_ROOT_PATH", "") def get_prisma_client_or_throw(message: str): diff --git a/tests/proxy_unit_tests/test_server_root_path.py b/tests/proxy_unit_tests/test_server_root_path.py new file mode 100644 index 00000000000..4b39558e15a --- /dev/null +++ b/tests/proxy_unit_tests/test_server_root_path.py @@ -0,0 +1,64 @@ +import os +from unittest import mock +from litellm.proxy import utils + + +# Test the utility function logic +def test_get_server_root_path_unset(): + """ + Test that get_server_root_path returns empty string when SERVER_ROOT_PATH is unset + """ + with mock.patch.dict(os.environ, {}, clear=True): + # We need to make sure SERVER_ROOT_PATH is not in env + if "SERVER_ROOT_PATH" in os.environ: + del os.environ["SERVER_ROOT_PATH"] + + root_path = utils.get_server_root_path() + assert ( + root_path == "" + ), "Should return empty string when unset to allow X-Forwarded-Prefix" + + +def test_get_server_root_path_set(): + """ + Test that get_server_root_path returns the value when SERVER_ROOT_PATH is set + """ + with mock.patch.dict(os.environ, {"SERVER_ROOT_PATH": "/my-path"}, clear=True): + root_path = utils.get_server_root_path() + assert root_path == "/my-path", "Should return the set value" + + +def test_get_server_root_path_empty_string(): + """ + Test that get_server_root_path returns empty string when SERVER_ROOT_PATH is explicitly empty + """ + with mock.patch.dict(os.environ, {"SERVER_ROOT_PATH": ""}, clear=True): + root_path = utils.get_server_root_path() + assert ( + root_path == "" + ), "Should return empty string when explicitly set to empty" + + +# Integration test simulation for FastAPI app initialization +def test_fastapi_app_initialization_mock(): + """ + Simulate how proxy_server.py initializes FastAPI app with the root_path. + We don't import proxy_server because it has global side effects/singletons. + Instead we verify the logic flow. + """ + from fastapi import FastAPI + + # CASE 1: Proxy Mode (Unset) + with mock.patch.dict(os.environ, {}, clear=True): + if "SERVER_ROOT_PATH" in os.environ: + del os.environ["SERVER_ROOT_PATH"] + + server_root_path = utils.get_server_root_path() + app = FastAPI(root_path=server_root_path) + assert app.root_path == "" + + # CASE 2: Direct Mode (Set) + with mock.patch.dict(os.environ, {"SERVER_ROOT_PATH": "/custom-root"}, clear=True): + server_root_path = utils.get_server_root_path() + app = FastAPI(root_path=server_root_path) + assert app.root_path == "/custom-root"