Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions litellm/passthrough/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@ def forward_headers_from_request(
e.g., 'x-pass-anthropic-beta: value' becomes 'anthropic-beta: value'
"""
if forward_headers is True:
# Header We Should NOT forward
request_headers.pop("content-length", None)
request_headers.pop("host", None)
request_headers.pop("x-litellm-api-key", None)

# Combine request headers with custom headers
headers = {**request_headers, **headers}

# Always process x-pass- prefixed headers (strip prefix and forward)
Expand All @@ -52,6 +51,7 @@ def forward_headers_from_request(

return headers


class CommonUtils:
@staticmethod
def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str:
Expand All @@ -64,37 +64,36 @@ def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str:
arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile%2Fabdefg12334
so that it is treated as one part of the path.
Otherwise, the encoded endpoint will return 500 error when passed to Bedrock endpoint.

See the apis in https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Operations_Amazon_Bedrock_Runtime.html
for more details on the regex patterns of modelId which we use in the regex logic below.

Args:
endpoint (str): The original endpoint string which may contain ARNs that contain slashes.

Returns:
str: The endpoint with properly encoded ARN slashes
"""
import re

# Early exit: if no ARN detected, return unchanged
if 'arn:aws:' not in endpoint:
if "arn:aws:" not in endpoint:
return endpoint

# Handle all patterns in one go - more efficient and cleaner
patterns = [
# Custom model with 2 slashes (order matters - do this first)
(r'(custom-model)/([a-z0-9.-]+)/([a-z0-9]+)', r'\1%2F\2%2F\3'),

(r"(custom-model)/([a-z0-9.-]+)/([a-z0-9]+)", r"\1%2F\2%2F\3"),
# All other resource types with 1 slash
(r'(:application-inference-profile)/', r'\1%2F'),
(r'(:inference-profile)/', r'\1%2F'),
(r'(:foundation-model)/', r'\1%2F'),
(r'(:imported-model)/', r'\1%2F'),
(r'(:provisioned-model)/', r'\1%2F'),
(r'(:prompt)/', r'\1%2F'),
(r'(:endpoint)/', r'\1%2F'),
(r'(:prompt-router)/', r'\1%2F'),
(r'(:default-prompt-router)/', r'\1%2F'),
(r"(:application-inference-profile)/", r"\1%2F"),
(r"(:inference-profile)/", r"\1%2F"),
(r"(:foundation-model)/", r"\1%2F"),
(r"(:imported-model)/", r"\1%2F"),
(r"(:provisioned-model)/", r"\1%2F"),
(r"(:prompt)/", r"\1%2F"),
(r"(:endpoint)/", r"\1%2F"),
(r"(:prompt-router)/", r"\1%2F"),
(r"(:default-prompt-router)/", r"\1%2F"),
]

for pattern, replacement in patterns:
Expand All @@ -103,4 +102,4 @@ def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str:
endpoint = re.sub(pattern, replacement, endpoint)
break # Exit after first match since each ARN has only one resource type

return endpoint
return endpoint
72 changes: 42 additions & 30 deletions litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,19 +587,26 @@ async def anthropic_proxy_route(
base_target_url = os.getenv("ANTHROPIC_API_BASE") or "https://api.anthropic.com"
encoded_endpoint = httpx.URL(endpoint).path

# Ensure endpoint starts with '/' for proper URL construction
if not encoded_endpoint.startswith("/"):
encoded_endpoint = "/" + encoded_endpoint

# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)

# Add or update query parameters
anthropic_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="anthropic",
region_name=None,
)
x_api_key_header = request.headers.get("x-api-key", "")
auth_header = request.headers.get("authorization", "")

if x_api_key_header or auth_header:
custom_headers = {}
else:
anthropic_api_key = passthrough_endpoint_router.get_credentials(
custom_llm_provider="anthropic",
region_name=None,
)
if anthropic_api_key:
custom_headers = {"x-api-key": anthropic_api_key}
else:
custom_headers = {}
Comment on lines +596 to +609
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR description claims this detects OAuth tokens (pattern: Bearer sk-ant-oat...), but the code doesn't check for OAuth tokens specifically. It checks if ANY x-api-key or authorization header exists, which means all client credentials bypass server credentials. If you specifically want OAuth detection as described, consider checking for the OAuth token pattern.

Prompt To Fix With AI
This is a comment left during a code review.
Path: litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py
Line: 596:609

Comment:
PR description claims this detects OAuth tokens (pattern: `Bearer sk-ant-oat...`), but the code doesn't check for OAuth tokens specifically. It checks if ANY `x-api-key` or `authorization` header exists, which means all client credentials bypass server credentials. If you specifically want OAuth detection as described, consider checking for the OAuth token pattern.

How can I resolve this? If you propose a fix, please make it concise.

Comment on lines +599 to +609
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nested if-else creates empty dict in multiple paths. Consider simplifying with ternary operator for the inner condition.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Prompt To Fix With AI
This is a comment left during a code review.
Path: litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py
Line: 599:609

Comment:
Nested if-else creates empty dict in multiple paths. Consider simplifying with ternary operator for the inner condition.

<sub>Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!</sub>

How can I resolve this? If you propose a fix, please make it concise.


## check for streaming
is_streaming_request = await is_streaming_request_fn(request)
Expand All @@ -608,7 +615,7 @@ async def anthropic_proxy_route(
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(updated_url),
custom_headers={"x-api-key": "{}".format(anthropic_api_key)},
custom_headers=custom_headers,
_forward_headers=True,
is_streaming_request=is_streaming_request,
) # dynamically construct pass-through endpoint based on incoming path
Expand Down Expand Up @@ -829,10 +836,10 @@ async def handle_bedrock_count_tokens(

except BedrockError as e:
# Convert BedrockError to HTTPException for FastAPI
verbose_proxy_logger.error(f"BedrockError in handle_bedrock_count_tokens: {str(e)}")
raise HTTPException(
status_code=e.status_code, detail={"error": e.message}
verbose_proxy_logger.error(
f"BedrockError in handle_bedrock_count_tokens: {str(e)}"
)
raise HTTPException(status_code=e.status_code, detail={"error": e.message})
except HTTPException:
# Re-raise HTTP exceptions as-is
raise
Expand Down Expand Up @@ -1041,7 +1048,7 @@ async def bedrock_proxy_route(
target=str(prepped.url),
custom_headers=prepped.headers, # type: ignore
is_streaming_request=is_streaming_request,
_forward_headers=True
_forward_headers=True,
) # dynamically construct pass-through endpoint based on incoming path
received_value = await endpoint_func(
request,
Expand All @@ -1063,42 +1070,43 @@ def _resolve_vertex_model_from_router(
) -> Tuple[str, str, Optional[str], Optional[str]]:
"""
Resolve Vertex AI model configuration from router.

Args:
model_id: The model ID extracted from the URL (e.g., "gcp/google/gemini-2.5-flash")
llm_router: The LiteLLM router instance
encoded_endpoint: The encoded endpoint path
endpoint: The original endpoint path
vertex_project: Current vertex project (may be from URL)
vertex_location: Current vertex location (may be from URL)

Returns:
Tuple of (encoded_endpoint, endpoint, vertex_project, vertex_location)
with resolved values from router config
"""
if not llm_router:
return encoded_endpoint, endpoint, vertex_project, vertex_location

try:
deployment = llm_router.get_available_deployment_for_pass_through(model=model_id)
deployment = llm_router.get_available_deployment_for_pass_through(
model=model_id
)
if not deployment:
return encoded_endpoint, endpoint, vertex_project, vertex_location

litellm_params = deployment.get("litellm_params", {})

# Always override with router config values (they take precedence over URL values)
config_vertex_project = litellm_params.get("vertex_project")
config_vertex_location = litellm_params.get("vertex_location")
if config_vertex_project:
vertex_project = config_vertex_project
if config_vertex_location:
vertex_location = config_vertex_location

# Get the actual Vertex AI model name by stripping the provider prefix
# e.g., "vertex_ai/gemini-2.0-flash-exp" -> "gemini-2.0-flash-exp"
model_from_config = litellm_params.get("model", "")
if model_from_config:

# get_llm_provider returns (model, custom_llm_provider, dynamic_api_key, api_base)
# For "vertex_ai/gemini-2.0-flash-exp" it returns:
# model="gemini-2.0-flash-exp", custom_llm_provider="vertex_ai"
Expand Down Expand Up @@ -1127,12 +1135,12 @@ def _resolve_vertex_model_from_router(
)
encoded_endpoint = encoded_endpoint.replace(model_id, actual_model)
endpoint = endpoint.replace(model_id, actual_model)

except Exception as e:
verbose_proxy_logger.debug(
f"Error resolving vertex model from router for model {model_id}: {e}"
)

return encoded_endpoint, endpoint, vertex_project, vertex_location


Expand Down Expand Up @@ -1597,7 +1605,7 @@ async def _prepare_vertex_auth_headers(
vertex_credentials_str = None
elif vertex_credentials is not None:
# Use credentials from vertex_credentials
# When vertex_credentials are provided (including default credentials),
# When vertex_credentials are provided (including default credentials),
# use their project/location values if available
if vertex_credentials.vertex_project is not None:
vertex_project = vertex_credentials.vertex_project
Expand Down Expand Up @@ -1703,10 +1711,14 @@ async def _base_vertex_proxy_route(
# Check if model is in router config - always do this to resolve custom model names
model_id = get_vertex_model_id_from_url(endpoint)
if model_id:

if llm_router:
# Resolve model configuration from router
encoded_endpoint, endpoint, vertex_project, vertex_location = _resolve_vertex_model_from_router(
(
encoded_endpoint,
endpoint,
vertex_project,
vertex_location,
) = _resolve_vertex_model_from_router(
model_id=model_id,
llm_router=llm_router,
encoded_endpoint=encoded_endpoint,
Expand Down Expand Up @@ -1899,25 +1911,25 @@ async def openai_proxy_route(
):
"""
Pass-through endpoint for OpenAI API calls.

Available on both routes:
- /openai/{endpoint:path} - Standard OpenAI passthrough route
- /openai_passthrough/{endpoint:path} - Dedicated passthrough route (recommended for Responses API)

Use /openai_passthrough/* when you need guaranteed passthrough to OpenAI without conflicts
with LiteLLM's native implementations (e.g., for the Responses API at /v1/responses).

Examples:
Standard route:
- /openai/v1/chat/completions
- /openai/v1/assistants
- /openai/v1/threads

Dedicated passthrough (for Responses API):
- /openai_passthrough/v1/responses
- /openai_passthrough/v1/responses/{response_id}
- /openai_passthrough/v1/responses/{response_id}/input_items

[Docs](https://docs.litellm.ai/docs/pass_through/openai_passthrough)
"""
base_target_url = os.getenv("OPENAI_API_BASE") or "https://api.openai.com/"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fastapi import Request, Response
from fastapi.testclient import TestClient

from litellm.passthrough.utils import CommonUtils
from litellm.passthrough.utils import CommonUtils, HttpPassThroughEndpointHelpers

sys.path.insert(
0, os.path.abspath("../../../..")
Expand Down Expand Up @@ -95,4 +95,42 @@ def test_encode_bedrock_runtime_modelid_arn_edge_cases():
endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile.v1/invoke"
expected = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile%2Ftest-profile.v1/invoke"
result = CommonUtils.encode_bedrock_runtime_modelid_arn(endpoint)
assert result == expected
assert result == expected


def test_forward_headers_strips_litellm_api_key():
"""x-litellm-api-key should not be forwarded to upstream providers."""
request_headers = {
"x-litellm-api-key": "sk-litellm-secret-key",
"content-type": "application/json",
"x-api-key": "sk-ant-api-key",
}

result = HttpPassThroughEndpointHelpers.forward_headers_from_request(
request_headers=request_headers.copy(),
headers={},
forward_headers=True,
)

assert "x-litellm-api-key" not in result
assert result.get("content-type") == "application/json"
assert result.get("x-api-key") == "sk-ant-api-key"


def test_forward_headers_strips_host_and_content_length():
"""host and content-length should not be forwarded."""
request_headers = {
"host": "api.anthropic.com",
"content-length": "1234",
"content-type": "application/json",
}

result = HttpPassThroughEndpointHelpers.forward_headers_from_request(
request_headers=request_headers.copy(),
headers={},
forward_headers=True,
)

assert "host" not in result
assert "content-length" not in result
assert result.get("content-type") == "application/json"
Loading