diff --git a/litellm/passthrough/utils.py b/litellm/passthrough/utils.py index fbbf9cd2581..805a20e073c 100644 --- a/litellm/passthrough/utils.py +++ b/litellm/passthrough/utils.py @@ -36,11 +36,10 @@ def forward_headers_from_request( e.g., 'x-pass-anthropic-beta: value' becomes 'anthropic-beta: value' """ if forward_headers is True: - # Header We Should NOT forward request_headers.pop("content-length", None) request_headers.pop("host", None) + request_headers.pop("x-litellm-api-key", None) - # Combine request headers with custom headers headers = {**request_headers, **headers} # Always process x-pass- prefixed headers (strip prefix and forward) @@ -52,6 +51,7 @@ def forward_headers_from_request( return headers + class CommonUtils: @staticmethod def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str: @@ -64,37 +64,36 @@ def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str: arn:aws:bedrock:ap-southeast-1:123456789012:application-inference-profile%2Fabdefg12334 so that it is treated as one part of the path. Otherwise, the encoded endpoint will return 500 error when passed to Bedrock endpoint. - + See the apis in https://docs.aws.amazon.com/bedrock/latest/APIReference/API_Operations_Amazon_Bedrock_Runtime.html for more details on the regex patterns of modelId which we use in the regex logic below. - + Args: endpoint (str): The original endpoint string which may contain ARNs that contain slashes. - + Returns: str: The endpoint with properly encoded ARN slashes """ import re # Early exit: if no ARN detected, return unchanged - if 'arn:aws:' not in endpoint: + if "arn:aws:" not in endpoint: return endpoint # Handle all patterns in one go - more efficient and cleaner patterns = [ # Custom model with 2 slashes (order matters - do this first) - (r'(custom-model)/([a-z0-9.-]+)/([a-z0-9]+)', r'\1%2F\2%2F\3'), - + (r"(custom-model)/([a-z0-9.-]+)/([a-z0-9]+)", r"\1%2F\2%2F\3"), # All other resource types with 1 slash - (r'(:application-inference-profile)/', r'\1%2F'), - (r'(:inference-profile)/', r'\1%2F'), - (r'(:foundation-model)/', r'\1%2F'), - (r'(:imported-model)/', r'\1%2F'), - (r'(:provisioned-model)/', r'\1%2F'), - (r'(:prompt)/', r'\1%2F'), - (r'(:endpoint)/', r'\1%2F'), - (r'(:prompt-router)/', r'\1%2F'), - (r'(:default-prompt-router)/', r'\1%2F'), + (r"(:application-inference-profile)/", r"\1%2F"), + (r"(:inference-profile)/", r"\1%2F"), + (r"(:foundation-model)/", r"\1%2F"), + (r"(:imported-model)/", r"\1%2F"), + (r"(:provisioned-model)/", r"\1%2F"), + (r"(:prompt)/", r"\1%2F"), + (r"(:endpoint)/", r"\1%2F"), + (r"(:prompt-router)/", r"\1%2F"), + (r"(:default-prompt-router)/", r"\1%2F"), ] for pattern, replacement in patterns: @@ -103,4 +102,4 @@ def encode_bedrock_runtime_modelid_arn(endpoint: str) -> str: endpoint = re.sub(pattern, replacement, endpoint) break # Exit after first match since each ARN has only one resource type - return endpoint \ No newline at end of file + return endpoint diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index 3dab6ea14f8..1d9c1370700 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -587,19 +587,26 @@ async def anthropic_proxy_route( base_target_url = os.getenv("ANTHROPIC_API_BASE") or "https://api.anthropic.com" encoded_endpoint = httpx.URL(endpoint).path - # Ensure endpoint starts with '/' for proper URL construction if not encoded_endpoint.startswith("/"): encoded_endpoint = "/" + encoded_endpoint - # Construct the full target URL using httpx base_url = httpx.URL(base_target_url) updated_url = base_url.copy_with(path=encoded_endpoint) - # Add or update query parameters - anthropic_api_key = passthrough_endpoint_router.get_credentials( - custom_llm_provider="anthropic", - region_name=None, - ) + x_api_key_header = request.headers.get("x-api-key", "") + auth_header = request.headers.get("authorization", "") + + if x_api_key_header or auth_header: + custom_headers = {} + else: + anthropic_api_key = passthrough_endpoint_router.get_credentials( + custom_llm_provider="anthropic", + region_name=None, + ) + if anthropic_api_key: + custom_headers = {"x-api-key": anthropic_api_key} + else: + custom_headers = {} ## check for streaming is_streaming_request = await is_streaming_request_fn(request) @@ -608,7 +615,7 @@ async def anthropic_proxy_route( endpoint_func = create_pass_through_route( endpoint=endpoint, target=str(updated_url), - custom_headers={"x-api-key": "{}".format(anthropic_api_key)}, + custom_headers=custom_headers, _forward_headers=True, is_streaming_request=is_streaming_request, ) # dynamically construct pass-through endpoint based on incoming path @@ -829,10 +836,10 @@ async def handle_bedrock_count_tokens( except BedrockError as e: # Convert BedrockError to HTTPException for FastAPI - verbose_proxy_logger.error(f"BedrockError in handle_bedrock_count_tokens: {str(e)}") - raise HTTPException( - status_code=e.status_code, detail={"error": e.message} + verbose_proxy_logger.error( + f"BedrockError in handle_bedrock_count_tokens: {str(e)}" ) + raise HTTPException(status_code=e.status_code, detail={"error": e.message}) except HTTPException: # Re-raise HTTP exceptions as-is raise @@ -1041,7 +1048,7 @@ async def bedrock_proxy_route( target=str(prepped.url), custom_headers=prepped.headers, # type: ignore is_streaming_request=is_streaming_request, - _forward_headers=True + _forward_headers=True, ) # dynamically construct pass-through endpoint based on incoming path received_value = await endpoint_func( request, @@ -1063,7 +1070,7 @@ def _resolve_vertex_model_from_router( ) -> Tuple[str, str, Optional[str], Optional[str]]: """ Resolve Vertex AI model configuration from router. - + Args: model_id: The model ID extracted from the URL (e.g., "gcp/google/gemini-2.5-flash") llm_router: The LiteLLM router instance @@ -1071,21 +1078,23 @@ def _resolve_vertex_model_from_router( endpoint: The original endpoint path vertex_project: Current vertex project (may be from URL) vertex_location: Current vertex location (may be from URL) - + Returns: Tuple of (encoded_endpoint, endpoint, vertex_project, vertex_location) with resolved values from router config """ if not llm_router: return encoded_endpoint, endpoint, vertex_project, vertex_location - + try: - deployment = llm_router.get_available_deployment_for_pass_through(model=model_id) + deployment = llm_router.get_available_deployment_for_pass_through( + model=model_id + ) if not deployment: return encoded_endpoint, endpoint, vertex_project, vertex_location - + litellm_params = deployment.get("litellm_params", {}) - + # Always override with router config values (they take precedence over URL values) config_vertex_project = litellm_params.get("vertex_project") config_vertex_location = litellm_params.get("vertex_location") @@ -1093,12 +1102,11 @@ def _resolve_vertex_model_from_router( vertex_project = config_vertex_project if config_vertex_location: vertex_location = config_vertex_location - + # Get the actual Vertex AI model name by stripping the provider prefix # e.g., "vertex_ai/gemini-2.0-flash-exp" -> "gemini-2.0-flash-exp" model_from_config = litellm_params.get("model", "") if model_from_config: - # get_llm_provider returns (model, custom_llm_provider, dynamic_api_key, api_base) # For "vertex_ai/gemini-2.0-flash-exp" it returns: # model="gemini-2.0-flash-exp", custom_llm_provider="vertex_ai" @@ -1127,12 +1135,12 @@ def _resolve_vertex_model_from_router( ) encoded_endpoint = encoded_endpoint.replace(model_id, actual_model) endpoint = endpoint.replace(model_id, actual_model) - + except Exception as e: verbose_proxy_logger.debug( f"Error resolving vertex model from router for model {model_id}: {e}" ) - + return encoded_endpoint, endpoint, vertex_project, vertex_location @@ -1597,7 +1605,7 @@ async def _prepare_vertex_auth_headers( vertex_credentials_str = None elif vertex_credentials is not None: # Use credentials from vertex_credentials - # When vertex_credentials are provided (including default credentials), + # When vertex_credentials are provided (including default credentials), # use their project/location values if available if vertex_credentials.vertex_project is not None: vertex_project = vertex_credentials.vertex_project @@ -1703,10 +1711,14 @@ async def _base_vertex_proxy_route( # Check if model is in router config - always do this to resolve custom model names model_id = get_vertex_model_id_from_url(endpoint) if model_id: - if llm_router: # Resolve model configuration from router - encoded_endpoint, endpoint, vertex_project, vertex_location = _resolve_vertex_model_from_router( + ( + encoded_endpoint, + endpoint, + vertex_project, + vertex_location, + ) = _resolve_vertex_model_from_router( model_id=model_id, llm_router=llm_router, encoded_endpoint=encoded_endpoint, @@ -1899,25 +1911,25 @@ async def openai_proxy_route( ): """ Pass-through endpoint for OpenAI API calls. - + Available on both routes: - /openai/{endpoint:path} - Standard OpenAI passthrough route - /openai_passthrough/{endpoint:path} - Dedicated passthrough route (recommended for Responses API) - + Use /openai_passthrough/* when you need guaranteed passthrough to OpenAI without conflicts with LiteLLM's native implementations (e.g., for the Responses API at /v1/responses). - + Examples: Standard route: - /openai/v1/chat/completions - /openai/v1/assistants - /openai/v1/threads - + Dedicated passthrough (for Responses API): - /openai_passthrough/v1/responses - /openai_passthrough/v1/responses/{response_id} - /openai_passthrough/v1/responses/{response_id}/input_items - + [Docs](https://docs.litellm.ai/docs/pass_through/openai_passthrough) """ base_target_url = os.getenv("OPENAI_API_BASE") or "https://api.openai.com/" diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py b/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py index 97ef05100de..ff3e1c43c63 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_endpoints_common_utils.py @@ -10,7 +10,7 @@ from fastapi import Request, Response from fastapi.testclient import TestClient -from litellm.passthrough.utils import CommonUtils +from litellm.passthrough.utils import CommonUtils, HttpPassThroughEndpointHelpers sys.path.insert( 0, os.path.abspath("../../../..") @@ -95,4 +95,42 @@ def test_encode_bedrock_runtime_modelid_arn_edge_cases(): endpoint = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile/test-profile.v1/invoke" expected = "model/arn:aws:bedrock:us-east-1:123456789012:application-inference-profile%2Ftest-profile.v1/invoke" result = CommonUtils.encode_bedrock_runtime_modelid_arn(endpoint) - assert result == expected \ No newline at end of file + assert result == expected + + +def test_forward_headers_strips_litellm_api_key(): + """x-litellm-api-key should not be forwarded to upstream providers.""" + request_headers = { + "x-litellm-api-key": "sk-litellm-secret-key", + "content-type": "application/json", + "x-api-key": "sk-ant-api-key", + } + + result = HttpPassThroughEndpointHelpers.forward_headers_from_request( + request_headers=request_headers.copy(), + headers={}, + forward_headers=True, + ) + + assert "x-litellm-api-key" not in result + assert result.get("content-type") == "application/json" + assert result.get("x-api-key") == "sk-ant-api-key" + + +def test_forward_headers_strips_host_and_content_length(): + """host and content-length should not be forwarded.""" + request_headers = { + "host": "api.anthropic.com", + "content-length": "1234", + "content-type": "application/json", + } + + result = HttpPassThroughEndpointHelpers.forward_headers_from_request( + request_headers=request_headers.copy(), + headers={}, + forward_headers=True, + ) + + assert "host" not in result + assert "content-length" not in result + assert result.get("content-type") == "application/json" \ No newline at end of file