diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 95739834a9a9..f453ef28c5dc 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -676,6 +676,11 @@ class LiteLLMRoutes(enum.Enum): ) +# Pre-computed tuple for fast startswith() checks against mapped pass-through routes. +# Defined once here and imported by auth/route_checks modules. +MAPPED_PASS_THROUGH_PREFIXES = tuple(LiteLLMRoutes.mapped_pass_through_routes.value) + + class LiteLLMPromptInjectionParams(LiteLLMPydanticObjectBase): heuristics_check: bool = False vector_db_check: bool = False diff --git a/litellm/proxy/auth/route_checks.py b/litellm/proxy/auth/route_checks.py index a00401008fcb..ddf19f67f1a6 100644 --- a/litellm/proxy/auth/route_checks.py +++ b/litellm/proxy/auth/route_checks.py @@ -5,6 +5,7 @@ from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ( + MAPPED_PASS_THROUGH_PREFIXES, CommonProxyErrors, LiteLLM_UserTable, LiteLLMRoutes, @@ -346,9 +347,8 @@ def is_llm_api_route(route: str) -> bool: if RouteChecks._is_azure_openai_route(route=route): return True - for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value: - if _llm_passthrough_route in route: - return True + if route.startswith(MAPPED_PASS_THROUGH_PREFIXES): + return True return False @staticmethod diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 138f9bab2c42..11d6ebe774e0 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -78,6 +78,9 @@ user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL +# Pre-computed constants to avoid repeated enum attribute access +_PUBLIC_ROUTES = LiteLLMRoutes.public_routes.value + custom_litellm_key_header = APIKeyHeader( name=SpecialHeaders.custom_litellm_api_key.value, auto_error=False, @@ -358,15 +361,6 @@ def get_api_key( google_auth_key: str = _safe_get_request_query_params(request).get("key") or "" passed_in_key = google_auth_key api_key = google_auth_key - elif pass_through_endpoints is not None: - for endpoint in pass_through_endpoints: - if endpoint.get("path", "") == route: - headers: Optional[dict] = endpoint.get("headers", None) - if headers is not None: - header_key: str = headers.get("litellm_user_api_key", "") - if request.headers.get(header_key) is not None: - api_key = request.headers.get(header_key) or "" - passed_in_key = api_key return api_key, passed_in_key @@ -376,29 +370,22 @@ async def check_api_key_for_custom_headers_or_pass_through_endpoints( pass_through_endpoints: Optional[List[dict]], api_key: str, ) -> Union[UserAPIKeyAuth, str]: - is_mapped_pass_through_route: bool = False - for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore - if route.startswith(mapped_route): - is_mapped_pass_through_route = True - if is_mapped_pass_through_route: - if request.headers.get("litellm_user_api_key") is not None: - api_key = request.headers.get("litellm_user_api_key") or "" + # Fast path: nothing to check + is_mapped = route.startswith(MAPPED_PASS_THROUGH_PREFIXES) + if not is_mapped and pass_through_endpoints is None: + return api_key + + if is_mapped: + value = request.headers.get("litellm_user_api_key") + if value is not None: + api_key = value + if pass_through_endpoints is not None: for endpoint in pass_through_endpoints: if isinstance(endpoint, dict) and endpoint.get("path", "") == route: - ## IF AUTH DISABLED if endpoint.get("auth") is not True: return UserAPIKeyAuth() - ## IF AUTH ENABLED - ### IF CUSTOM PARSER REQUIRED - if ( - endpoint.get("custom_auth_parser") is not None - and endpoint.get("custom_auth_parser") == "langfuse" - ): - """ - - langfuse returns {'Authorization': 'Basic YW55dGhpbmc6YW55dGhpbmc'} - - check the langfuse public key if it contains the litellm api key - """ + if endpoint.get("custom_auth_parser") == "langfuse": import base64 api_key = api_key.replace("Basic ", "").strip() @@ -409,11 +396,10 @@ async def check_api_key_for_custom_headers_or_pass_through_endpoints( headers = endpoint.get("headers", None) if headers is not None: header_key = headers.get("litellm_user_api_key", "") - if ( - isinstance(request.headers, dict) - and request.headers.get(key=header_key) is not None # type: ignore - ): - api_key = request.headers.get(key=header_key) # type: ignore + value = request.headers.get(header_key) + if value is not None: + api_key = value + break # found matching endpoint, stop looping return api_key @@ -426,6 +412,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 azure_apim_header: Optional[str], request_data: dict, custom_litellm_key_header: Optional[str] = None, + route: Optional[str] = None, ) -> UserAPIKeyAuth: from litellm.proxy.proxy_server import ( general_settings, @@ -444,7 +431,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 parent_otel_span: Optional[Span] = None start_time = datetime.now() - route: str = get_request_route(request=request) + if route is None: + route = get_request_route(request=request) valid_token: Optional[UserAPIKeyAuth] = None custom_auth_api_key: bool = False @@ -483,7 +471,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 parent_otel_span = ( open_telemetry_logger.create_litellm_proxy_request_started_span( start_time=start_time, - headers=dict(request.headers), + headers=_safe_get_request_headers(request), ) ) @@ -515,7 +503,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 ######## Route Checks Before Reading DB / Cache for "token" ################ if ( - route in LiteLLMRoutes.public_routes.value # type: ignore + route in _PUBLIC_ROUTES or route_in_additonal_public_routes(current_route=route) ): # check if public endpoint @@ -1371,6 +1359,7 @@ async def user_api_key_auth( azure_apim_header=azure_apim_header, request_data=request_data, custom_litellm_key_header=custom_litellm_key_header, + route=route, ) ## ENSURE DISABLE ROUTE WORKS ACROSS ALL USER AUTH FLOWS ## diff --git a/litellm/proxy/common_utils/http_parsing_utils.py b/litellm/proxy/common_utils/http_parsing_utils.py index e1bca6e905f6..8d179a9caed0 100644 --- a/litellm/proxy/common_utils/http_parsing_utils.py +++ b/litellm/proxy/common_utils/http_parsing_utils.py @@ -135,17 +135,29 @@ def _safe_set_request_parsed_body( def _safe_get_request_headers(request: Optional[Request]) -> dict: """ - [Non-Blocking] Safely get the request headers + [Non-Blocking] Safely get the request headers. + Caches the result on request.state to avoid re-creating dict(request.headers) per call. + + Warning: Callers must NOT mutate the returned dict — it is shared across + all callers within the same request via the cache. """ + if request is None: + return {} + cached = getattr(request.state, "_cached_headers", None) + if cached is not None: + return cached try: - if request is None: - return {} - return dict(request.headers) + headers = dict(request.headers) except Exception as e: verbose_proxy_logger.debug( "Unexpected error reading request headers - {}".format(e) ) - return {} + headers = {} + try: + request.state._cached_headers = headers + except Exception: + pass # request.state may not be available in all contexts + return headers def check_file_size_under_limit( diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index d1c4b8b34036..71e2cb9c42c2 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -40,9 +40,9 @@ from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.passthrough import BasePassthroughUtils from litellm.proxy._types import ( + MAPPED_PASS_THROUGH_PREFIXES, ConfigFieldInfo, ConfigFieldUpdate, - LiteLLMRoutes, PassThroughEndpointResponse, PassThroughGenericEndpoint, ProxyException, @@ -2058,9 +2058,8 @@ def is_registered_pass_through_route(route: str) -> bool: bool: True if route is a registered pass-through endpoint, False otherwise """ ## CHECK IF MAPPED PASS THROUGH ENDPOINT - for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: - if route.startswith(mapped_route): - return True + if route.startswith(MAPPED_PASS_THROUGH_PREFIXES): + return True # Fast path: check if any registered route key contains this path # Keys are in format: "{endpoint_id}:exact:{path}:{methods}" or "{endpoint_id}:subpath:{path}:{methods}" diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index 79c2ed4158b7..f90cb5922cd6 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -443,6 +443,166 @@ async def test_return_user_api_key_auth_obj_user_spend_and_budget(): assert result.user_email == "test@example.com" +# ── Regression tests for auth optimizations ────────────────────────────────── + + +@pytest.mark.asyncio +async def test_check_api_key_normal_route_returns_api_key(): + """Normal route, no pass-through config -> returns api_key unchanged.""" + from litellm.proxy.auth.user_api_key_auth import ( + check_api_key_for_custom_headers_or_pass_through_endpoints, + ) + + request = MagicMock() + request.headers = {} + result = await check_api_key_for_custom_headers_or_pass_through_endpoints( + request=request, + route="/chat/completions", + pass_through_endpoints=None, + api_key="sk-test123", + ) + assert result == "sk-test123" + + +@pytest.mark.asyncio +async def test_check_api_key_mapped_pass_through_with_header(): + """Route /anthropic/v1/messages, header litellm_user_api_key set -> extracts key.""" + from litellm.proxy.auth.user_api_key_auth import ( + check_api_key_for_custom_headers_or_pass_through_endpoints, + ) + + request = MagicMock() + request.headers = {"litellm_user_api_key": "sk-from-header"} + result = await check_api_key_for_custom_headers_or_pass_through_endpoints( + request=request, + route="/anthropic/v1/messages", + pass_through_endpoints=None, + api_key="sk-original", + ) + assert result == "sk-from-header" + + +@pytest.mark.asyncio +async def test_check_api_key_configured_endpoint_auth_disabled(): + """Pass-through endpoint with auth: false -> returns UserAPIKeyAuth().""" + from litellm.proxy._types import UserAPIKeyAuth + from litellm.proxy.auth.user_api_key_auth import ( + check_api_key_for_custom_headers_or_pass_through_endpoints, + ) + + request = MagicMock() + request.headers = {} + endpoints = [{"path": "/custom/endpoint", "auth": False}] + result = await check_api_key_for_custom_headers_or_pass_through_endpoints( + request=request, + route="/custom/endpoint", + pass_through_endpoints=endpoints, + api_key="sk-test", + ) + assert isinstance(result, UserAPIKeyAuth) + + +@pytest.mark.asyncio +async def test_check_api_key_configured_endpoint_langfuse_parser(): + """Langfuse endpoint with Base64 auth -> parses correctly.""" + import base64 + + from litellm.proxy.auth.user_api_key_auth import ( + check_api_key_for_custom_headers_or_pass_through_endpoints, + ) + + public_key = "sk-lf-public" + secret_key = "sk-lf-secret" + basic_auth = base64.b64encode(f"{public_key}:{secret_key}".encode()).decode() + + request = MagicMock() + request.headers = {} + endpoints = [ + {"path": "/langfuse/api", "auth": True, "custom_auth_parser": "langfuse"} + ] + result = await check_api_key_for_custom_headers_or_pass_through_endpoints( + request=request, + route="/langfuse/api", + pass_through_endpoints=endpoints, + api_key=f"Basic {basic_auth}", + ) + assert result == public_key + + +@pytest.mark.asyncio +async def test_check_api_key_configured_endpoint_custom_header(): + """Pass-through endpoint with custom header config -> extracts from configured header.""" + from litellm.proxy.auth.user_api_key_auth import ( + check_api_key_for_custom_headers_or_pass_through_endpoints, + ) + + request = MagicMock() + request.headers = {"x-custom-key": "sk-custom-value"} + endpoints = [ + { + "path": "/custom/endpoint", + "auth": True, + "headers": {"litellm_user_api_key": "x-custom-key"}, + } + ] + result = await check_api_key_for_custom_headers_or_pass_through_endpoints( + request=request, + route="/custom/endpoint", + pass_through_endpoints=endpoints, + api_key="sk-original", + ) + assert result == "sk-custom-value" + + +def test_get_api_key_without_pass_through_branch(): + """ + Regression test for Round 2: after removing the elif pass_through_endpoints branch + from get_api_key, the key should still be extractable via Bearer token even when + pass_through_endpoints is configured (the later check_api_key_for_... call handles it). + """ + endpoints = [ + { + "path": "/custom/endpoint", + "auth": True, + "headers": {"litellm_user_api_key": "x-custom-key"}, + } + ] + request = MagicMock() + request.headers = {"x-custom-key": "sk-custom-value"} + + api_key, passed_in_key = get_api_key( + custom_litellm_key_header=None, + api_key="Bearer sk-test-key", + azure_api_key_header=None, + anthropic_api_key_header=None, + google_ai_studio_api_key_header=None, + azure_apim_header=None, + pass_through_endpoints=endpoints, + route="/custom/endpoint", + request=request, + ) + # Bearer token should be extracted normally regardless of pass_through_endpoints + assert api_key == "sk-test-key" + assert passed_in_key == "Bearer sk-test-key" + + +def test_safe_get_request_headers_caching(): + """Call _safe_get_request_headers twice on same request, assert returns same dict object.""" + from starlette.datastructures import State + + from litellm.proxy.common_utils.http_parsing_utils import ( + _safe_get_request_headers, + ) + + request = MagicMock() + request.headers = {"content-type": "application/json", "authorization": "Bearer sk-123"} + request.state = State() # real State object that supports attribute setting + + result1 = _safe_get_request_headers(request) + result2 = _safe_get_request_headers(request) + assert result1 is result2, "Second call should return the same cached dict object" + + def test_proxy_admin_jwt_auth_includes_identity_fields(): """ Test that the proxy admin early-return path in JWT auth populates diff --git a/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py b/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py index af366b082a07..16f934d8d464 100644 --- a/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py +++ b/tests/test_litellm/proxy/common_utils/test_http_parsing_utils.py @@ -7,6 +7,7 @@ import pytest from fastapi import Request from fastapi.testclient import TestClient +from starlette.datastructures import State sys.path.insert( 0, os.path.abspath("../../../..") @@ -36,6 +37,7 @@ async def test_request_body_caching(): """ # Create a mock request with a JSON body mock_request = MagicMock() + mock_request.state = State() test_data = {"key": "value"} # Use AsyncMock for the body method mock_request.body = AsyncMock(return_value=orjson.dumps(test_data)) @@ -69,6 +71,7 @@ async def test_form_data_parsing(): """ # Create a mock request with form data mock_request = MagicMock() + mock_request.state = State() test_data = {"name": "test_user", "message": "hello world"} # Mock the form method to return the test data as an awaitable @@ -104,7 +107,8 @@ async def test_form_data_with_json_metadata(): """ # Create a mock request with form data containing JSON metadata mock_request = MagicMock() - + mock_request.state = State() + # Metadata is sent as a JSON string in form data metadata_json_string = json.dumps({ "user_id": "12345", @@ -152,7 +156,8 @@ async def test_form_data_with_invalid_json_metadata(): """ # Create a mock request with form data containing invalid JSON metadata mock_request = MagicMock() - + mock_request.state = State() + test_data = { "model": "whisper-1", "file": "audio.mp3", @@ -178,7 +183,8 @@ async def test_form_data_without_metadata(): """ # Create a mock request with form data without metadata mock_request = MagicMock() - + mock_request.state = State() + test_data = { "model": "whisper-1", "file": "audio.mp3", @@ -208,7 +214,8 @@ async def test_form_data_with_empty_metadata(): """ # Create a mock request with form data containing empty metadata mock_request = MagicMock() - + mock_request.state = State() + test_data = { "model": "whisper-1", "file": "audio.mp3", @@ -240,7 +247,8 @@ async def test_form_data_with_dict_metadata(): """ # Create a mock request with form data where metadata is already a dict mock_request = MagicMock() - + mock_request.state = State() + metadata_dict = { "user_id": "12345", "tags": ["test"] @@ -275,7 +283,8 @@ async def test_form_data_with_none_metadata(): """ # Create a mock request with form data where metadata is None mock_request = MagicMock() - + mock_request.state = State() + test_data = { "model": "whisper-1", "file": "audio.mp3", @@ -303,6 +312,7 @@ async def test_empty_request_body(): """ # Create a mock request with an empty body mock_request = MagicMock() + mock_request.state = State() mock_request.body = AsyncMock(return_value=b"") # Empty bytes as an awaitable mock_request.headers = {"content-type": "application/json"} mock_request.scope = {} @@ -327,6 +337,7 @@ async def test_circular_reference_handling(): """ # Create a mock request with initial data mock_request = MagicMock() + mock_request.state = State() initial_body = { "model": "gpt-4", "messages": [{"role": "user", "content": "Hello"}], @@ -366,6 +377,7 @@ async def test_json_parsing_error_handling(): """ # Test case 1: Trailing comma error mock_request = MagicMock() + mock_request.state = State() invalid_json_with_trailing_comma = b'''{ "model": "gpt-4o", "tools": [ @@ -394,6 +406,7 @@ async def test_json_parsing_error_handling(): # Test case 2: Unquoted property name error mock_request2 = MagicMock() + mock_request2.state = State() invalid_json_unquoted_property = b'''{ "model": "gpt-4o", "tools": [ @@ -418,6 +431,7 @@ async def test_json_parsing_error_handling(): # Test case 3: Valid JSON should work normally mock_request3 = MagicMock() + mock_request3.state = State() valid_json = b'''{ "model": "gpt-4o", "tools": [ @@ -749,6 +763,7 @@ async def test_request_body_with_html_script_tags(): } mock_request = MagicMock() + mock_request.state = State() mock_request.body = AsyncMock(return_value=orjson.dumps(test_payload)) mock_request.headers = {"content-type": "application/json"} mock_request.scope = {}