Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions litellm/proxy/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,11 @@ class LiteLLMRoutes(enum.Enum):
)


# Pre-computed tuple for fast startswith() checks against mapped pass-through routes.
# Defined once here and imported by auth/route_checks modules.
MAPPED_PASS_THROUGH_PREFIXES = tuple(LiteLLMRoutes.mapped_pass_through_routes.value)


class LiteLLMPromptInjectionParams(LiteLLMPydanticObjectBase):
heuristics_check: bool = False
vector_db_check: bool = False
Expand Down
6 changes: 3 additions & 3 deletions litellm/proxy/auth/route_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import (
MAPPED_PASS_THROUGH_PREFIXES,
CommonProxyErrors,
LiteLLM_UserTable,
LiteLLMRoutes,
Expand Down Expand Up @@ -346,9 +347,8 @@ def is_llm_api_route(route: str) -> bool:
if RouteChecks._is_azure_openai_route(route=route):
return True

for _llm_passthrough_route in LiteLLMRoutes.mapped_pass_through_routes.value:
if _llm_passthrough_route in route:
return True
if route.startswith(MAPPED_PASS_THROUGH_PREFIXES):
return True
Comment on lines +350 to +351
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subtle behavioral change: startswith vs in

The old code used if _llm_passthrough_route in route (substring match), but the new code uses route.startswith(MAPPED_PASS_THROUGH_PREFIXES) (prefix match). This is a semantic change: the old logic would match routes like /some/path/openai/endpoint where the prefix appears as a substring, while the new logic only matches routes that begin with the prefix.

In practice this is likely the correct behavior (all pass-through routes are defined with these as path prefixes), and the in check was overly broad. However, it should be intentional — if any routes exist where /openai, /anthropic, etc. appear as a non-prefix substring, they would no longer be classified as LLM API routes by _is_llm_api_route.

return False

@staticmethod
Expand Down
59 changes: 24 additions & 35 deletions litellm/proxy/auth/user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@

user_api_key_service_logger_obj = ServiceLogging() # used for tracking latency on OTEL

# Pre-computed constants to avoid repeated enum attribute access
_PUBLIC_ROUTES = LiteLLMRoutes.public_routes.value

custom_litellm_key_header = APIKeyHeader(
name=SpecialHeaders.custom_litellm_api_key.value,
auto_error=False,
Expand Down Expand Up @@ -358,15 +361,6 @@ def get_api_key(
google_auth_key: str = _safe_get_request_query_params(request).get("key") or ""
passed_in_key = google_auth_key
api_key = google_auth_key
elif pass_through_endpoints is not None:
for endpoint in pass_through_endpoints:
if endpoint.get("path", "") == route:
headers: Optional[dict] = endpoint.get("headers", None)
if headers is not None:
header_key: str = headers.get("litellm_user_api_key", "")
if request.headers.get(header_key) is not None:
api_key = request.headers.get(header_key) or ""
passed_in_key = api_key
return api_key, passed_in_key


Expand All @@ -376,29 +370,22 @@ async def check_api_key_for_custom_headers_or_pass_through_endpoints(
pass_through_endpoints: Optional[List[dict]],
api_key: str,
) -> Union[UserAPIKeyAuth, str]:
is_mapped_pass_through_route: bool = False
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
if route.startswith(mapped_route):
is_mapped_pass_through_route = True
if is_mapped_pass_through_route:
if request.headers.get("litellm_user_api_key") is not None:
api_key = request.headers.get("litellm_user_api_key") or ""
# Fast path: nothing to check
is_mapped = route.startswith(MAPPED_PASS_THROUGH_PREFIXES)
if not is_mapped and pass_through_endpoints is None:
return api_key

if is_mapped:
value = request.headers.get("litellm_user_api_key")
if value is not None:
api_key = value

if pass_through_endpoints is not None:
for endpoint in pass_through_endpoints:
if isinstance(endpoint, dict) and endpoint.get("path", "") == route:
## IF AUTH DISABLED
if endpoint.get("auth") is not True:
return UserAPIKeyAuth()
Comment on lines 386 to 387
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auth: False check may miss falsy values

The condition endpoint.get("auth") is not True will pass for endpoints where auth is explicitly set to False, but it will also pass for endpoints that don't have an auth key at all (returns None), or where auth is any other truthy value like "yes".

Note: this is not a regression - the original code had the same check (endpoint.get("auth") is not True). Just flagging that if someone configures "auth": "enabled" on a pass-through endpoint, it would skip auth entirely since "enabled" is not True evaluates to True.

## IF AUTH ENABLED
### IF CUSTOM PARSER REQUIRED
if (
endpoint.get("custom_auth_parser") is not None
and endpoint.get("custom_auth_parser") == "langfuse"
):
"""
- langfuse returns {'Authorization': 'Basic YW55dGhpbmc6YW55dGhpbmc'}
- check the langfuse public key if it contains the litellm api key
"""
if endpoint.get("custom_auth_parser") == "langfuse":
import base64

api_key = api_key.replace("Basic ", "").strip()
Expand All @@ -409,11 +396,10 @@ async def check_api_key_for_custom_headers_or_pass_through_endpoints(
headers = endpoint.get("headers", None)
if headers is not None:
header_key = headers.get("litellm_user_api_key", "")
if (
isinstance(request.headers, dict)
and request.headers.get(key=header_key) is not None # type: ignore
):
api_key = request.headers.get(key=header_key) # type: ignore
value = request.headers.get(header_key)
if value is not None:
api_key = value
break # found matching endpoint, stop looping
return api_key


Expand All @@ -426,6 +412,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
azure_apim_header: Optional[str],
request_data: dict,
custom_litellm_key_header: Optional[str] = None,
route: Optional[str] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New route parameter should be typed Optional[str]

The route parameter defaults to None, but the type annotation uses Optional[str] correctly. However, there's no docstring update to explain that callers can now pass in a pre-computed route to avoid redundant get_request_route() calls. A brief inline comment would help future maintainers understand this optimization.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
general_settings,
Expand All @@ -444,7 +431,8 @@ async def _user_api_key_auth_builder( # noqa: PLR0915

parent_otel_span: Optional[Span] = None
start_time = datetime.now()
route: str = get_request_route(request=request)
if route is None:
route = get_request_route(request=request)
valid_token: Optional[UserAPIKeyAuth] = None
custom_auth_api_key: bool = False

Expand Down Expand Up @@ -483,7 +471,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
parent_otel_span = (
open_telemetry_logger.create_litellm_proxy_request_started_span(
start_time=start_time,
headers=dict(request.headers),
headers=_safe_get_request_headers(request),
)
)

Expand Down Expand Up @@ -515,7 +503,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915

######## Route Checks Before Reading DB / Cache for "token" ################
if (
route in LiteLLMRoutes.public_routes.value # type: ignore
route in _PUBLIC_ROUTES
or route_in_additonal_public_routes(current_route=route)
):
# check if public endpoint
Expand Down Expand Up @@ -1371,6 +1359,7 @@ async def user_api_key_auth(
azure_apim_header=azure_apim_header,
request_data=request_data,
custom_litellm_key_header=custom_litellm_key_header,
route=route,
)

## ENSURE DISABLE ROUTE WORKS ACROSS ALL USER AUTH FLOWS ##
Expand Down
22 changes: 17 additions & 5 deletions litellm/proxy/common_utils/http_parsing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,29 @@ def _safe_set_request_parsed_body(

def _safe_get_request_headers(request: Optional[Request]) -> dict:
"""
[Non-Blocking] Safely get the request headers
[Non-Blocking] Safely get the request headers.
Caches the result on request.state to avoid re-creating dict(request.headers) per call.

Warning: Callers must NOT mutate the returned dict — it is shared across
all callers within the same request via the cache.
"""
if request is None:
return {}
cached = getattr(request.state, "_cached_headers", None)
if cached is not None:
return cached
try:
if request is None:
return {}
return dict(request.headers)
headers = dict(request.headers)
except Exception as e:
verbose_proxy_logger.debug(
"Unexpected error reading request headers - {}".format(e)
)
return {}
headers = {}
try:
request.state._cached_headers = headers
except Exception:
pass # request.state may not be available in all contexts
return headers


def check_file_size_under_limit(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.passthrough import BasePassthroughUtils
from litellm.proxy._types import (
MAPPED_PASS_THROUGH_PREFIXES,
ConfigFieldInfo,
ConfigFieldUpdate,
LiteLLMRoutes,
PassThroughEndpointResponse,
PassThroughGenericEndpoint,
ProxyException,
Expand Down Expand Up @@ -2058,9 +2058,8 @@ def is_registered_pass_through_route(route: str) -> bool:
bool: True if route is a registered pass-through endpoint, False otherwise
"""
## CHECK IF MAPPED PASS THROUGH ENDPOINT
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
if route.startswith(mapped_route):
return True
if route.startswith(MAPPED_PASS_THROUGH_PREFIXES):
return True

# Fast path: check if any registered route key contains this path
# Keys are in format: "{endpoint_id}:exact:{path}:{methods}" or "{endpoint_id}:subpath:{path}:{methods}"
Expand Down
160 changes: 160 additions & 0 deletions tests/test_litellm/proxy/auth/test_user_api_key_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,166 @@ async def test_return_user_api_key_auth_obj_user_spend_and_budget():
assert result.user_email == "test@example.com"


# ── Regression tests for auth optimizations ──────────────────────────────────


@pytest.mark.asyncio
async def test_check_api_key_normal_route_returns_api_key():
"""Normal route, no pass-through config -> returns api_key unchanged."""
from litellm.proxy.auth.user_api_key_auth import (
check_api_key_for_custom_headers_or_pass_through_endpoints,
)

request = MagicMock()
request.headers = {}
result = await check_api_key_for_custom_headers_or_pass_through_endpoints(
request=request,
route="/chat/completions",
pass_through_endpoints=None,
api_key="sk-test123",
)
assert result == "sk-test123"


@pytest.mark.asyncio
async def test_check_api_key_mapped_pass_through_with_header():
"""Route /anthropic/v1/messages, header litellm_user_api_key set -> extracts key."""
from litellm.proxy.auth.user_api_key_auth import (
check_api_key_for_custom_headers_or_pass_through_endpoints,
)

request = MagicMock()
request.headers = {"litellm_user_api_key": "sk-from-header"}
result = await check_api_key_for_custom_headers_or_pass_through_endpoints(
request=request,
route="/anthropic/v1/messages",
pass_through_endpoints=None,
api_key="sk-original",
)
assert result == "sk-from-header"


@pytest.mark.asyncio
async def test_check_api_key_configured_endpoint_auth_disabled():
"""Pass-through endpoint with auth: false -> returns UserAPIKeyAuth()."""
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import (
check_api_key_for_custom_headers_or_pass_through_endpoints,
)

request = MagicMock()
request.headers = {}
endpoints = [{"path": "/custom/endpoint", "auth": False}]
result = await check_api_key_for_custom_headers_or_pass_through_endpoints(
request=request,
route="/custom/endpoint",
pass_through_endpoints=endpoints,
api_key="sk-test",
)
assert isinstance(result, UserAPIKeyAuth)


@pytest.mark.asyncio
async def test_check_api_key_configured_endpoint_langfuse_parser():
"""Langfuse endpoint with Base64 auth -> parses correctly."""
import base64

from litellm.proxy.auth.user_api_key_auth import (
check_api_key_for_custom_headers_or_pass_through_endpoints,
)

public_key = "sk-lf-public"
secret_key = "sk-lf-secret"
basic_auth = base64.b64encode(f"{public_key}:{secret_key}".encode()).decode()

request = MagicMock()
request.headers = {}
endpoints = [
{"path": "/langfuse/api", "auth": True, "custom_auth_parser": "langfuse"}
]
result = await check_api_key_for_custom_headers_or_pass_through_endpoints(
request=request,
route="/langfuse/api",
pass_through_endpoints=endpoints,
api_key=f"Basic {basic_auth}",
)
assert result == public_key


@pytest.mark.asyncio
async def test_check_api_key_configured_endpoint_custom_header():
"""Pass-through endpoint with custom header config -> extracts from configured header."""
from litellm.proxy.auth.user_api_key_auth import (
check_api_key_for_custom_headers_or_pass_through_endpoints,
)

request = MagicMock()
request.headers = {"x-custom-key": "sk-custom-value"}
endpoints = [
{
"path": "/custom/endpoint",
"auth": True,
"headers": {"litellm_user_api_key": "x-custom-key"},
}
]
result = await check_api_key_for_custom_headers_or_pass_through_endpoints(
request=request,
route="/custom/endpoint",
pass_through_endpoints=endpoints,
api_key="sk-original",
)
assert result == "sk-custom-value"


def test_get_api_key_without_pass_through_branch():
"""
Regression test for Round 2: after removing the elif pass_through_endpoints branch
from get_api_key, the key should still be extractable via Bearer token even when
pass_through_endpoints is configured (the later check_api_key_for_... call handles it).
"""
endpoints = [
{
"path": "/custom/endpoint",
"auth": True,
"headers": {"litellm_user_api_key": "x-custom-key"},
}
]
request = MagicMock()
request.headers = {"x-custom-key": "sk-custom-value"}

api_key, passed_in_key = get_api_key(
custom_litellm_key_header=None,
api_key="Bearer sk-test-key",
azure_api_key_header=None,
anthropic_api_key_header=None,
google_ai_studio_api_key_header=None,
azure_apim_header=None,
pass_through_endpoints=endpoints,
route="/custom/endpoint",
request=request,
)
# Bearer token should be extracted normally regardless of pass_through_endpoints
assert api_key == "sk-test-key"
assert passed_in_key == "Bearer sk-test-key"


def test_safe_get_request_headers_caching():
"""Call _safe_get_request_headers twice on same request, assert returns same dict object."""
from starlette.datastructures import State

from litellm.proxy.common_utils.http_parsing_utils import (
_safe_get_request_headers,
)

request = MagicMock()
request.headers = {"content-type": "application/json", "authorization": "Bearer sk-123"}
request.state = State() # real State object that supports attribute setting

result1 = _safe_get_request_headers(request)
result2 = _safe_get_request_headers(request)
assert result1 is result2, "Second call should return the same cached dict object"


def test_proxy_admin_jwt_auth_includes_identity_fields():
"""
Test that the proxy admin early-return path in JWT auth populates
Expand Down
Loading
Loading