Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 129 additions & 59 deletions litellm/proxy/pass_through_endpoints/pass_through_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ async def chat_completion_pass_through_endpoint( # noqa: PLR0915
elif (
llm_router is not None
and data["model"] not in router_model_names
and (llm_router.default_deployment is not None or len(llm_router.pattern_router.patterns) > 0)
and (
llm_router.default_deployment is not None
or len(llm_router.pattern_router.patterns) > 0
)
): # check for wildcard routes or default deployment before checking deployment_names
llm_response = asyncio.create_task(llm_router.aadapter_completion(**data))
elif (
Expand Down Expand Up @@ -442,10 +445,10 @@ async def make_multipart_http_request(

for field_name, field_value in form_data.items():
if isinstance(field_value, (StarletteUploadFile, UploadFile)):
files[field_name] = (
await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file=field_value
)
files[
field_name
] = await HttpPassThroughEndpointHelpers._build_request_files_from_upload_file(
upload_file=field_value
)
else:
form_data_dict[field_name] = field_value
Expand Down Expand Up @@ -538,9 +541,9 @@ def _init_kwargs_for_pass_through_endpoint(
"passthrough_logging_payload": passthrough_logging_payload,
}

logging_obj.model_call_details["passthrough_logging_payload"] = (
passthrough_logging_payload
)
logging_obj.model_call_details[
"passthrough_logging_payload"
] = passthrough_logging_payload

return kwargs

Expand Down Expand Up @@ -677,7 +680,7 @@ async def pass_through_request( # noqa: PLR0915
user_api_key_dict=user_api_key_dict,
passthrough_guardrails_config=guardrails_config,
)

# Add guardrails to metadata if any should run
if guardrails_to_run and len(guardrails_to_run) > 0:
if _parsed_body is None:
Expand All @@ -700,10 +703,10 @@ async def pass_through_request( # noqa: PLR0915
litellm_call_id=litellm_call_id,
function_id="1245",
)

# Store passthrough guardrails config on logging_obj for field targeting
logging_obj.passthrough_guardrails_config = guardrails_config

# Store logging_obj in data so guardrails can access it
if _parsed_body is None:
_parsed_body = {}
Expand Down Expand Up @@ -738,7 +741,9 @@ async def pass_through_request( # noqa: PLR0915
# Store custom_llm_provider in kwargs and logging object if provided
if custom_llm_provider:
logging_obj.model_call_details["custom_llm_provider"] = custom_llm_provider
logging_obj.model_call_details["litellm_params"] = kwargs.get("litellm_params", {})
logging_obj.model_call_details["litellm_params"] = kwargs.get(
"litellm_params", {}
)

# done for supporting 'parallel_request_limiter.py' with pass-through endpoints
logging_obj.update_environment_variables(
Expand Down Expand Up @@ -928,12 +933,16 @@ async def pass_through_request( # noqa: PLR0915
if kwargs:
for key, value in kwargs.items():
request_payload[key] = value

if "model" not in request_payload and _parsed_body and isinstance(_parsed_body, dict):

if (
"model" not in request_payload
and _parsed_body
and isinstance(_parsed_body, dict)
):
request_payload["model"] = _parsed_body.get("model", "")
if "custom_llm_provider" not in request_payload and custom_llm_provider:
request_payload["custom_llm_provider"] = custom_llm_provider

await proxy_logging_obj.post_call_failure_hook(
user_api_key_dict=user_api_key_dict,
original_exception=e,
Expand Down Expand Up @@ -1442,9 +1451,9 @@ async def forward_client_to_upstream() -> None:
)
if extracted_model:
kwargs["model"] = extracted_model
kwargs["custom_llm_provider"] = (
"vertex_ai-language-models"
)
kwargs[
"custom_llm_provider"
] = "vertex_ai-language-models"
# Update logging object with correct model
logging_obj.model = extracted_model
logging_obj.model_call_details[
Expand Down Expand Up @@ -1510,9 +1519,9 @@ async def forward_upstream_to_client() -> None:
# Update logging object with correct model
logging_obj.model = extracted_model
logging_obj.model_call_details["model"] = extracted_model
logging_obj.model_call_details["custom_llm_provider"] = (
"vertex_ai_language_models"
)
logging_obj.model_call_details[
"custom_llm_provider"
] = "vertex_ai_language_models"
verbose_proxy_logger.debug(
f"WebSocket passthrough ({endpoint}): Successfully extracted model '{extracted_model}' and set provider to 'vertex_ai' from server setup response"
)
Expand Down Expand Up @@ -1840,10 +1849,9 @@ def add_exact_path_route(
# Check if this exact route is already registered
if route_key in _registered_pass_through_routes:
verbose_proxy_logger.debug(
"Skipping duplicate exact pass through endpoint: %s (already registered)",
"Updating duplicate exact pass through endpoint: %s (already registered)",
path,
)
return

verbose_proxy_logger.debug(
"adding exact pass through endpoint: %s, dependencies: %s",
Expand All @@ -1852,7 +1860,7 @@ def add_exact_path_route(
)

# Use SafeRouteAdder to only add route if it doesn't exist on the app
was_added = SafeRouteAdder.add_api_route_if_not_exists(
SafeRouteAdder.add_api_route_if_not_exists(
app=app,
path=path,
endpoint=create_pass_through_route( # type: ignore
Expand All @@ -1869,22 +1877,21 @@ def add_exact_path_route(
dependencies=dependencies,
)

# Register the route to prevent duplicates only if it was added
if was_added:
_registered_pass_through_routes[route_key] = {
"endpoint_id": endpoint_id,
"path": path,
"type": "exact",
"passthrough_params": {
"target": target,
"custom_headers": custom_headers,
"forward_headers": forward_headers,
"merge_query_params": merge_query_params,
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
},
}
# Always register/update the route metadata (headers, target) even if FastAPI route exists
_registered_pass_through_routes[route_key] = {
"endpoint_id": endpoint_id,
"path": path,
"type": "exact",
"passthrough_params": {
"target": target,
"custom_headers": custom_headers,
"forward_headers": forward_headers,
"merge_query_params": merge_query_params,
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
},
}

@staticmethod
def add_subpath_route(
Expand All @@ -1906,10 +1913,9 @@ def add_subpath_route(
# Check if this subpath route is already registered
if route_key in _registered_pass_through_routes:
verbose_proxy_logger.debug(
"Skipping duplicate wildcard pass through endpoint: %s (already registered)",
"Updating duplicate wildcard pass through endpoint: %s (already registered)",
wildcard_path,
)
return

verbose_proxy_logger.debug(
"adding wildcard pass through endpoint: %s, dependencies: %s",
Expand All @@ -1918,7 +1924,7 @@ def add_subpath_route(
)

# Use SafeRouteAdder to only add route if it doesn't exist on the app
was_added = SafeRouteAdder.add_api_route_if_not_exists(
SafeRouteAdder.add_api_route_if_not_exists(
app=app,
path=wildcard_path,
endpoint=create_pass_through_route( # type: ignore
Expand All @@ -1937,21 +1943,20 @@ def add_subpath_route(
)

# Register the route to prevent duplicates only if it was added
if was_added:
_registered_pass_through_routes[route_key] = {
"endpoint_id": endpoint_id,
"path": path,
"type": "subpath",
"passthrough_params": {
"target": target,
"custom_headers": custom_headers,
"forward_headers": forward_headers,
"merge_query_params": merge_query_params,
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
},
}
_registered_pass_through_routes[route_key] = {
"endpoint_id": endpoint_id,
"path": path,
"type": "subpath",
"passthrough_params": {
"target": target,
"custom_headers": custom_headers,
"forward_headers": forward_headers,
"merge_query_params": merge_query_params,
"dependencies": dependencies,
"cost_per_request": cost_per_request,
"guardrails": guardrails,
},
}

@staticmethod
def remove_endpoint_routes(endpoint_id: str):
Expand Down Expand Up @@ -2128,7 +2133,7 @@ async def initialize_pass_through_endpoints(

# Get guardrails config if present
_guardrails = endpoint.get("guardrails", None)

# Add exact path route
verbose_proxy_logger.debug(
"Initializing pass through endpoint: %s (ID: %s)", _path, endpoint_id
Expand Down Expand Up @@ -2307,6 +2312,7 @@ async def get_pass_through_endpoints(
async def update_pass_through_endpoints(
endpoint_id: str,
data: PassThroughGenericEndpoint,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Expand Down Expand Up @@ -2397,6 +2403,37 @@ async def update_pass_through_endpoints(
data=updated_data, user_api_key_dict=user_api_key_dict
)

# Re-register the route with updated headers
_custom_headers: Optional[dict] = updated_endpoint.headers or {}
_custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers)

if updated_endpoint.include_subpath:
InitPassThroughEndpointHelpers.add_subpath_route(
app=request.app,
path=updated_endpoint.path,
target=updated_endpoint.target,
custom_headers=_custom_headers,
forward_headers=None, # Defaults not available in model? assuming None logic handles it
merge_query_params=None,
dependencies=None,
cost_per_request=updated_endpoint.cost_per_request,
endpoint_id=updated_endpoint.id or endpoint_id or "",
guardrails=getattr(updated_endpoint, "guardrails", None),
)
else:
InitPassThroughEndpointHelpers.add_exact_path_route(
app=request.app,
path=updated_endpoint.path,
target=updated_endpoint.target,
custom_headers=_custom_headers,
forward_headers=None,
merge_query_params=None,
dependencies=None,
cost_per_request=updated_endpoint.cost_per_request,
endpoint_id=updated_endpoint.id or endpoint_id or "",
guardrails=getattr(updated_endpoint, "guardrails", None),
)

return PassThroughEndpointResponse(
endpoints=[updated_endpoint] if updated_endpoint else []
)
Expand All @@ -2408,6 +2445,7 @@ async def update_pass_through_endpoints(
)
async def create_pass_through_endpoints(
data: PassThroughGenericEndpoint,
request: Request,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Expand Down Expand Up @@ -2452,6 +2490,38 @@ async def create_pass_through_endpoints(

# Return the created endpoint with the generated ID
created_endpoint = PassThroughGenericEndpoint(**data_dict)

# Register the new route
_custom_headers: Optional[dict] = created_endpoint.headers or {}
_custom_headers = await set_env_variables_in_header(custom_headers=_custom_headers)

if created_endpoint.include_subpath:
InitPassThroughEndpointHelpers.add_subpath_route(
app=request.app,
path=created_endpoint.path,
target=created_endpoint.target,
custom_headers=_custom_headers,
forward_headers=None,
merge_query_params=None,
dependencies=None,
cost_per_request=created_endpoint.cost_per_request,
endpoint_id=created_endpoint.id or "",
guardrails=getattr(created_endpoint, "guardrails", None),
)
else:
InitPassThroughEndpointHelpers.add_exact_path_route(
app=request.app,
path=created_endpoint.path,
target=created_endpoint.target,
custom_headers=_custom_headers,
forward_headers=None,
merge_query_params=None,
dependencies=None,
cost_per_request=created_endpoint.cost_per_request,
endpoint_id=created_endpoint.id or "",
guardrails=getattr(created_endpoint, "guardrails", None),
)

return PassThroughEndpointResponse(endpoints=[created_endpoint])


Expand Down
Loading
Loading