Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions litellm/proxy/example_config_yaml/pipeline_test_guardrails.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Test guardrails for pipeline E2E testing.

- StrictFilter: blocks any message containing "bad" (case-insensitive)
- PermissiveFilter: always passes (simulates an advanced guardrail that is more lenient)
"""

from typing import Optional, Union

from fastapi import HTTPException

from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import CallTypesLiteral


class StrictFilter(CustomGuardrail):
"""Blocks any message containing the word 'bad'."""

async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Optional[Union[Exception, str, dict]]:
for msg in data.get("messages", []):
content = msg.get("content", "")
if isinstance(content, str) and "bad" in content.lower():
verbose_proxy_logger.info("StrictFilter: BLOCKED - found 'bad'")
raise HTTPException(
status_code=400,
detail="StrictFilter: content contains forbidden word 'bad'",
)
verbose_proxy_logger.info("StrictFilter: PASSED")
return data


class PermissiveFilter(CustomGuardrail):
"""Always passes - simulates a lenient advanced guardrail."""

async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Optional[Union[Exception, str, dict]]:
verbose_proxy_logger.info("PermissiveFilter: PASSED (always passes)")
return data


class AlwaysBlockFilter(CustomGuardrail):
"""Always blocks - for testing full escalation->block path."""

async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Optional[Union[Exception, str, dict]]:
verbose_proxy_logger.info("AlwaysBlockFilter: BLOCKED")
raise HTTPException(
status_code=400,
detail="AlwaysBlockFilter: all content blocked",
)
64 changes: 64 additions & 0 deletions litellm/proxy/example_config_yaml/test_pipeline_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
model_list:
- model_name: fake-openai-endpoint
litellm_params:
model: openai/gpt-3.5-turbo
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: fake-blocked-endpoint
litellm_params:
model: openai/gpt-3.5-turbo
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/

guardrails:
- guardrail_name: "strict-filter"
litellm_params:
guardrail: pipeline_test_guardrails.StrictFilter
mode: "pre_call"
- guardrail_name: "permissive-filter"
litellm_params:
guardrail: pipeline_test_guardrails.PermissiveFilter
mode: "pre_call"
- guardrail_name: "always-block-filter"
litellm_params:
guardrail: pipeline_test_guardrails.AlwaysBlockFilter
mode: "pre_call"

policies:
# Pipeline: strict-filter fails -> escalate to permissive-filter
# If strict fails but permissive passes -> allow the request
content-safety-permissive:
description: "Multi-tier: strict filter with permissive fallback"
guardrails:
add: [strict-filter, permissive-filter]
pipeline:
mode: "pre_call"
steps:
- guardrail: strict-filter
on_fail: next # escalate to permissive
on_pass: allow # clean content proceeds
- guardrail: permissive-filter
on_fail: block # hard block
on_pass: allow # permissive says OK

# Pipeline: strict-filter fails -> escalate to always-block
# Both fail -> block
content-safety-strict:
description: "Multi-tier: strict filter with strict fallback (both block)"
guardrails:
add: [strict-filter, always-block-filter]
pipeline:
mode: "pre_call"
steps:
- guardrail: strict-filter
on_fail: next
on_pass: allow
- guardrail: always-block-filter
on_fail: block
on_pass: allow

policy_attachments:
- policy: content-safety-permissive
models: [fake-openai-endpoint]
- policy: content-safety-strict
models: [fake-blocked-endpoint]
24 changes: 22 additions & 2 deletions litellm/proxy/litellm_pre_call_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1642,20 +1642,40 @@ def add_guardrails_from_policy_engine(
f"Policy engine: resolved guardrails: {resolved_guardrails}"
)

if not resolved_guardrails:
return
# Resolve pipelines from matching policies
pipelines = PolicyResolver.resolve_pipelines_for_context(context=context)

# Add resolved guardrails to request metadata
if metadata_variable_name not in data:
data[metadata_variable_name] = {}

# Track pipeline-managed guardrails to exclude from independent execution
pipeline_managed_guardrails: set = set()
if pipelines:
pipeline_managed_guardrails = PolicyResolver.get_pipeline_managed_guardrails(
pipelines
)
data[metadata_variable_name]["_guardrail_pipelines"] = pipelines
data[metadata_variable_name]["_pipeline_managed_guardrails"] = (
pipeline_managed_guardrails
)
verbose_proxy_logger.debug(
f"Policy engine: resolved {len(pipelines)} pipeline(s), "
f"managed guardrails: {pipeline_managed_guardrails}"
)

if not resolved_guardrails and not pipelines:
return

existing_guardrails = data[metadata_variable_name].get("guardrails", [])
if not isinstance(existing_guardrails, list):
existing_guardrails = []

# Combine existing guardrails with policy-resolved guardrails (no duplicates)
# Exclude pipeline-managed guardrails from the flat list
combined = set(existing_guardrails)
combined.update(resolved_guardrails)
combined -= pipeline_managed_guardrails
data[metadata_variable_name]["guardrails"] = list(combined)

verbose_proxy_logger.debug(
Expand Down
208 changes: 208 additions & 0 deletions litellm/proxy/policy_engine/pipeline_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
"""
Pipeline Executor - Executes guardrail pipelines with conditional step logic.

