Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ async def endpoint_func( # type: ignore
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
subpath: str = "", # captures sub-paths when include_subpath=True
custom_body: Optional[dict] = None, # accepted for signature compatibility with URL-based path; not forwarded because chat_completion_pass_through_endpoint does not support it
):
return await chat_completion_pass_through_endpoint(
fastapi_response=fastapi_response,
Expand All @@ -1115,6 +1116,7 @@ async def endpoint_func( # type: ignore
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
subpath: str = "", # captures sub-paths when include_subpath=True
custom_body: Optional[dict] = None,
):
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
InitPassThroughEndpointHelpers,
Expand Down Expand Up @@ -1189,11 +1191,16 @@ async def endpoint_func( # type: ignore
)
if query_params:
final_query_params.update(query_params)
final_custom_body = (
custom_body_data
if isinstance(custom_body_data, dict) or custom_body_data is None
else None
)
# When a caller (e.g. bedrock_proxy_route) supplies a pre-built
# body, use it instead of the body parsed from the raw request.
if custom_body is not None:
final_custom_body = custom_body
else:
final_custom_body = (
custom_body_data
if isinstance(custom_body_data, dict) or custom_body_data is None
else None
)

return await pass_through_request( # type: ignore
request=request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1960,6 +1960,143 @@ async def test_add_litellm_data_to_request_adds_headers_to_metadata():
assert "headers" in result["proxy_server_request"]


@pytest.mark.asyncio
async def test_create_pass_through_route_custom_body_url_target():
"""
Test that the URL-based endpoint_func created by create_pass_through_route
accepts a custom_body parameter and forwards it to pass_through_request,
taking precedence over the request-parsed body.

This verifies the fix for issue #16999 where bedrock_proxy_route passes
custom_body=data to the endpoint function, which previously crashed with:
TypeError: endpoint_func() got an unexpected keyword argument 'custom_body'
"""
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)

unique_path = "/test/path/unique/custom_body_url"
endpoint_func = create_pass_through_route(
endpoint=unique_path,
target="https://bedrock-agent-runtime.us-east-1.amazonaws.com",
custom_headers={"Content-Type": "application/json"},
_forward_headers=True,
)

with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_request"
) as mock_pass_through, patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.is_registered_pass_through_route"
) as mock_is_registered, patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.get_registered_pass_through_route"
) as mock_get_registered, patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints._parse_request_data_by_content_type"
) as mock_parse_request:
mock_pass_through.return_value = MagicMock()
mock_is_registered.return_value = True
mock_get_registered.return_value = None
# Simulate the request parser returning a different body
mock_parse_request.return_value = (
{}, # query_params_data
{"parsed_from_request": True}, # custom_body_data (from request)
None, # file_data
False, # stream
)

mock_request = MagicMock(spec=Request)
mock_request.url = MagicMock()
mock_request.url.path = unique_path
mock_request.path_params = {}
mock_request.query_params = QueryParams({})

mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.api_key = "test-key"

# The caller-supplied body (e.g. from bedrock_proxy_route)
bedrock_body = {
"retrievalQuery": {"text": "What is in the knowledge base?"},
}

# Call endpoint_func with custom_body — this is the call that
# used to crash with TypeError before the fix
await endpoint_func(
request=mock_request,
fastapi_response=MagicMock(),
user_api_key_dict=mock_user_api_key_dict,
custom_body=bedrock_body,
)

mock_pass_through.assert_called_once()
call_kwargs = mock_pass_through.call_args[1]

# The critical assertion: custom_body takes precedence over
# the body parsed from the raw request
assert call_kwargs["custom_body"] == bedrock_body


@pytest.mark.asyncio
async def test_create_pass_through_route_no_custom_body_falls_back():
"""
Test that the URL-based endpoint_func falls back to the request-parsed body
when custom_body is not provided.

This ensures the default pass-through behavior is preserved — only the
Bedrock proxy route (and similar callers) supply a pre-built body.
"""
from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
create_pass_through_route,
)

unique_path = "/test/path/unique/no_custom_body"
endpoint_func = create_pass_through_route(
endpoint=unique_path,
target="http://example.com/api",
custom_headers={},
)

with patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_request"
) as mock_pass_through, patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.is_registered_pass_through_route"
) as mock_is_registered, patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.get_registered_pass_through_route"
) as mock_get_registered, patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints._parse_request_data_by_content_type"
) as mock_parse_request:
mock_pass_through.return_value = MagicMock()
mock_is_registered.return_value = True
mock_get_registered.return_value = None
request_parsed_body = {"key": "from_request"}
mock_parse_request.return_value = (
{}, # query_params_data
request_parsed_body, # custom_body_data
None, # file_data
False, # stream
)

mock_request = MagicMock(spec=Request)
mock_request.url = MagicMock()
mock_request.url.path = unique_path
mock_request.path_params = {}
mock_request.query_params = QueryParams({})

mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.api_key = "test-key"

# Call without custom_body — should use the request-parsed body
await endpoint_func(
request=mock_request,
fastapi_response=MagicMock(),
user_api_key_dict=mock_user_api_key_dict,
)

mock_pass_through.assert_called_once()
call_kwargs = mock_pass_through.call_args[1]

# Should fall back to the body parsed from the request
assert call_kwargs["custom_body"] == request_parsed_body


def test_build_full_path_with_root_default():
"""
Test _build_full_path_with_root with default root path (/)
Expand Down
Loading