diff --git a/litellm/proxy/example_config_yaml/pipeline_test_guardrails.py b/litellm/proxy/example_config_yaml/pipeline_test_guardrails.py new file mode 100644 index 00000000000..539a520fcef --- /dev/null +++ b/litellm/proxy/example_config_yaml/pipeline_test_guardrails.py @@ -0,0 +1,69 @@ +""" +Test guardrails for pipeline E2E testing. + +- StrictFilter: blocks any message containing "bad" (case-insensitive) +- PermissiveFilter: always passes (simulates an advanced guardrail that is more lenient) +""" + +from typing import Optional, Union + +from fastapi import HTTPException + +from litellm._logging import verbose_proxy_logger +from litellm.caching.caching import DualCache +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy._types import UserAPIKeyAuth +from litellm.types.utils import CallTypesLiteral + + +class StrictFilter(CustomGuardrail): + """Blocks any message containing the word 'bad'.""" + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + for msg in data.get("messages", []): + content = msg.get("content", "") + if isinstance(content, str) and "bad" in content.lower(): + verbose_proxy_logger.info("StrictFilter: BLOCKED - found 'bad'") + raise HTTPException( + status_code=400, + detail="StrictFilter: content contains forbidden word 'bad'", + ) + verbose_proxy_logger.info("StrictFilter: PASSED") + return data + + +class PermissiveFilter(CustomGuardrail): + """Always passes - simulates a lenient advanced guardrail.""" + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + verbose_proxy_logger.info("PermissiveFilter: PASSED (always passes)") + return data + + +class AlwaysBlockFilter(CustomGuardrail): + """Always blocks - for testing full escalation->block path.""" + + async def async_pre_call_hook( + self, + user_api_key_dict: UserAPIKeyAuth, + cache: DualCache, + data: dict, + call_type: CallTypesLiteral, + ) -> Optional[Union[Exception, str, dict]]: + verbose_proxy_logger.info("AlwaysBlockFilter: BLOCKED") + raise HTTPException( + status_code=400, + detail="AlwaysBlockFilter: all content blocked", + ) diff --git a/litellm/proxy/example_config_yaml/test_pipeline_config.yaml b/litellm/proxy/example_config_yaml/test_pipeline_config.yaml new file mode 100644 index 00000000000..d3a8c56b48a --- /dev/null +++ b/litellm/proxy/example_config_yaml/test_pipeline_config.yaml @@ -0,0 +1,64 @@ +model_list: + - model_name: fake-openai-endpoint + litellm_params: + model: openai/gpt-3.5-turbo + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + - model_name: fake-blocked-endpoint + litellm_params: + model: openai/gpt-3.5-turbo + api_key: fake-key + api_base: https://exampleopenaiendpoint-production.up.railway.app/ + +guardrails: + - guardrail_name: "strict-filter" + litellm_params: + guardrail: pipeline_test_guardrails.StrictFilter + mode: "pre_call" + - guardrail_name: "permissive-filter" + litellm_params: + guardrail: pipeline_test_guardrails.PermissiveFilter + mode: "pre_call" + - guardrail_name: "always-block-filter" + litellm_params: + guardrail: pipeline_test_guardrails.AlwaysBlockFilter + mode: "pre_call" + +policies: + # Pipeline: strict-filter fails -> escalate to permissive-filter + # If strict fails but permissive passes -> allow the request + content-safety-permissive: + description: "Multi-tier: strict filter with permissive fallback" + guardrails: + add: [strict-filter, permissive-filter] + pipeline: + mode: "pre_call" + steps: + - guardrail: strict-filter + on_fail: next # escalate to permissive + on_pass: allow # clean content proceeds + - guardrail: permissive-filter + on_fail: block # hard block + on_pass: allow # permissive says OK + + # Pipeline: strict-filter fails -> escalate to always-block + # Both fail -> block + content-safety-strict: + description: "Multi-tier: strict filter with strict fallback (both block)" + guardrails: + add: [strict-filter, always-block-filter] + pipeline: + mode: "pre_call" + steps: + - guardrail: strict-filter + on_fail: next + on_pass: allow + - guardrail: always-block-filter + on_fail: block + on_pass: allow + +policy_attachments: + - policy: content-safety-permissive + models: [fake-openai-endpoint] + - policy: content-safety-strict + models: [fake-blocked-endpoint] diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 49d31c1efec..fa024cc33d4 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1642,20 +1642,40 @@ def add_guardrails_from_policy_engine( f"Policy engine: resolved guardrails: {resolved_guardrails}" ) - if not resolved_guardrails: - return + # Resolve pipelines from matching policies + pipelines = PolicyResolver.resolve_pipelines_for_context(context=context) # Add resolved guardrails to request metadata if metadata_variable_name not in data: data[metadata_variable_name] = {} + # Track pipeline-managed guardrails to exclude from independent execution + pipeline_managed_guardrails: set = set() + if pipelines: + pipeline_managed_guardrails = PolicyResolver.get_pipeline_managed_guardrails( + pipelines + ) + data[metadata_variable_name]["_guardrail_pipelines"] = pipelines + data[metadata_variable_name]["_pipeline_managed_guardrails"] = ( + pipeline_managed_guardrails + ) + verbose_proxy_logger.debug( + f"Policy engine: resolved {len(pipelines)} pipeline(s), " + f"managed guardrails: {pipeline_managed_guardrails}" + ) + + if not resolved_guardrails and not pipelines: + return + existing_guardrails = data[metadata_variable_name].get("guardrails", []) if not isinstance(existing_guardrails, list): existing_guardrails = [] # Combine existing guardrails with policy-resolved guardrails (no duplicates) + # Exclude pipeline-managed guardrails from the flat list combined = set(existing_guardrails) combined.update(resolved_guardrails) + combined -= pipeline_managed_guardrails data[metadata_variable_name]["guardrails"] = list(combined) verbose_proxy_logger.debug( diff --git a/litellm/proxy/policy_engine/pipeline_executor.py b/litellm/proxy/policy_engine/pipeline_executor.py new file mode 100644 index 00000000000..c015d755a9a --- /dev/null +++ b/litellm/proxy/policy_engine/pipeline_executor.py @@ -0,0 +1,208 @@ +""" +Pipeline Executor - Executes guardrail pipelines with conditional step logic. + +Runs guardrails sequentially per pipeline step definitions, handling +pass/fail actions (allow, block, next, modify_response) and data forwarding. +""" + +import time +from typing import Any, List, Optional + +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + ModifyResponseException, +) +from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + UnifiedLLMGuardrails, +) +from litellm.types.proxy.policy_engine.pipeline_types import ( + PipelineExecutionResult, + PipelineStep, + PipelineStepResult, +) + +try: + from fastapi.exceptions import HTTPException +except ImportError: + HTTPException = None # type: ignore + + +class PipelineExecutor: + """Executes guardrail pipelines with ordered, conditional step logic.""" + + @staticmethod + async def execute_steps( + steps: List[PipelineStep], + mode: str, + data: dict, + user_api_key_dict: Any, + call_type: str, + policy_name: str, + ) -> PipelineExecutionResult: + """ + Execute pipeline steps sequentially with conditional actions. + + Args: + steps: Ordered list of pipeline steps + mode: Event hook mode (pre_call, post_call) + data: Request data dict + user_api_key_dict: User API key auth + call_type: Type of call (completion, etc.) + policy_name: Name of the owning policy (for logging) + + Returns: + PipelineExecutionResult with terminal action and step results + """ + step_results: List[PipelineStepResult] = [] + working_data = copy.deepcopy(data) + + for i, step in enumerate(steps): + start_time = time.perf_counter() + + outcome, modified_data, error_detail = await PipelineExecutor._run_step( + step=step, + mode=mode, + data=working_data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, + ) + + duration = time.perf_counter() - start_time + + action = step.on_pass if outcome == "pass" else step.on_fail + + step_result = PipelineStepResult( + guardrail_name=step.guardrail, + outcome=outcome, + action_taken=action, + modified_data=modified_data, + error_detail=error_detail, + duration_seconds=round(duration, 4), + ) + step_results.append(step_result) + + verbose_proxy_logger.debug( + f"Pipeline '{policy_name}' step {i}: guardrail={step.guardrail}, " + f"outcome={outcome}, action={action}" + ) + + # Forward modified data to next step if pass_data is True + if step.pass_data and modified_data is not None: + working_data = {**working_data, **modified_data} + + # Handle terminal actions + if action == "allow": + return PipelineExecutionResult( + terminal_action="allow", + step_results=step_results, + modified_data=working_data if working_data != data else None, + ) + + if action == "block": + return PipelineExecutionResult( + terminal_action="block", + step_results=step_results, + error_message=error_detail, + ) + + if action == "modify_response": + return PipelineExecutionResult( + terminal_action="modify_response", + step_results=step_results, + modify_response_message=step.modify_response_message or error_detail, + ) + + # action == "next" → continue to next step + + # Ran out of steps without a terminal action → default allow + return PipelineExecutionResult( + terminal_action="allow", + step_results=step_results, + modified_data=working_data if working_data != data else None, + ) + + @staticmethod + async def _run_step( + step: PipelineStep, + mode: str, + data: dict, + user_api_key_dict: Any, + call_type: str, + ) -> tuple: + """ + Run a single pipeline step's guardrail. + + Returns: + Tuple of (outcome, modified_data, error_detail) where: + - outcome: "pass", "fail", or "error" + - modified_data: dict if guardrail returned modified data, else None + - error_detail: error message string if fail/error, else None + """ + callback = PipelineExecutor._find_guardrail_callback(step.guardrail) + if callback is None: + verbose_proxy_logger.warning( + f"Pipeline: guardrail '{step.guardrail}' not found in callbacks" + ) + return ("error", None, f"Guardrail '{step.guardrail}' not found") + + try: + # Use unified_guardrail path if callback implements apply_guardrail + target = callback + use_unified = "apply_guardrail" in type(callback).__dict__ + if use_unified: + data["guardrail_to_apply"] = callback + target = UnifiedLLMGuardrails() + + if mode == "pre_call": + response = await target.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=None, # type: ignore + data=data, + call_type=call_type, # type: ignore + ) + elif mode == "post_call": + response = await target.async_post_call_success_hook( + user_api_key_dict=user_api_key_dict, + data=data, + response=data.get("response"), # type: ignore + ) + else: + return ("error", None, f"Unsupported pipeline mode: {mode}") + + # Normal return means pass + modified_data = None + if response is not None and isinstance(response, dict): + modified_data = response + return ("pass", modified_data, None) + + except Exception as e: + if CustomGuardrail._is_guardrail_intervention(e): + error_msg = _extract_error_message(e) + return ("fail", None, error_msg) + else: + verbose_proxy_logger.error( + f"Pipeline: unexpected error from guardrail '{step.guardrail}': {e}" + ) + return ("error", None, str(e)) + + @staticmethod + def _find_guardrail_callback(guardrail_name: str) -> Optional[CustomGuardrail]: + """Look up an initialized guardrail callback by name from litellm.callbacks.""" + for callback in litellm.callbacks: + if isinstance(callback, CustomGuardrail): + if callback.guardrail_name == guardrail_name: + return callback + return None + + +def _extract_error_message(e: Exception) -> str: + """Extract a human-readable error message from a guardrail exception.""" + if isinstance(e, ModifyResponseException): + return str(e) + if HTTPException is not None and isinstance(e, HTTPException): + detail = getattr(e, "detail", None) + if detail: + return str(detail) + return str(e) diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py index a2431977b24..377b4cd86dd 100644 --- a/litellm/proxy/policy_engine/policy_registry.py +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -12,6 +12,8 @@ from litellm._logging import verbose_proxy_logger from litellm.types.proxy.policy_engine import ( + GuardrailPipeline, + PipelineStep, Policy, PolicyCondition, PolicyCreateRequest, @@ -93,11 +95,32 @@ def _parse_policy(self, policy_name: str, policy_data: Dict[str, Any]) -> Policy if condition_data: condition = PolicyCondition(model=condition_data.get("model")) + # Parse pipeline (optional ordered guardrail execution) + pipeline = PolicyRegistry._parse_pipeline(policy_data.get("pipeline")) + return Policy( inherit=policy_data.get("inherit"), description=policy_data.get("description"), guardrails=guardrails, condition=condition, + pipeline=pipeline, + ) + + @staticmethod + def _parse_pipeline(pipeline_data: Optional[Dict[str, Any]]) -> Optional[GuardrailPipeline]: + """Parse a pipeline configuration from raw data.""" + if pipeline_data is None: + return None + + steps_data = pipeline_data.get("steps", []) + steps = [ + PipelineStep(**step_data) if isinstance(step_data, dict) else step_data + for step_data in steps_data + ] + + return GuardrailPipeline( + mode=pipeline_data.get("mode", "pre_call"), + steps=steps, ) def get_policy(self, policy_name: str) -> Optional[Policy]: diff --git a/litellm/proxy/policy_engine/policy_resolver.py b/litellm/proxy/policy_engine/policy_resolver.py index cfdedc467d8..a8ad78d6491 100644 --- a/litellm/proxy/policy_engine/policy_resolver.py +++ b/litellm/proxy/policy_engine/policy_resolver.py @@ -8,10 +8,11 @@ - Combining guardrails from multiple matching policies """ -from typing import Dict, List, Optional, Set +from typing import Dict, List, Optional, Set, Tuple from litellm._logging import verbose_proxy_logger from litellm.types.proxy.policy_engine import ( + GuardrailPipeline, Policy, PolicyMatchContext, ResolvedPolicy, @@ -190,6 +191,67 @@ def resolve_guardrails_for_context( return result + @staticmethod + def resolve_pipelines_for_context( + context: PolicyMatchContext, + policies: Optional[Dict[str, Policy]] = None, + ) -> List[Tuple[str, GuardrailPipeline]]: + """ + Resolve pipelines from matching policies for a request context. + + Returns (policy_name, pipeline) tuples for policies that have pipelines. + Guardrails managed by pipelines should be excluded from the flat + guardrails list to avoid double execution. + + Args: + context: The request context + policies: Dictionary of all policies (if None, uses global registry) + + Returns: + List of (policy_name, GuardrailPipeline) tuples + """ + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + + if policies is None: + registry = get_policy_registry() + if not registry.is_initialized(): + return [] + policies = registry.get_all_policies() + + matching_policy_names = PolicyMatcher.get_matching_policies(context=context) + if not matching_policy_names: + return [] + + pipelines: List[Tuple[str, GuardrailPipeline]] = [] + for policy_name in matching_policy_names: + policy = policies.get(policy_name) + if policy is None: + continue + if policy.pipeline is not None: + pipelines.append((policy_name, policy.pipeline)) + verbose_proxy_logger.debug( + f"Policy '{policy_name}' has pipeline with " + f"{len(policy.pipeline.steps)} steps" + ) + + return pipelines + + @staticmethod + def get_pipeline_managed_guardrails( + pipelines: List[Tuple[str, GuardrailPipeline]], + ) -> Set[str]: + """ + Get the set of guardrail names managed by pipelines. + + These guardrails should be excluded from normal independent execution. + """ + managed: Set[str] = set() + for _policy_name, pipeline in pipelines: + for step in pipeline.steps: + managed.add(step.guardrail) + return managed + @staticmethod def get_all_resolved_policies( policies: Optional[Dict[str, Policy]] = None, diff --git a/litellm/proxy/policy_engine/policy_validator.py b/litellm/proxy/policy_engine/policy_validator.py index 3eaa67a54d3..89c9b0e2e99 100644 --- a/litellm/proxy/policy_engine/policy_validator.py +++ b/litellm/proxy/policy_engine/policy_validator.py @@ -283,8 +283,14 @@ async def validate_policies( ) ) - # Note: Team, key, and model validation is done via policy_attachments - # Policies no longer have scope - attachments define where policies apply + # Validate pipeline if present + if policy.pipeline is not None: + pipeline_errors = PolicyValidator._validate_pipeline( + policy_name=policy_name, + policy=policy, + available_guardrails=available_guardrails, + ) + errors.extend(pipeline_errors) # Validate inheritance inheritance_errors = self._validate_inheritance_chain( @@ -298,6 +304,53 @@ async def validate_policies( warnings=warnings, ) + @staticmethod + def _validate_pipeline( + policy_name: str, + policy: Policy, + available_guardrails: Set[str], + ) -> List[PolicyValidationError]: + """Validate a policy's pipeline configuration.""" + errors: List[PolicyValidationError] = [] + pipeline = policy.pipeline + if pipeline is None: + return errors + + guardrails_add = set(policy.guardrails.get_add()) + + for i, step in enumerate(pipeline.steps): + # Check guardrail is in policy's guardrails.add + if step.guardrail not in guardrails_add: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_GUARDRAIL, + message=( + f"Pipeline step {i} guardrail '{step.guardrail}' " + f"is not in the policy's guardrails.add list" + ), + field="pipeline.steps", + value=step.guardrail, + ) + ) + + # Check guardrail exists in registry + if available_guardrails and step.guardrail not in available_guardrails: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_GUARDRAIL, + message=( + f"Pipeline step {i} guardrail '{step.guardrail}' " + f"not found in guardrail registry" + ), + field="pipeline.steps", + value=step.guardrail, + ) + ) + + return errors + async def validate_policy_config( self, policy_config: Dict[str, Any], diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index d977751004c..a441b2ae7d1 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -77,7 +77,10 @@ from litellm._service_logger import ServiceLogging, ServiceTypes from litellm.caching.caching import DualCache, RedisCache from litellm.exceptions import RejectedRequestError -from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + ModifyResponseException, +) from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert @@ -110,6 +113,7 @@ _PROXY_MaxParallelRequestsHandler, ) from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup +from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor from litellm.secret_managers.main import str_to_bool from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES from litellm.types.mcp import ( @@ -117,6 +121,7 @@ MCPPreCallRequestObject, MCPPreCallResponseObject, ) +from litellm.types.proxy.policy_engine.pipeline_types import PipelineExecutionResult from litellm.types.utils import LLMResponseTypes, LoggedLiteLLMParams if TYPE_CHECKING: @@ -1141,6 +1146,101 @@ def _process_guardrail_metadata(self, data: dict) -> None: request_data=data, guardrail_name=guardrail_name ) + async def _maybe_execute_pipelines( + self, + data: dict, + user_api_key_dict: UserAPIKeyAuth, + call_type: str, + event_hook: str, + ) -> dict: + """ + Execute guardrail pipelines if any are configured for this request. + + Checks metadata for pipelines resolved by the policy engine + and executes them. Handles the result (allow/block/modify_response). + + Returns the (possibly modified) data dict. + """ + metadata = data.get("metadata", data.get("litellm_metadata", {})) or {} + pipelines = metadata.get("_guardrail_pipelines") + if not pipelines: + return data + + for policy_name, pipeline in pipelines: + if pipeline.mode != event_hook: + continue + + result: PipelineExecutionResult = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data=data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, + policy_name=policy_name, + ) + + data = self._handle_pipeline_result( + result=result, + data=data, + policy_name=policy_name, + ) + + return data + + @staticmethod + def _handle_pipeline_result( + result: Any, + data: dict, + policy_name: str, + ) -> dict: + """ + Handle a PipelineExecutionResult — allow, block, or modify_response. + + Returns data dict if allowed, raises on block/modify_response. + """ + if result.terminal_action == "allow": + if result.modified_data is not None: + data.update(result.modified_data) + return data + + if result.terminal_action == "block": + step_results_serializable = [ + { + "guardrail": sr.guardrail_name, + "outcome": sr.outcome, + "action": sr.action_taken, + } + for sr in result.step_results + ] + error_detail = { + "error": { + "message": f"Content blocked by guardrail pipeline '{policy_name}'", + "type": "guardrail_pipeline_error", + "pipeline_context": { + "policy": policy_name, + "step_results": step_results_serializable, + }, + } + } + if HTTPException is not None: + raise HTTPException(status_code=400, detail=error_detail) + else: + raise Exception(str(error_detail)) + + if result.terminal_action == "modify_response": + raise ModifyResponseException( + message=result.modify_response_message or "Response modified by pipeline", + model=data.get("model", "unknown"), + request_data=data, + guardrail_name=f"pipeline:{policy_name}", + detection_info=None, + ) + + verbose_proxy_logger.warning( + f"Pipeline '{policy_name}': unrecognized terminal_action '{result.terminal_action}', defaulting to allow" + ) + return data + # The actual implementation of the function @overload async def pre_call_hook( @@ -1203,6 +1303,18 @@ async def pre_call_hook( ) try: + # Execute guardrail pipelines before the normal callback loop + data = await self._maybe_execute_pipelines( + data=data, + user_api_key_dict=user_api_key_dict, + call_type=call_type, + event_hook="pre_call", + ) + + # Get pipeline-managed guardrails to skip in normal loop + metadata = data.get("metadata", data.get("litellm_metadata", {})) or {} + pipeline_managed: set = metadata.get("_pipeline_managed_guardrails", set()) + for callback in litellm.callbacks: start_time = time.time() _callback = None @@ -1217,6 +1329,10 @@ async def pre_call_hook( and isinstance(_callback, CustomGuardrail) and data is not None ): + # Skip guardrails managed by a pipeline + if _callback.guardrail_name and _callback.guardrail_name in pipeline_managed: + continue + result = await self._process_guardrail_callback( callback=_callback, data=data, # type: ignore diff --git a/litellm/types/proxy/policy_engine/__init__.py b/litellm/types/proxy/policy_engine/__init__.py index 42490c2eddc..6f1a8d27d34 100644 --- a/litellm/types/proxy/policy_engine/__init__.py +++ b/litellm/types/proxy/policy_engine/__init__.py @@ -10,6 +10,12 @@ - `policy_attachments`: Define WHERE policies apply (teams, keys, models) """ +from litellm.types.proxy.policy_engine.pipeline_types import ( + GuardrailPipeline, + PipelineExecutionResult, + PipelineStep, + PipelineStepResult, +) from litellm.types.proxy.policy_engine.policy_types import ( Policy, PolicyAttachment, @@ -48,6 +54,11 @@ ) __all__ = [ + # Pipeline types + "GuardrailPipeline", + "PipelineStep", + "PipelineStepResult", + "PipelineExecutionResult", # Policy types "Policy", "PolicyConfig", diff --git a/litellm/types/proxy/policy_engine/pipeline_types.py b/litellm/types/proxy/policy_engine/pipeline_types.py new file mode 100644 index 00000000000..29d2e576000 --- /dev/null +++ b/litellm/types/proxy/policy_engine/pipeline_types.py @@ -0,0 +1,98 @@ +""" +Pipeline type definitions for guardrail pipelines. + +Pipelines define ordered, conditional execution of guardrails within a policy. +When a policy has a `pipeline`, its guardrails run in the defined step order +with configurable actions on pass/fail, rather than independently. +""" + +from typing import Any, Dict, List, Literal, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +VALID_PIPELINE_ACTIONS = {"allow", "block", "next", "modify_response"} +VALID_PIPELINE_MODES = {"pre_call", "post_call"} + + +class PipelineStep(BaseModel): + """ + A single step in a guardrail pipeline. + + Each step runs a guardrail and takes an action based on pass/fail. + """ + + guardrail: str = Field(description="Name of the guardrail to run.") + on_fail: str = Field( + default="block", + description="Action when guardrail rejects: next | block | allow | modify_response", + ) + on_pass: str = Field( + default="allow", + description="Action when guardrail passes: next | block | allow | modify_response", + ) + pass_data: bool = Field( + default=False, + description="Forward modified request data (e.g., PII-masked) to next step.", + ) + modify_response_message: Optional[str] = Field( + default=None, + description="Custom message for modify_response action.", + ) + + model_config = ConfigDict(extra="forbid") + + @field_validator("on_fail", "on_pass") + @classmethod + def validate_action(cls, v: str) -> str: + if v not in VALID_PIPELINE_ACTIONS: + raise ValueError( + f"Invalid action '{v}'. Must be one of: {sorted(VALID_PIPELINE_ACTIONS)}" + ) + return v + + +class GuardrailPipeline(BaseModel): + """ + Defines ordered execution of guardrails with conditional actions. + + When present on a policy, the guardrails in `steps` are executed + sequentially instead of independently. + """ + + mode: str = Field(description="Event hook: pre_call | post_call") + steps: List[PipelineStep] = Field( + description="Ordered list of pipeline steps. Must have at least 1 step.", + min_length=1, + ) + + model_config = ConfigDict(extra="forbid") + + @field_validator("mode") + @classmethod + def validate_mode(cls, v: str) -> str: + if v not in VALID_PIPELINE_MODES: + raise ValueError( + f"Invalid mode '{v}'. Must be one of: {sorted(VALID_PIPELINE_MODES)}" + ) + return v + + +class PipelineStepResult(BaseModel): + """Result of executing a single pipeline step.""" + + guardrail_name: str + outcome: Literal["pass", "fail", "error"] + action_taken: str + modified_data: Optional[Dict[str, Any]] = None + error_detail: Optional[str] = None + duration_seconds: Optional[float] = None + + +class PipelineExecutionResult(BaseModel): + """Result of executing an entire pipeline.""" + + terminal_action: str # block | allow | modify_response + step_results: List[PipelineStepResult] + modified_data: Optional[Dict[str, Any]] = None + error_message: Optional[str] = None + modify_response_message: Optional[str] = None diff --git a/litellm/types/proxy/policy_engine/policy_types.py b/litellm/types/proxy/policy_engine/policy_types.py index f221ba7e038..53a74ca6fd8 100644 --- a/litellm/types/proxy/policy_engine/policy_types.py +++ b/litellm/types/proxy/policy_engine/policy_types.py @@ -29,10 +29,12 @@ - `condition`: Optional model condition for when guardrails apply """ -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field +from litellm.types.proxy.policy_engine.pipeline_types import GuardrailPipeline + # ───────────────────────────────────────────────────────────────────────────── # Policy Condition # ───────────────────────────────────────────────────────────────────────────── @@ -231,6 +233,10 @@ class Policy(BaseModel): default=None, description="Optional condition for when this policy's guardrails apply.", ) + pipeline: Optional[GuardrailPipeline] = Field( + default=None, + description="Optional pipeline for ordered, conditional guardrail execution.", + ) model_config = ConfigDict(extra="forbid") diff --git a/tests/test_litellm/proxy/policy_engine/test_pipeline_executor.py b/tests/test_litellm/proxy/policy_engine/test_pipeline_executor.py new file mode 100644 index 00000000000..226e88bea3e --- /dev/null +++ b/tests/test_litellm/proxy/policy_engine/test_pipeline_executor.py @@ -0,0 +1,484 @@ +""" +Tests for the pipeline executor. + +Uses mock guardrails to validate pipeline execution without external services. +""" + +from unittest.mock import MagicMock + +import pytest + +import litellm +from litellm.integrations.custom_guardrail import CustomGuardrail +from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor +from litellm.types.proxy.policy_engine.pipeline_types import ( + GuardrailPipeline, + PipelineStep, +) + +try: + from fastapi.exceptions import HTTPException +except ImportError: + HTTPException = None + + +# ───────────────────────────────────────────────────────────────────────────── +# Mock Guardrails +# ───────────────────────────────────────────────────────────────────────────── + + +class AlwaysFailGuardrail(CustomGuardrail): + """Mock guardrail that always raises HTTPException(400).""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + raise HTTPException(status_code=400, detail="Content policy violation") + + +class AlwaysPassGuardrail(CustomGuardrail): + """Mock guardrail that always passes.""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + return None + + +class PiiMaskingGuardrail(CustomGuardrail): + """Mock guardrail that masks PII in messages and returns modified data.""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + self.received_messages = None + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + self.received_messages = data.get("messages", []) + masked_messages = [] + for msg in data.get("messages", []): + masked_msg = dict(msg) + masked_msg["content"] = msg["content"].replace( + "John Smith", "[REDACTED]" + ) + masked_messages.append(masked_msg) + return {"messages": masked_messages} + + +class ContentCheckGuardrail(CustomGuardrail): + """Mock guardrail that records what messages it received.""" + + def __init__(self, guardrail_name: str): + super().__init__( + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + self.calls = 0 + self.received_messages = None + + def should_run_guardrail(self, data, event_type) -> bool: + return True + + async def async_pre_call_hook(self, user_api_key_dict, cache, data, call_type): + self.calls += 1 + self.received_messages = data.get("messages", []) + return None + + +# ───────────────────────────────────────────────────────────────────────────── +# Tests +# ───────────────────────────────────────────────────────────────────────────── + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_escalation_step1_fails_step2_blocks(): + """ + Pipeline: simple-filter (on_fail: next) -> advanced-filter (on_fail: block) + Input: request that fails simple-filter + Expected: simple-filter fails -> escalate -> advanced-filter fails -> block + """ + simple_guard = AlwaysFailGuardrail(guardrail_name="simple-filter") + advanced_guard = AlwaysFailGuardrail(guardrail_name="advanced-filter") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="simple-filter", on_fail="next", on_pass="allow" + ), + PipelineStep( + guardrail="advanced-filter", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [simple_guard, advanced_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "bad content"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="content-safety", + ) + + assert simple_guard.calls == 1 + assert advanced_guard.calls == 1 + assert result.terminal_action == "block" + assert len(result.step_results) == 2 + assert result.step_results[0].guardrail_name == "simple-filter" + assert result.step_results[0].outcome == "fail" + assert result.step_results[0].action_taken == "next" + assert result.step_results[1].guardrail_name == "advanced-filter" + assert result.step_results[1].outcome == "fail" + assert result.step_results[1].action_taken == "block" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_early_allow_step1_passes_step2_skipped(): + """ + Pipeline: simple-filter (on_pass: allow) -> advanced-filter + Input: clean request that passes simple-filter + Expected: simple-filter passes -> allow (advanced-filter never called) + """ + simple_guard = AlwaysPassGuardrail(guardrail_name="simple-filter") + advanced_guard = AlwaysFailGuardrail(guardrail_name="advanced-filter") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="simple-filter", on_fail="next", on_pass="allow" + ), + PipelineStep( + guardrail="advanced-filter", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [simple_guard, advanced_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "clean content"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="content-safety", + ) + + assert simple_guard.calls == 1 + assert advanced_guard.calls == 0 + assert result.terminal_action == "allow" + assert len(result.step_results) == 1 + assert result.step_results[0].outcome == "pass" + assert result.step_results[0].action_taken == "allow" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_escalation_step1_fails_step2_passes(): + """ + Pipeline: simple-filter (on_fail: next) -> advanced-filter (on_pass: allow) + Input: request that fails simple but passes advanced + Expected: simple-filter fails -> escalate -> advanced-filter passes -> allow + """ + simple_guard = AlwaysFailGuardrail(guardrail_name="simple-filter") + advanced_guard = AlwaysPassGuardrail(guardrail_name="advanced-filter") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="simple-filter", on_fail="next", on_pass="allow" + ), + PipelineStep( + guardrail="advanced-filter", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [simple_guard, advanced_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "borderline content"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="content-safety", + ) + + assert simple_guard.calls == 1 + assert advanced_guard.calls == 1 + assert result.terminal_action == "allow" + assert len(result.step_results) == 2 + assert result.step_results[0].outcome == "fail" + assert result.step_results[0].action_taken == "next" + assert result.step_results[1].outcome == "pass" + assert result.step_results[1].action_taken == "allow" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_data_forwarding_pii_masking(): + """ + Pipeline: pii-masker (pass_data: true, on_pass: next) -> content-check (on_pass: allow) + Input: "Hello John Smith" + Expected: pii-masker masks -> content-check receives "[REDACTED]" -> allow + """ + pii_guard = PiiMaskingGuardrail(guardrail_name="pii-masker") + content_guard = ContentCheckGuardrail(guardrail_name="content-check") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="pii-masker", + on_fail="block", + on_pass="next", + pass_data=True, + ), + PipelineStep( + guardrail="content-check", on_fail="block", on_pass="allow" + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [pii_guard, content_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={ + "messages": [{"role": "user", "content": "Hello John Smith"}] + }, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="pii-then-safety", + ) + + assert pii_guard.calls == 1 + assert content_guard.calls == 1 + assert content_guard.received_messages[0]["content"] == "Hello [REDACTED]" + assert result.terminal_action == "allow" + assert result.modified_data is not None + assert result.modified_data["messages"][0]["content"] == "Hello [REDACTED]" + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_guardrail_not_found_uses_on_fail(): + """ + If a guardrail is not found, treat as error and use on_fail action. + """ + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="nonexistent-guard", + on_fail="block", + on_pass="allow", + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test-policy", + ) + + assert result.terminal_action == "block" + assert result.step_results[0].outcome == "error" + assert "not found" in result.step_results[0].error_detail + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_guardrail_not_found_with_next_continues(): + """ + If a guardrail is not found and on_fail is 'next', continue to next step. + """ + pass_guard = AlwaysPassGuardrail(guardrail_name="fallback-guard") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep( + guardrail="nonexistent-guard", + on_fail="next", + on_pass="allow", + ), + PipelineStep( + guardrail="fallback-guard", + on_fail="block", + on_pass="allow", + ), + ], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [pass_guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test-policy", + ) + + assert result.terminal_action == "allow" + assert len(result.step_results) == 2 + assert result.step_results[0].outcome == "error" + assert result.step_results[0].action_taken == "next" + assert result.step_results[1].outcome == "pass" + assert pass_guard.calls == 1 + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.skipif(HTTPException is None, reason="fastapi not installed") +@pytest.mark.asyncio +async def test_single_step_pipeline_block(): + """Single step pipeline that blocks.""" + guard = AlwaysFailGuardrail(guardrail_name="blocker") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="blocker", on_fail="block")], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test", + ) + + assert result.terminal_action == "block" + assert guard.calls == 1 + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_single_step_pipeline_allow(): + """Single step pipeline that allows.""" + guard = AlwaysPassGuardrail(guardrail_name="passer") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="passer", on_pass="allow")], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test", + ) + + assert result.terminal_action == "allow" + assert guard.calls == 1 + finally: + litellm.callbacks = original_callbacks + + +@pytest.mark.asyncio +async def test_step_results_include_duration(): + """Step results should include timing information.""" + guard = AlwaysPassGuardrail(guardrail_name="timed") + + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="timed")], + ) + + original_callbacks = litellm.callbacks.copy() + litellm.callbacks = [guard] + + try: + result = await PipelineExecutor.execute_steps( + steps=pipeline.steps, + mode=pipeline.mode, + data={"messages": [{"role": "user", "content": "test"}]}, + user_api_key_dict=MagicMock(), + call_type="completion", + policy_name="test", + ) + + assert result.step_results[0].duration_seconds is not None + assert result.step_results[0].duration_seconds >= 0 + finally: + litellm.callbacks = original_callbacks diff --git a/tests/test_litellm/types/__init__.py b/tests/test_litellm/types/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/types/proxy/__init__.py b/tests/test_litellm/types/proxy/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/types/proxy/policy_engine/__init__.py b/tests/test_litellm/types/proxy/policy_engine/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/types/proxy/policy_engine/test_pipeline_types.py b/tests/test_litellm/types/proxy/policy_engine/test_pipeline_types.py new file mode 100644 index 00000000000..21fecc015a3 --- /dev/null +++ b/tests/test_litellm/types/proxy/policy_engine/test_pipeline_types.py @@ -0,0 +1,152 @@ +""" +Tests for pipeline type definitions. +""" + +import pytest +from pydantic import ValidationError + +from litellm.types.proxy.policy_engine.pipeline_types import ( + GuardrailPipeline, + PipelineExecutionResult, + PipelineStep, + PipelineStepResult, +) +from litellm.types.proxy.policy_engine.policy_types import ( + Policy, + PolicyGuardrails, +) + + +def test_pipeline_step_defaults(): + step = PipelineStep(guardrail="my-guard") + assert step.on_fail == "block" + assert step.on_pass == "allow" + assert step.pass_data is False + assert step.modify_response_message is None + + +def test_pipeline_step_valid_actions(): + step = PipelineStep(guardrail="my-guard", on_fail="next", on_pass="next") + assert step.on_fail == "next" + assert step.on_pass == "next" + + +def test_pipeline_step_all_action_types(): + for action in ("allow", "block", "next", "modify_response"): + step = PipelineStep(guardrail="g", on_fail=action, on_pass=action) + assert step.on_fail == action + assert step.on_pass == action + + +def test_pipeline_step_invalid_action_rejected(): + with pytest.raises(ValidationError): + PipelineStep(guardrail="my-guard", on_fail="invalid_action") + + +def test_pipeline_step_invalid_on_pass_rejected(): + with pytest.raises(ValidationError): + PipelineStep(guardrail="my-guard", on_pass="skip") + + +def test_pipeline_requires_at_least_one_step(): + with pytest.raises(ValidationError): + GuardrailPipeline(mode="pre_call", steps=[]) + + +def test_pipeline_invalid_mode_rejected(): + with pytest.raises(ValidationError): + GuardrailPipeline( + mode="during_call", + steps=[PipelineStep(guardrail="g")], + ) + + +def test_pipeline_valid_modes(): + for mode in ("pre_call", "post_call"): + pipeline = GuardrailPipeline( + mode=mode, + steps=[PipelineStep(guardrail="g")], + ) + assert pipeline.mode == mode + + +def test_pipeline_with_multiple_steps(): + pipeline = GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep(guardrail="g1", on_fail="next", on_pass="allow"), + PipelineStep(guardrail="g2", on_fail="block", on_pass="allow"), + ], + ) + assert len(pipeline.steps) == 2 + assert pipeline.steps[0].guardrail == "g1" + assert pipeline.steps[1].guardrail == "g2" + + +def test_policy_with_pipeline_parses(): + policy = Policy( + guardrails=PolicyGuardrails(add=["g1", "g2"]), + pipeline=GuardrailPipeline( + mode="pre_call", + steps=[ + PipelineStep(guardrail="g1", on_fail="next"), + PipelineStep(guardrail="g2"), + ], + ), + ) + assert policy.pipeline is not None + assert len(policy.pipeline.steps) == 2 + + +def test_policy_without_pipeline(): + policy = Policy( + guardrails=PolicyGuardrails(add=["g1"]), + ) + assert policy.pipeline is None + + +def test_pipeline_step_result(): + result = PipelineStepResult( + guardrail_name="g1", + outcome="fail", + action_taken="next", + error_detail="Content policy violation", + duration_seconds=0.05, + ) + assert result.outcome == "fail" + assert result.action_taken == "next" + + +def test_pipeline_execution_result(): + result = PipelineExecutionResult( + terminal_action="block", + step_results=[ + PipelineStepResult( + guardrail_name="g1", + outcome="fail", + action_taken="next", + ), + PipelineStepResult( + guardrail_name="g2", + outcome="fail", + action_taken="block", + ), + ], + error_message="Content blocked", + ) + assert result.terminal_action == "block" + assert len(result.step_results) == 2 + + +def test_pipeline_step_extra_fields_rejected(): + with pytest.raises(ValidationError): + PipelineStep(guardrail="g", unknown_field="value") + + +def test_pipeline_extra_fields_rejected(): + with pytest.raises(ValidationError): + GuardrailPipeline( + mode="pre_call", + steps=[PipelineStep(guardrail="g")], + unknown="value", + )