Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ litellm_settings:
mode: pre_call # or post_call, during_call
api_base: https://your-guardrail-api.com
api_key: os.environ/YOUR_GUARDRAIL_API_KEY # optional
unreachable_fallback: fail_closed # default: fail_closed. Set to fail_open to proceed if the guardrail endpoint is unreachable (network errors, or HTTP 502/503/504 from an upstream proxy/LB).
additional_provider_specific_params:
# your custom parameters
threshold: 0.8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"
additional_provider_specific_params=getattr(
litellm_params, "additional_provider_specific_params", {}
),
unreachable_fallback=getattr(
litellm_params, "unreachable_fallback", "fail_closed"
),
guardrail_name=guardrail.get("guardrail_name", ""),
event_hook=litellm_params.mode,
default_on=litellm_params.default_on,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ litellm_settings:
mode: pre_call # Options: pre_call, post_call, during_call, [pre_call, post_call]
api_key: os.environ/GENERIC_GUARDRAIL_API_KEY # Optional if using Bearer auth
api_base: http://localhost:8080 # Required. Endpoint /beta/litellm_basic_guardrail_api is automatically appended
unreachable_fallback: fail_closed # Options: fail_closed (default, raise), fail_open (proceed if endpoint unreachable or upstream returns 502/503/504)
default_on: false # Set to true to apply to all requests by default
additional_provider_specific_params:
# Any additional parameters your guardrail API needs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import os
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional

import httpx

from litellm._logging import verbose_proxy_logger
from litellm._version import version as litellm_version
from litellm.exceptions import GuardrailRaisedException
from litellm.exceptions import GuardrailRaisedException, Timeout
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
Expand All @@ -34,17 +36,19 @@
GUARDRAIL_NAME = "generic_guardrail_api"

# Headers whose values are forwarded as-is (case-insensitive). Glob patterns supported (e.g. x-stainless-*, x-litellm*).
_HEADER_VALUE_ALLOWLIST = frozenset({
"host",
"accept-encoding",
"connection",
"accept",
"content-type",
"user-agent",
"x-stainless-*",
"x-litellm-*",
"content-length",
})
_HEADER_VALUE_ALLOWLIST = frozenset(
{
"host",
"accept-encoding",
"connection",
"accept",
"content-type",
"user-agent",
"x-stainless-*",
"x-litellm-*",
"content-length",
}
)