Runs guardrails sequentially per pipeline step definitions, handling
pass/fail actions (allow, block, next, modify_response) and data forwarding.
"""

import time
from typing import Any, List, Optional

import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
ModifyResponseException,
)
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
UnifiedLLMGuardrails,
)
from litellm.types.proxy.policy_engine.pipeline_types import (
PipelineExecutionResult,
PipelineStep,
PipelineStepResult,
)

try:
from fastapi.exceptions import HTTPException
except ImportError:
HTTPException = None # type: ignore


class PipelineExecutor:
"""Executes guardrail pipelines with ordered, conditional step logic."""

@staticmethod
async def execute_steps(
steps: List[PipelineStep],
mode: str,
data: dict,
user_api_key_dict: Any,
call_type: str,
policy_name: str,
) -> PipelineExecutionResult:
"""
Execute pipeline steps sequentially with conditional actions.

Args:
steps: Ordered list of pipeline steps
mode: Event hook mode (pre_call, post_call)
data: Request data dict
user_api_key_dict: User API key auth
call_type: Type of call (completion, etc.)
policy_name: Name of the owning policy (for logging)

Returns:
PipelineExecutionResult with terminal action and step results
"""
step_results: List[PipelineStepResult] = []
working_data = copy.deepcopy(data)

for i, step in enumerate(steps):
start_time = time.perf_counter()

outcome, modified_data, error_detail = await PipelineExecutor._run_step(
step=step,
mode=mode,
data=working_data,
user_api_key_dict=user_api_key_dict,
call_type=call_type,
)

duration = time.perf_counter() - start_time

action = step.on_pass if outcome == "pass" else step.on_fail
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"error" outcome silently uses on_fail action

When outcome is "error" (e.g., guardrail not found, unexpected exception), the code falls into the else branch and uses step.on_fail. This means if on_fail is "next", unexpected runtime errors (not guardrail interventions) will be silently swallowed and the pipeline will continue to the next step.

This is a design choice, but it means a misconfigured guardrail name or a transient error could be invisible to the operator when on_fail: next is set. Consider either logging at warning level when outcome == "error" and action is "next", or introducing a separate on_error action field.


step_result = PipelineStepResult(
guardrail_name=step.guardrail,
outcome=outcome,
action_taken=action,
modified_data=modified_data,
error_detail=error_detail,
duration_seconds=round(duration, 4),
)
step_results.append(step_result)

verbose_proxy_logger.debug(
f"Pipeline '{policy_name}' step {i}: guardrail={step.guardrail}, "
f"outcome={outcome}, action={action}"
)

# Forward modified data to next step if pass_data is True
if step.pass_data and modified_data is not None:
working_data = {**working_data, **modified_data}

# Handle terminal actions
if action == "allow":
return PipelineExecutionResult(
terminal_action="allow",
step_results=step_results,
modified_data=working_data if working_data != data else None,
)

if action == "block":
return PipelineExecutionResult(
terminal_action="block",
step_results=step_results,
error_message=error_detail,
)

if action == "modify_response":
return PipelineExecutionResult(
terminal_action="modify_response",
step_results=step_results,
modify_response_message=step.modify_response_message or error_detail,
)

# action == "next" → continue to next step

# Ran out of steps without a terminal action → default allow
return PipelineExecutionResult(
terminal_action="allow",
step_results=step_results,
modified_data=working_data if working_data != data else None,
)

@staticmethod
async def _run_step(
step: PipelineStep,
mode: str,
data: dict,
user_api_key_dict: Any,
call_type: str,
) -> tuple:
"""
Run a single pipeline step's guardrail.

Returns:
Tuple of (outcome, modified_data, error_detail) where:
- outcome: "pass", "fail", or "error"
- modified_data: dict if guardrail returned modified data, else None
- error_detail: error message string if fail/error, else None
"""
callback = PipelineExecutor._find_guardrail_callback(step.guardrail)
if callback is None:
verbose_proxy_logger.warning(
f"Pipeline: guardrail '{step.guardrail}' not found in callbacks"
)
return ("error", None, f"Guardrail '{step.guardrail}' not found")

try:
# Use unified_guardrail path if callback implements apply_guardrail
target = callback
use_unified = "apply_guardrail" in type(callback).__dict__
if use_unified:
data["guardrail_to_apply"] = callback
target = UnifiedLLMGuardrails()

if mode == "pre_call":
response = await target.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=None, # type: ignore
data=data,
call_type=call_type, # type: ignore
)
elif mode == "post_call":
response = await target.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=data.get("response"), # type: ignore
)
else:
return ("error", None, f"Unsupported pipeline mode: {mode}")

# Normal return means pass
modified_data = None
if response is not None and isinstance(response, dict):
modified_data = response
return ("pass", modified_data, None)

except Exception as e:
if CustomGuardrail._is_guardrail_intervention(e):
error_msg = _extract_error_message(e)
return ("fail", None, error_msg)
else:
verbose_proxy_logger.error(
f"Pipeline: unexpected error from guardrail '{step.guardrail}': {e}"
)
return ("error", None, str(e))

@staticmethod
def _find_guardrail_callback(guardrail_name: str) -> Optional[CustomGuardrail]:
"""Look up an initialized guardrail callback by name from litellm.callbacks."""
for callback in litellm.callbacks:
if isinstance(callback, CustomGuardrail):
if callback.guardrail_name == guardrail_name:
return callback
return None


def _extract_error_message(e: Exception) -> str:
"""Extract a human-readable error message from a guardrail exception."""
if isinstance(e, ModifyResponseException):
return str(e)
if HTTPException is not None and isinstance(e, HTTPException):
detail = getattr(e, "detail", None)
if detail:
return str(detail)
return str(e)
Loading
Loading