diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 75d8253a904..b6a39f397ef 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -233,7 +233,10 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915 elif ( llm_router is not None and data["model"] not in router_model_names - and (llm_router.default_deployment is not None or len(llm_router.pattern_router.patterns) > 0) + and ( + llm_router.default_deployment is not None + or len(llm_router.pattern_router.patterns) > 0 + ) ): # check for wildcard routes or default deployment before checking deployment_names llm_response = asyncio.create_task(llm_router.aadapter_completion(**data)) elif ( @@ -442,10 +445,10 @@ async def make_multipart_http_request( for field_name, field_value in form_data.items(): if isinstance(field_value, (StarletteUploadFile, UploadFile)): - files[field_name] = ( - await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( - upload_file=field_value - ) + files[ + field_name + ] = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file( + upload_file=field_value ) else: form_data_dict[field_name] = field_value @@ -538,9 +541,9 @@ def _init_kwargs_for_pass_through_endpoint( "passthrough_logging_payload": passthrough_logging_payload, } - logging_obj.model_call_details["passthrough_logging_payload"] = ( - passthrough_logging_payload - ) + logging_obj.model_call_details[ + "passthrough_logging_payload" + ] = passthrough_logging_payload return kwargs @@ -677,7 +680,7 @@ async def pass_through_request( # noqa: PLR0915 user_api_key_dict=user_api_key_dict, passthrough_guardrails_config=guardrails_config, ) - + # Add guardrails to metadata if any should run if guardrails_to_run and len(guardrails_to_run) > 0: if _parsed_body is None: @@ -700,10 +703,10 @@ async def pass_through_request( # noqa: PLR0915 litellm_call_id=litellm_call_id, function_id="1245", ) - + # Store passthrough guardrails config on logging_obj for field targeting logging_obj.passthrough_guardrails_config = guardrails_config - + # Store logging_obj in data so guardrails can access it if _parsed_body is None: _parsed_body = {} @@ -738,7 +741,9 @@ async def pass_through_request( # noqa: PLR0915 # Store custom_llm_provider in kwargs and logging object if provided if custom_llm_provider: logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider - logging_obj.model_call_details["litellm_params"] = kwargs.get("litellm_params", {}) + logging_obj.model_call_details["litellm_params"] = kwargs.get( + "litellm_params", {} + ) # done for supporting 'parallel_request_limiter.py' with pass-through endpoints logging_obj.update_environment_variables( @@ -928,12 +933,16 @@ async def pass_through_request( # noqa: PLR0915 if kwargs: for key, value in kwargs.items(): request_payload[key] = value - - if "model" not in request_payload and _parsed_body and isinstance(_parsed_body, dict): + + if ( + "model" not in request_payload + and _parsed_body + and isinstance(_parsed_body, dict) + ): request_payload["model"] = _parsed_body.get("model", "") if "custom_llm_provider" not in request_payload and custom_llm_provider: request_payload["custom_llm_provider"] = custom_llm_provider - + await proxy_logging_obj.post_call_failure_hook( user_api_key_dict=user_api_key_dict, original_exception=e, @@ -1442,9 +1451,9 @@ async def forward_client_to_upstream() -> None: ) if extracted_model: kwargs["model"] = extracted_model - kwargs["custom_llm_provider"] = ( - "vertex_ai-language-models" - ) + kwargs[ + "custom_llm_provider" + ] = "vertex_ai-language-models" # Update logging object with correct model logging_obj.model = extracted_model logging_obj.model_call_details[ @@ -1510,9 +1519,9 @@ async def forward_upstream_to_client() -> None: # Update logging object with correct model logging_obj.model = extracted_model logging_obj.model_call_details["model"] = extracted_model - logging_obj.model_call_details["custom_llm_provider"] = ( - "vertex_ai_language_models" - ) + logging_obj.model_call_details[ + "custom_llm_provider" + ] = "vertex_ai_language_models" verbose_proxy_logger.debug( f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from server setup response" ) @@ -1840,10 +1849,9 @@ def add_exact_path_route( # Check if this exact route is already registered if route_key in _registered_pass_through_routes: verbose_proxy_logger.debug( - "Skipping duplicate exact pass through endpoint: %s (already registered)", + "Updating duplicate exact pass through endpoint: %s (already registered)", path, ) - return verbose_proxy_logger.debug( "adding exact pass through endpoint: %s, dependencies: %s", @@ -1852,7 +1860,7 @@ def add_exact_path_route( ) # Use SafeRouteAdder to only add route if it doesn't exist on the app - was_added = SafeRouteAdder.add_api_route_if_not_exists( + SafeRouteAdder.add_api_route_if_not_exists( app=app, path=path, endpoint=create_pass_through_route( # type: ignore @@ -1869,22 +1877,21 @@ def add_exact_path_route( dependencies=dependencies, ) - # Register the route to prevent duplicates only if it was added - if was_added: - _registered_pass_through_routes[route_key] = { - "endpoint_id": endpoint_id, - "path": path, - "type": "exact", - "passthrough_params": { - "target": target, - "custom_headers": custom_headers, - "forward_headers": forward_headers, - "merge_query_params": merge_query_params, - "dependencies": dependencies, - "cost_per_request": cost_per_request, - "guardrails": guardrails, - }, - } + # Always register/update the route metadata (headers, target) even if FastAPI route exists + _registered_pass_through_routes[route_key] = { + "endpoint_id": endpoint_id, + "path": path, + "type": "exact", + "passthrough_params": { + "target": target, + "custom_headers": custom_headers, + "forward_headers": forward_headers, + "merge_query_params": merge_query_params, + "dependencies": dependencies, + "cost_per_request": cost_per_request, + "guardrails": guardrails, + }, + } @staticmethod def add_subpath_route( @@ -1906,10 +1913,9 @@ def add_subpath_route( # Check if this subpath route is already registered if route_key in _registered_pass_through_routes: verbose_proxy_logger.debug( - "Skipping duplicate wildcard pass through endpoint: %s (already registered)", + "Updating duplicate wildcard pass through endpoint: %s (already registered)", wildcard_path, ) - return verbose_proxy_logger.debug( "adding wildcard pass through endpoint: %s, dependencies: %s", @@ -1918,7 +1924,7 @@ def add_subpath_route( ) # Use SafeRouteAdder to only add route if it doesn't exist on the app - was_added = SafeRouteAdder.add_api_route_if_not_exists( + SafeRouteAdder.add_api_route_if_not_exists( app=app, path=wildcard_path, endpoint=create_pass_through_route( # type: ignore @@ -1937,21 +1943,20 @@ def add_subpath_route( ) # Register the route to prevent duplicates only if it was added - if was_added: - _registered_pass_through_routes[route_key] = { - "endpoint_id": endpoint_id, - "path": path, - "type": "subpath", - "passthrough_params": { - "target": target, - "custom_headers": custom_headers, - "forward_headers": forward_headers, - "merge_query_params": merge_query_params, - "dependencies": dependencies, - "cost_per_request": cost_per_request, - "guardrails": guardrails, - }, - } + _registered_pass_through_routes[route_key] = { + "endpoint_id": endpoint_id, + "path": path, + "type": "subpath", + "passthrough_params": { + "target": target, + "custom_headers": custom_headers, + "forward_headers": forward_headers, + "merge_query_params": merge_query_params, + "dependencies": dependencies, + "cost_per_request": cost_per_request, + "guardrails": guardrails, + }, + } @staticmethod def remove_endpoint_routes(endpoint_id: str): @@ -2128,7 +2133,7 @@ async def initialize_pass_through_endpoints( # Get guardrails config if present _guardrails = endpoint.get("guardrails", None) - + # Add exact path route verbose_proxy_logger.debug( "Initializing pass through endpoint: %s (ID: %s)", _path, endpoint_id @@ -2307,6 +2312,7 @@ async def get_pass_through_endpoints( async def update_pass_through_endpoints( endpoint_id: str, data: PassThroughGenericEndpoint, + request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -2397,6 +2403,37 @@ async def update_pass_through_endpoints( data=updated_data, user_api_key_dict=user_api_key_dict ) + # Re-register the route with updated headers + _custom_headers: Optional[dict] = updated_endpoint.headers or {} + _custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers) + + if updated_endpoint.include_subpath: + InitPassThroughEndpointHelpers.add_subpath_route( + app=request.app, + path=updated_endpoint.path, + target=updated_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, # Defaults not available in model? assuming None logic handles it + merge_query_params=None, + dependencies=None, + cost_per_request=updated_endpoint.cost_per_request, + endpoint_id=updated_endpoint.id or endpoint_id or "", + guardrails=getattr(updated_endpoint, "guardrails", None), + ) + else: + InitPassThroughEndpointHelpers.add_exact_path_route( + app=request.app, + path=updated_endpoint.path, + target=updated_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, + merge_query_params=None, + dependencies=None, + cost_per_request=updated_endpoint.cost_per_request, + endpoint_id=updated_endpoint.id or endpoint_id or "", + guardrails=getattr(updated_endpoint, "guardrails", None), + ) + return PassThroughEndpointResponse( endpoints=[updated_endpoint] if updated_endpoint else [] ) @@ -2408,6 +2445,7 @@ async def update_pass_through_endpoints( ) async def create_pass_through_endpoints( data: PassThroughGenericEndpoint, + request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -2452,6 +2490,38 @@ async def create_pass_through_endpoints( # Return the created endpoint with the generated ID created_endpoint = PassThroughGenericEndpoint(**data_dict) + + # Register the new route + _custom_headers: Optional[dict] = created_endpoint.headers or {} + _custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers) + + if created_endpoint.include_subpath: + InitPassThroughEndpointHelpers.add_subpath_route( + app=request.app, + path=created_endpoint.path, + target=created_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, + merge_query_params=None, + dependencies=None, + cost_per_request=created_endpoint.cost_per_request, + endpoint_id=created_endpoint.id or "", + guardrails=getattr(created_endpoint, "guardrails", None), + ) + else: + InitPassThroughEndpointHelpers.add_exact_path_route( + app=request.app, + path=created_endpoint.path, + target=created_endpoint.target, + custom_headers=_custom_headers, + forward_headers=None, + merge_query_params=None, + dependencies=None, + cost_per_request=created_endpoint.cost_per_request, + endpoint_id=created_endpoint.id or "", + guardrails=getattr(created_endpoint, "guardrails", None), + ) + return PassThroughEndpointResponse(endpoints=[created_endpoint]) diff --git a/tests/pass_through_unit_tests/test_passthrough_registry_updates.py b/tests/pass_through_unit_tests/test_passthrough_registry_updates.py new file mode 100644 index 00000000000..125ffdb6fa0 --- /dev/null +++ b/tests/pass_through_unit_tests/test_passthrough_registry_updates.py @@ -0,0 +1,145 @@ +from unittest.mock import MagicMock +import asyncio + +# Import the specific components we need to test +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + InitPassThroughEndpointHelpers, + _registered_pass_through_routes, +) + + +def test_update_pass_through_route_updates_registry(): + """ + REGRESSION TEST: Verify that calling add_exact_path_route (or add_subpath_route) + on an EXISTING route correctly updates the in-memory registry. + """ + + async def _async_test(): + # Setup - Unique IDs to avoid collision with other tests + endpoint_id = "regression-test-endpoint" + path = "/regression-test-path" + route_key = f"{endpoint_id}:exact:{path}" + target = "http://example.com" + + # Cleanup: Ensure clean state before test + if route_key in _registered_pass_through_routes: + del _registered_pass_through_routes[route_key] + + try: + # 1. First Registration (Initial State) + InitPassThroughEndpointHelpers.add_exact_path_route( + app=MagicMock(), + path=path, + target=target, + custom_headers={"Authorization": "Bearer INITIAL_TOKEN"}, + forward_headers=False, + merge_query_params=False, + dependencies=[], + cost_per_request=0, + endpoint_id=endpoint_id, + ) + + # Verify Initial State + assert route_key in _registered_pass_through_routes + initial_headers = _registered_pass_through_routes[route_key][ + "passthrough_params" + ]["custom_headers"] + assert initial_headers["Authorization"] == "Bearer INITIAL_TOKEN" + + # 2. Perform Update (Simulate API Update) + # This call should overwrite the existing entry + InitPassThroughEndpointHelpers.add_exact_path_route( + app=MagicMock(), + path=path, + target=target, + custom_headers={ + "Authorization": "Bearer NEW_UPDATED_TOKEN" + }, # Changed Header + forward_headers=False, + merge_query_params=False, + dependencies=[], + cost_per_request=0, + endpoint_id=endpoint_id, + ) + + # 3. Verify Update Occurred + updated_headers = _registered_pass_through_routes[route_key][ + "passthrough_params" + ]["custom_headers"] + + # This assertion protects against the regression + assert ( + updated_headers["Authorization"] == "Bearer NEW_UPDATED_TOKEN" + ), "Registry failed to update! Old headers persisted despite update call." + + finally: + # Cleanup: Remove test entry + if route_key in _registered_pass_through_routes: + del _registered_pass_through_routes[route_key] + + asyncio.run(_async_test()) + + +def test_update_subpath_route_updates_registry(): + """ + REGRESSION TEST: Verify that calling add_subpath_route + on an EXISTING route correctly updates the in-memory registry. + """ + + async def _async_test(): + # Setup + endpoint_id = "regression-test-subpath" + path = "/regression-test-wildcard" + route_key = f"{endpoint_id}:subpath:{path}" + target = "http://example.com" + + if route_key in _registered_pass_through_routes: + del _registered_pass_through_routes[route_key] + + try: + # 1. First Registration + InitPassThroughEndpointHelpers.add_subpath_route( + app=MagicMock(), + path=path, + target=target, + custom_headers={"Authorization": "Bearer INITIAL_SUBPATH_TOKEN"}, + forward_headers=False, + merge_query_params=False, + dependencies=[], + cost_per_request=0, + endpoint_id=endpoint_id, + ) + + assert ( + _registered_pass_through_routes[route_key]["passthrough_params"][ + "custom_headers" + ]["Authorization"] + == "Bearer INITIAL_SUBPATH_TOKEN" + ) + + # 2. Update + InitPassThroughEndpointHelpers.add_subpath_route( + app=MagicMock(), + path=path, + target=target, + custom_headers={"Authorization": "Bearer NEW_SUBPATH_TOKEN"}, + forward_headers=False, + merge_query_params=False, + dependencies=[], + cost_per_request=0, + endpoint_id=endpoint_id, + ) + + # 3. Verify + updated_headers = _registered_pass_through_routes[route_key][ + "passthrough_params" + ]["custom_headers"] + assert ( + updated_headers["Authorization"] == "Bearer NEW_SUBPATH_TOKEN" + ), "Subpath registry failed to update!" + + finally: + if route_key in _registered_pass_through_routes: + del _registered_pass_through_routes[route_key] + + asyncio.run(_async_test()) diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py index c585089c7be..4b5a1f9ec1a 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_pass_through_endpoints.py @@ -7,7 +7,6 @@ import httpx import pytest from fastapi import Request, UploadFile -from fastapi.testclient import TestClient from starlette.datastructures import Headers, QueryParams from starlette.datastructures import UploadFile as StarletteUploadFile @@ -201,7 +200,6 @@ async def test_pass_through_request_failure_handler(): Critical Test: When a users pass through endpoint request fails, we must log the failure code, exception in litellm spend logs. """ - print("running test_pass_through_request_failure_handler") with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: with patch( "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" @@ -266,27 +264,27 @@ def test_is_langfuse_route(): # Test positive cases assert ( handler.is_langfuse_route("http://localhost:4000/langfuse/api/public/traces") - == True + is True ) assert ( handler.is_langfuse_route( "https://proxy.example.com/langfuse/api/public/sessions" ) - == True + is True ) - assert handler.is_langfuse_route("/langfuse/api/public/ingestion") == True - assert handler.is_langfuse_route("http://localhost:4000/langfuse/") == True + assert handler.is_langfuse_route("/langfuse/api/public/ingestion") is True + assert handler.is_langfuse_route("http://localhost:4000/langfuse/") is True # Test negative cases assert ( - handler.is_langfuse_route("https://api.openai.com/v1/chat/completions") == False + handler.is_langfuse_route("https://api.openai.com/v1/chat/completions") is False ) assert ( handler.is_langfuse_route("http://localhost:4000/anthropic/v1/messages") - == False + is False ) - assert handler.is_langfuse_route("https://example.com/other") == False - assert handler.is_langfuse_route("") == False + assert handler.is_langfuse_route("https://example.com/other") is False + assert handler.is_langfuse_route("") is False @pytest.mark.asyncio @@ -576,7 +574,6 @@ def test_set_cost_per_request(): """ Test that _set_cost_per_request correctly sets the cost in logging object and kwargs """ - from datetime import datetime from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.types.passthrough_endpoints.pass_through_endpoints import ( @@ -687,7 +684,7 @@ async def test_pass_through_success_handler_with_cost_per_request(): end_time = datetime.now() # Call the success handler - result = await handler.pass_through_async_success_handler( + await handler.pass_through_async_success_handler( httpx_response=mock_response, response_body={"status": "success", "data": "test"}, logging_obj=mock_logging_obj, @@ -719,8 +716,9 @@ async def test_create_pass_through_route_with_cost_per_request(): ) # Create the endpoint function with cost_per_request + unique_path = "/test/path/unique/cost_per_request" endpoint_func = create_pass_through_route( - endpoint="/test/path", + endpoint=unique_path, target="http://example.com", custom_headers={}, _forward_headers=True, @@ -732,11 +730,19 @@ async def test_create_pass_through_route_with_cost_per_request(): # Mock the pass_through_request function to capture its call with patch( "litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_request" - ) as mock_pass_through: + ) as mock_pass_through, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.is_registered_pass_through_route" + ) as mock_is_registered, patch( + "litellm.proxy.pass_through_endpoints.pass_through_endpoints.InitPassThroughEndpointHelpers.get_registered_pass_through_route" + ) as mock_get_registered: mock_pass_through.return_value = MagicMock() + mock_is_registered.return_value = True + mock_get_registered.return_value = None # Create mock request mock_request = MagicMock(spec=Request) + mock_request.url = MagicMock() + mock_request.url.path = unique_path mock_request.path_params = {} mock_request.query_params = QueryParams({}) @@ -817,7 +823,7 @@ def test_initialize_pass_through_endpoints_with_cost_per_request(): @pytest.mark.asyncio -async def test_pass_through_request_contains_proxy_server_request_in_kwargs(): +async def test_pass_through_request_contains_proxy_server_request_in_kwargs(): # noqa: PLR0915 """ Test that pass_through_request (parent method) correctly includes proxy_server_request in kwargs passed to the success handler. @@ -825,8 +831,6 @@ async def test_pass_through_request_contains_proxy_server_request_in_kwargs(): Critical Test: Ensures that when pass_through_request is called, the kwargs passed to downstream methods contain the proxy server request details (url, method, body). """ - print("running test_pass_through_request_contains_proxy_server_request_in_kwargs") - with patch("litellm.proxy.proxy_server.proxy_logging_obj") as mock_proxy_logging: with patch( "litellm.proxy.pass_through_endpoints.pass_through_endpoints.HttpPassThroughEndpointHelpers.non_streaming_http_request_handler" @@ -891,7 +895,7 @@ async def test_pass_through_request_contains_proxy_server_request_in_kwargs(): mock_user_api_key_dict.request_route = "/api/endpoint" # Call pass_through_request (the parent method) - result = await pass_through_request( + await pass_through_request( request=mock_request, target="http://target-api.com/endpoint", custom_headers={"X-Custom": "header"}, @@ -951,7 +955,6 @@ async def test_create_pass_through_endpoint(): """ from litellm.proxy._types import ( ConfigFieldInfo, - ConfigFieldUpdate, PassThroughEndpointResponse, PassThroughGenericEndpoint, UserAPIKeyAuth, @@ -986,7 +989,9 @@ async def test_create_pass_through_endpoint(): # Call the create function result = await create_pass_through_endpoints( - data=test_endpoint, user_api_key_dict=mock_user_api_key_dict + data=test_endpoint, + request=MagicMock(spec=Request), + user_api_key_dict=mock_user_api_key_dict, ) # Verify the result @@ -1029,7 +1034,6 @@ async def test_update_pass_through_endpoint(): """ from litellm.proxy._types import ( ConfigFieldInfo, - ConfigFieldUpdate, PassThroughEndpointResponse, PassThroughGenericEndpoint, UserAPIKeyAuth, @@ -1082,6 +1086,7 @@ async def test_update_pass_through_endpoint(): result = await update_pass_through_endpoints( endpoint_id=existing_endpoint_id, data=update_data, + request=MagicMock(spec=Request), user_api_key_dict=mock_user_api_key_dict, ) @@ -1165,6 +1170,7 @@ async def test_update_pass_through_endpoint_not_found(): await update_pass_through_endpoints( endpoint_id="non-existent-endpoint-123", data=update_data, + request=MagicMock(spec=Request), user_api_key_dict=mock_user_api_key_dict, ) @@ -1185,7 +1191,6 @@ async def test_delete_pass_through_endpoint(): """ from litellm.proxy._types import ( ConfigFieldInfo, - ConfigFieldUpdate, PassThroughEndpointResponse, UserAPIKeyAuth, ) @@ -1421,7 +1426,7 @@ async def test_pass_through_request_query_params_forwarding(): mock_user_api_key_dict.api_key = "sk-1234" # Call pass_through_request - result = await pass_through_request( + await pass_through_request( request=mock_request, target="https://krris-m2f9a9i7-eastus2.openai.azure.com/openai/assistants", custom_headers={"Authorization": "Bearer azure_token"}, @@ -1498,7 +1503,6 @@ async def mock_body(): # httpbin.org/get returns JSON with info about the request assert '"url": "https://httpbin.org/get"' in response_content - print("GOT A Response from HTTPBIN=", response_content) except Exception as e: # If httpbin.org is not accessible, skip the test import pytest