# Placeholder for headers that exist but are not on the allowlist (we don't expose their value).
_HEADER_PRESENT_PLACEHOLDER = "[present]"
Expand Down Expand Up @@ -166,6 +170,7 @@ def __init__(
api_base: Optional[str] = None,
api_key: Optional[str] = None,
additional_provider_specific_params: Optional[Dict[str, Any]] = None,
unreachable_fallback: Literal["fail_closed", "fail_open"] = "fail_closed",
**kwargs,
):
self.async_handler = get_async_httpx_client(
Expand Down Expand Up @@ -196,6 +201,10 @@ def __init__(
additional_provider_specific_params or {}
)

self.unreachable_fallback: Literal["fail_closed", "fail_open"] = (
unreachable_fallback
)

# Set supported event hooks
if "supported_event_hooks" not in kwargs:
kwargs["supported_event_hooks"] = [
Expand Down Expand Up @@ -259,6 +268,54 @@ def _extract_user_api_key_metadata(

return result_metadata

def _fail_open_passthrough(
self,
*,
inputs: GenericGuardrailAPIInputs,
input_type: Literal["request", "response"],
logging_obj: Optional["LiteLLMLoggingObj"],
error: Exception,
http_status_code: Optional[int] = None,
) -> GenericGuardrailAPIInputs:
status_suffix = f" http_status_code={http_status_code}" if http_status_code else ""
verbose_proxy_logger.critical(
"Generic Guardrail API unreachable (fail-open). Proceeding without guardrail.%s "
"guardrail_name=%s api_base=%s input_type=%s litellm_call_id=%s litellm_trace_id=%s",
status_suffix,
getattr(self, "guardrail_name", None),
getattr(self, "api_base", None),
input_type,
getattr(logging_obj, "litellm_call_id", None) if logging_obj else None,
getattr(logging_obj, "litellm_trace_id", None) if logging_obj else None,
exc_info=error,
)
# Keep flow going - treat as action=NONE (no modifications)
return_inputs: GenericGuardrailAPIInputs = {}
return_inputs.update(inputs)
return return_inputs

def _build_guardrail_return_inputs(
self,
*,
texts: list,
images: Any,
tools: Any,
guardrail_response: GenericGuardrailAPIResponse,
) -> GenericGuardrailAPIInputs:
# Action is NONE or no modifications needed
return_inputs = GenericGuardrailAPIInputs(texts=texts)
if guardrail_response.texts:
return_inputs["texts"] = guardrail_response.texts
if guardrail_response.images:
return_inputs["images"] = guardrail_response.images
elif images:
return_inputs["images"] = images
if guardrail_response.tools:
return_inputs["tools"] = guardrail_response.tools
elif tools:
return_inputs["tools"] = tools
return return_inputs

@log_guardrail_information
async def apply_guardrail(
self,
Expand Down Expand Up @@ -313,7 +370,9 @@ async def apply_guardrail(

# Extract user API key metadata
user_metadata = self._extract_user_api_key_metadata(request_data)
inbound_headers = _extract_inbound_headers(request_data=request_data, logging_obj=logging_obj)
inbound_headers = _extract_inbound_headers(
request_data=request_data, logging_obj=logging_obj
)

# Create request payload
guardrail_request = GenericGuardrailAPIRequest(
Expand Down Expand Up @@ -370,23 +429,64 @@ async def apply_guardrail(
should_wrap_with_default_message=False,
)

# Action is NONE or no modifications needed
return_inputs = GenericGuardrailAPIInputs(texts=texts)
if guardrail_response.texts:
return_inputs["texts"] = guardrail_response.texts
if guardrail_response.images:
return_inputs["images"] = guardrail_response.images
elif images:
return_inputs["images"] = images
if guardrail_response.tools:
return_inputs["tools"] = guardrail_response.tools
elif tools:
return_inputs["tools"] = tools
return return_inputs
return self._build_guardrail_return_inputs(
texts=texts,
images=images,
tools=tools,
guardrail_response=guardrail_response,
)

except GuardrailRaisedException:
# Re-raise guardrail exceptions as-is
raise
except Timeout as e:
# AsyncHTTPHandler wraps httpx.TimeoutException into litellm.Timeout
if self.unreachable_fallback == "fail_open":
return self._fail_open_passthrough(
inputs=inputs,
input_type=input_type,
logging_obj=logging_obj,
error=e,
)

verbose_proxy_logger.error(
"Generic Guardrail API: failed to make request: %s", str(e)
)
raise Exception(f"Generic Guardrail API failed: {str(e)}")
except httpx.HTTPStatusError as e:
# Common reverse-proxy/LB failures can present as HTTP errors even when the backend is unreachable.
status_code = getattr(getattr(e, "response", None), "status_code", None)
if self.unreachable_fallback == "fail_open" and status_code in (
502,
503,
504,
):
return self._fail_open_passthrough(
inputs=inputs,
input_type=input_type,
logging_obj=logging_obj,
error=e,
http_status_code=status_code,
)

verbose_proxy_logger.error(
"Generic Guardrail API: failed to make request: %s", str(e)
)
raise Exception(f"Generic Guardrail API failed: {str(e)}")
except httpx.RequestError as e:
# Guardrail endpoint is unreachable (DNS/connect/timeout/etc)
if self.unreachable_fallback == "fail_open":
return self._fail_open_passthrough(
inputs=inputs,
input_type=input_type,
logging_obj=logging_obj,
error=e,
)

verbose_proxy_logger.error(
"Generic Guardrail API: failed to make request: %s", str(e)
)
raise Exception(f"Generic Guardrail API failed: {str(e)}")
except Exception as e:
verbose_proxy_logger.error(
"Generic Guardrail API: failed to make request: %s", str(e)
Expand Down
10 changes: 10 additions & 0 deletions litellm/types/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,15 @@ class BaseLitellmParams(
description="Additional provider-specific parameters for generic guardrail APIs",
)

unreachable_fallback: Literal["fail_closed", "fail_open"] = Field(
default="fail_closed",
description=(
"Behavior when a guardrail endpoint is unreachable due to network errors. "
"NOTE: This is currently only implemented by guardrail='generic_guardrail_api'. "
"'fail_closed' raises an error (default). 'fail_open' logs a critical error and allows the request to proceed."
),
)

# Custom code guardrail params
custom_code: Optional[str] = Field(
default=None,
Expand Down Expand Up @@ -692,6 +701,7 @@ class LitellmParams(
"mode",
"default_action",
"on_disallowed_action",
"unreachable_fallback",
mode="before",
check_fields=False,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ class GenericGuardrailAPIOptionalParams(BaseModel):
description="Additional provider-specific parameters to send with the guardrail request",
)

unreachable_fallback: Optional[Literal["fail_closed", "fail_open"]] = Field(
default="fail_closed",
description=(
"Behavior when the guardrail endpoint is unreachable due to network errors. "
"'fail_closed' raises an error (default). 'fail_open' logs a critical error and allows the request to proceed."
),
)


class GenericGuardrailAPIConfigModel(
GuardrailConfigModel[GenericGuardrailAPIOptionalParams],
Expand All @@ -52,9 +60,9 @@ class GenericGuardrailAPIRequest(BaseModel):

input_type: Literal["request", "response"]
litellm_call_id: Optional[str] = None # the call id of the individual LLM call
litellm_trace_id: Optional[
str
] = None # the trace id of the LLM call - useful if there are multiple LLM calls for the same conversation
litellm_trace_id: Optional[str] = (
None # the trace id of the LLM call - useful if there are multiple LLM calls for the same conversation
)
structured_messages: Optional[List[AllMessageValues]] = None
images: Optional[List[str]] = None
tools: Optional[List[ChatCompletionToolParam]] = None
Expand Down
Loading
Loading