diff --git a/docs/my-website/docs/proxy/guardrails/guardrail_policies.md b/docs/my-website/docs/proxy/guardrails/guardrail_policies.md new file mode 100644 index 00000000000..56be11c85a7 --- /dev/null +++ b/docs/my-website/docs/proxy/guardrails/guardrail_policies.md @@ -0,0 +1,283 @@ +# [Beta] Guardrail Policies + +Use policies to group guardrails and control which ones run for specific teams, keys, or models. + +## Why use policies? + +- Enable/disable specific guardrails for teams, keys, or models +- Group guardrails into a single policy +- Inherit from existing policies and override what you need + +## Quick Start + +```yaml showLineNumbers title="config.yaml" +model_list: + - model_name: gpt-4 + litellm_params: + model: openai/gpt-4 + +# 1. Define your guardrails +guardrails: + - guardrail_name: pii_masking + litellm_params: + guardrail: presidio + mode: pre_call + + - guardrail_name: prompt_injection + litellm_params: + guardrail: lakera + mode: pre_call + api_key: os.environ/LAKERA_API_KEY + +# 2. Create a policy +policies: + my-policy: + guardrails: + add: + - pii_masking + - prompt_injection + +# 3. Attach the policy +policy_attachments: + - policy: my-policy + scope: "*" # apply to all requests +``` + +Response headers show what ran: + +``` +x-litellm-applied-policies: my-policy +x-litellm-applied-guardrails: pii_masking,prompt_injection +``` + +## Add guardrails for a specific team + +:::info +✨ Enterprise only feature for team/key-based policy attachments. [Get a free trial](https://www.litellm.ai/enterprise#trial) +::: + +You have a global baseline, but want to add extra guardrails for a specific team. + +```yaml showLineNumbers title="config.yaml" +policies: + global-baseline: + guardrails: + add: + - pii_masking + + finance-team-policy: + inherit: global-baseline + guardrails: + add: + - strict_compliance_check + - audit_logger + +policy_attachments: + - policy: global-baseline + scope: "*" + + - policy: finance-team-policy + teams: + - finance # team alias from /team/new +``` + +Now the `finance` team gets `pii_masking` + `strict_compliance_check` + `audit_logger`, while everyone else just gets `pii_masking`. + +## Remove guardrails for a specific team + +:::info +✨ Enterprise only feature for team/key-based policy attachments. [Get a free trial](https://www.litellm.ai/enterprise#trial) +::: + +You have guardrails running globally, but want to disable some for a specific team (e.g., internal testing). + +```yaml showLineNumbers title="config.yaml" +policies: + global-baseline: + guardrails: + add: + - pii_masking + - prompt_injection + + internal-team-policy: + inherit: global-baseline + guardrails: + remove: + - pii_masking # don't need PII masking for internal testing + +policy_attachments: + - policy: global-baseline + scope: "*" + + - policy: internal-team-policy + teams: + - internal-testing # team alias from /team/new +``` + +Now the `internal-testing` team only gets `prompt_injection`, while everyone else gets both guardrails. + +## Inheritance + +Start with a base policy and build on it: + +```yaml showLineNumbers title="config.yaml" +policies: + base: + guardrails: + add: + - pii_masking + - toxicity_filter + + strict: + inherit: base + guardrails: + add: + - prompt_injection + + relaxed: + inherit: base + guardrails: + remove: + - toxicity_filter +``` + +What you get: +- `base` → `[pii_masking, toxicity_filter]` +- `strict` → `[pii_masking, toxicity_filter, prompt_injection]` +- `relaxed` → `[pii_masking]` + +## Model Conditions + +Run guardrails only for specific models: + +```yaml showLineNumbers title="config.yaml" +policies: + gpt4-safety: + guardrails: + add: + - strict_content_filter + condition: + model: "gpt-4.*" # regex - matches gpt-4, gpt-4-turbo, gpt-4o + + bedrock-compliance: + guardrails: + add: + - audit_logger + condition: + model: # exact match list + - bedrock/claude-3 + - bedrock/claude-2 +``` + +## Attachments + +Policies don't do anything until you attach them. Attachments tell LiteLLM *where* to apply each policy. + +**Global** - runs on every request: + +```yaml showLineNumbers title="config.yaml" +policy_attachments: + - policy: default + scope: "*" +``` + +**Team-specific** (uses team alias from `/team/new`): + +```yaml showLineNumbers title="config.yaml" +policy_attachments: + - policy: hipaa-compliance + teams: + - healthcare-team # team alias + - medical-research # team alias +``` + +**Key-specific** (uses key alias from `/key/generate`, wildcards supported): + +```yaml showLineNumbers title="config.yaml" +policy_attachments: + - policy: internal-testing + keys: + - "dev-*" # key alias pattern + - "test-*" # key alias pattern +``` + +## Config Reference + +### `policies` + +```yaml +policies: + : + description: ... + inherit: ... + guardrails: + add: [...] + remove: [...] + condition: + model: ... +``` + +| Field | Type | Description | +|-------|------|-------------| +| `description` | `string` | Optional. What this policy does. | +| `inherit` | `string` | Optional. Parent policy to inherit guardrails from. | +| `guardrails.add` | `list[string]` | Guardrails to enable. | +| `guardrails.remove` | `list[string]` | Guardrails to disable (useful with inheritance). | +| `condition.model` | `string` or `list[string]` | Optional. Only apply when model matches. Supports regex. | + +### `policy_attachments` + +```yaml +policy_attachments: + - policy: ... + scope: ... + teams: [...] + keys: [...] +``` + +| Field | Type | Description | +|-------|------|-------------| +| `policy` | `string` | **Required.** Name of the policy to attach. | +| `scope` | `string` | Use `"*"` to apply globally. | +| `teams` | `list[string]` | Team aliases (from `/team/new`). | +| `keys` | `list[string]` | Key aliases (from `/key/generate`). Supports `*` wildcard. | + +### Response Headers + +| Header | Description | +|--------|-------------| +| `x-litellm-applied-policies` | Policies that matched this request | +| `x-litellm-applied-guardrails` | Guardrails that actually ran | + +## How it works + +Example config: + +```yaml showLineNumbers title="config.yaml" +policies: + base: + guardrails: + add: [pii_masking] + + finance-policy: + inherit: base + guardrails: + add: [audit_logger] + +policy_attachments: + - policy: base + scope: "*" + - policy: finance-policy + teams: [finance] +``` + +```mermaid +flowchart TD + A["Request with team_alias='finance'"] --> B["Matches policies: base, finance-policy"] + B --> C["Resolves guardrails: pii_masking, audit_logger"] +``` + +1. Request comes in with `team_alias='finance'` +2. Matches `base` (via `scope: "*"`) and `finance-policy` (via `teams: [finance]`) +3. Resolves guardrails: `base` adds `pii_masking`, `finance-policy` inherits and adds `audit_logger` +4. Final guardrails: `pii_masking`, `audit_logger` diff --git a/docs/my-website/docs/proxy/guardrails/quick_start.md b/docs/my-website/docs/proxy/guardrails/quick_start.md index 4a8dc4e6fe4..cb6379d49f4 100644 --- a/docs/my-website/docs/proxy/guardrails/quick_start.md +++ b/docs/my-website/docs/proxy/guardrails/quick_start.md @@ -203,8 +203,12 @@ Your response headers will include `x-litellm-applied-guardrails` with the guard x-litellm-applied-guardrails: aporia-pre-guard ``` +### Guardrail Policies - +Need more control? Use [Guardrail Policies](./guardrail_policies.md) to: +- Group guardrails into reusable policies +- Enable/disable guardrails for specific teams, keys, or models +- Inherit from existing policies and override specific guardrails ## **Using Guardrails Client Side** diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index ad5019d880a..41502324fc5 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -42,6 +42,7 @@ const sidebars = { label: "Guardrails", items: [ "proxy/guardrails/quick_start", + "proxy/guardrails/guardrail_policies", "proxy/guardrails/guardrail_load_balancing", { type: "category", diff --git a/litellm/__init__.py b/litellm/__init__.py index 9455b32178e..e5c09702b9b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -377,6 +377,9 @@ Dict[str, Union[float, "PriorityReservationDict"]] ] = None # priority_reservation_settings is lazy-loaded via __getattr__ +# Only declare for type checking - at runtime __getattr__ handles it +if TYPE_CHECKING: + priority_reservation_settings: Optional["PriorityReservationSettings"] = None ######## Networking Settings ######## @@ -1273,6 +1276,7 @@ def set_global_gitlab_config(config: Dict[str, Any]) -> None: if TYPE_CHECKING: from litellm.types.utils import ModelInfo as _ModelInfoType + from litellm.types.utils import PriorityReservationSettings from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.caching.caching import Cache diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index ab034d9f51b..6d87e0b5997 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -16094,6 +16094,181 @@ "output_cost_per_token": 0.0, "output_vector_size": 2560 }, + "gmi/anthropic/claude-opus-4.5": { + "input_cost_per_token": 5e-06, + "litellm_provider": "gmi", + "max_input_tokens": 409600, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "output_cost_per_token": 2.5e-05, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/anthropic/claude-sonnet-4.5": { + "input_cost_per_token": 3e-06, + "litellm_provider": "gmi", + "max_input_tokens": 409600, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "output_cost_per_token": 1.5e-05, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/anthropic/claude-sonnet-4": { + "input_cost_per_token": 3e-06, + "litellm_provider": "gmi", + "max_input_tokens": 409600, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "output_cost_per_token": 1.5e-05, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/anthropic/claude-opus-4": { + "input_cost_per_token": 1.5e-05, + "litellm_provider": "gmi", + "max_input_tokens": 409600, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "output_cost_per_token": 7.5e-05, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/openai/gpt-5.2": { + "input_cost_per_token": 1.75e-06, + "litellm_provider": "gmi", + "max_input_tokens": 409600, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "output_cost_per_token": 1.4e-05, + "supports_function_calling": true + }, + "gmi/openai/gpt-5.1": { + "input_cost_per_token": 1.25e-06, + "litellm_provider": "gmi", + "max_input_tokens": 409600, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "output_cost_per_token": 1e-05, + "supports_function_calling": true + }, + "gmi/openai/gpt-5": { + "input_cost_per_token": 1.25e-06, + "litellm_provider": "gmi", + "max_input_tokens": 409600, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "output_cost_per_token": 1e-05, + "supports_function_calling": true + }, + "gmi/openai/gpt-4o": { + "input_cost_per_token": 2.5e-06, + "litellm_provider": "gmi", + "max_input_tokens": 131072, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 1e-05, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/openai/gpt-4o-mini": { + "input_cost_per_token": 1.5e-07, + "litellm_provider": "gmi", + "max_input_tokens": 131072, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 6e-07, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/deepseek-ai/DeepSeek-V3.2": { + "input_cost_per_token": 2.8e-07, + "litellm_provider": "gmi", + "max_input_tokens": 163840, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 4e-07, + "supports_function_calling": true + }, + "gmi/deepseek-ai/DeepSeek-V3-0324": { + "input_cost_per_token": 2.8e-07, + "litellm_provider": "gmi", + "max_input_tokens": 163840, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 8.8e-07, + "supports_function_calling": true + }, + "gmi/google/gemini-3-pro-preview": { + "input_cost_per_token": 2e-06, + "litellm_provider": "gmi", + "max_input_tokens": 1048576, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "output_cost_per_token": 1.2e-05, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/google/gemini-3-flash-preview": { + "input_cost_per_token": 5e-07, + "litellm_provider": "gmi", + "max_input_tokens": 1048576, + "max_output_tokens": 65536, + "max_tokens": 65536, + "mode": "chat", + "output_cost_per_token": 3e-06, + "supports_function_calling": true, + "supports_vision": true + }, + "gmi/moonshotai/Kimi-K2-Thinking": { + "input_cost_per_token": 8e-07, + "litellm_provider": "gmi", + "max_input_tokens": 262144, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 1.2e-06 + }, + "gmi/MiniMaxAI/MiniMax-M2.1": { + "input_cost_per_token": 3e-07, + "litellm_provider": "gmi", + "max_input_tokens": 196608, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 1.2e-06 + }, + "gmi/Qwen/Qwen3-VL-235B-A22B-Instruct-FP8": { + "input_cost_per_token": 3e-07, + "litellm_provider": "gmi", + "max_input_tokens": 262144, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 1.4e-06, + "supports_vision": true + }, + "gmi/zai-org/GLM-4.7-FP8": { + "input_cost_per_token": 4e-07, + "litellm_provider": "gmi", + "max_input_tokens": 202752, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "output_cost_per_token": 2e-06 + }, "google.gemma-3-12b-it": { "input_cost_per_token": 9e-08, "litellm_provider": "bedrock_converse", diff --git a/litellm/proxy/common_utils/callback_utils.py b/litellm/proxy/common_utils/callback_utils.py index cb434da55b3..c0cff84bacd 100644 --- a/litellm/proxy/common_utils/callback_utils.py +++ b/litellm/proxy/common_utils/callback_utils.py @@ -380,6 +380,11 @@ def get_logging_caching_headers(request_data: Dict) -> Optional[Dict]: _metadata["applied_guardrails"] ) + if "applied_policies" in _metadata: + headers["x-litellm-applied-policies"] = ",".join( + _metadata["applied_policies"] + ) + if "semantic-similarity" in _metadata: headers["x-litellm-semantic-similarity"] = str(_metadata["semantic-similarity"]) @@ -406,6 +411,27 @@ def add_guardrail_to_applied_guardrails_header( request_data["metadata"] = _metadata +def add_policy_to_applied_policies_header( + request_data: Dict, policy_name: Optional[str] +): + """ + Add a policy name to the applied_policies list in request metadata. + + This is used to track which policies were applied to a request, + similar to how applied_guardrails tracks guardrails. + """ + if policy_name is None: + return + _metadata = request_data.get("metadata", None) or {} + if "applied_policies" in _metadata: + if policy_name not in _metadata["applied_policies"]: + _metadata["applied_policies"].append(policy_name) + else: + _metadata["applied_policies"] = [policy_name] + # Ensure metadata is set back to request_data (important when metadata didn't exist) + request_data["metadata"] = _metadata + + def add_guardrail_response_to_standard_logging_object( litellm_logging_obj: Optional["LiteLLMLogging"], guardrail_response: StandardLoggingGuardrailInformation, diff --git a/litellm/proxy/hooks/dynamic_rate_limiter_v3.py b/litellm/proxy/hooks/dynamic_rate_limiter_v3.py index a659d62e3eb..5a1d6bec5d5 100644 --- a/litellm/proxy/hooks/dynamic_rate_limiter_v3.py +++ b/litellm/proxy/hooks/dynamic_rate_limiter_v3.py @@ -4,7 +4,7 @@ import os from datetime import datetime -from typing import Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union from fastapi import HTTPException @@ -24,6 +24,25 @@ from litellm.types.router import ModelGroupInfo from litellm.types.utils import CallTypesLiteral +if TYPE_CHECKING: + from litellm.types.utils import PriorityReservationSettings + + +def _get_priority_settings() -> "PriorityReservationSettings": + """ + Get the priority reservation settings, guaranteed to be non-None. + + The settings are lazy-loaded in litellm.__init__ and always return an instance. + This helper provides proper type narrowing for mypy. + """ + settings = litellm.priority_reservation_settings + if settings is None: + # This should never happen due to lazy loading, but satisfy mypy + from litellm.types.utils import PriorityReservationSettings + + return PriorityReservationSettings() + return settings + class _PROXY_DynamicRateLimitHandlerV3(CustomLogger): """ @@ -60,7 +79,7 @@ def update_variables(self, llm_router: Router): def _get_saturation_check_cache_ttl(self) -> int: """Get the configurable TTL for local cache when reading saturation values.""" - return litellm.priority_reservation_settings.saturation_check_cache_ttl + return _get_priority_settings().saturation_check_cache_ttl async def _get_saturation_value_from_cache( self, @@ -91,7 +110,7 @@ def _get_priority_weight( self, priority: Optional[str], model_info: Optional[ModelGroupInfo] = None ) -> float: """Get the weight for a given priority from litellm.priority_reservation""" - weight: float = litellm.priority_reservation_settings.default_priority + weight: float = _get_priority_settings().default_priority if ( litellm.priority_reservation is None or priority not in litellm.priority_reservation @@ -201,7 +220,7 @@ def _get_priority_allocation( priority_key = f"{model}:{priority}" else: # No explicit priority: share the default_priority pool with ALL other default keys - priority_weight = litellm.priority_reservation_settings.default_priority + priority_weight = _get_priority_settings().default_priority # Use shared key for all default-priority requests priority_key = f"{model}:default_pool" @@ -418,9 +437,7 @@ async def _check_rate_limits( """ import json - saturation_threshold = ( - litellm.priority_reservation_settings.saturation_threshold - ) + saturation_threshold = _get_priority_settings().saturation_threshold should_enforce_priority = saturation >= saturation_threshold # Build ALL descriptors upfront @@ -593,9 +610,7 @@ async def async_pre_call_hook( # STEP 1: Check current saturation level saturation = await self._check_model_saturation(model, model_group_info) - saturation_threshold = ( - litellm.priority_reservation_settings.saturation_threshold - ) + saturation_threshold = _get_priority_settings().saturation_threshold verbose_proxy_logger.debug( f"[Dynamic Rate Limiter] Model={model}, Saturation={saturation:.1%}, " diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 1fbd8ee72c2..32cddc0ef58 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1082,13 +1082,20 @@ async def add_litellm_data_to_request( # noqa: PLR0915 if disabled_callbacks and isinstance(disabled_callbacks, list): data["litellm_disabled_callbacks"] = disabled_callbacks - # Guardrails + # Guardrails from key/team metadata move_guardrails_to_metadata( data=data, _metadata_variable_name=_metadata_variable_name, user_api_key_dict=user_api_key_dict, ) + # Guardrails from policy engine + add_guardrails_from_policy_engine( + data=data, + metadata_variable_name=_metadata_variable_name, + user_api_key_dict=user_api_key_dict, + ) + # Team Model Aliases _update_model_if_team_alias_exists( data=data, @@ -1314,6 +1321,7 @@ def move_guardrails_to_metadata( - If guardrails set on API Key metadata then sets guardrails on request metadata - If guardrails not set on API key, then checks request metadata + - Adds guardrails from policy engine based on team/key/model context """ # Check key-level guardrails _add_guardrails_from_key_or_team_metadata( @@ -1323,6 +1331,15 @@ def move_guardrails_to_metadata( metadata_variable_name=_metadata_variable_name, ) + ######################################################################################### + # Add guardrails from policy engine based on team/key/model context + ######################################################################################### + add_guardrails_from_policy_engine( + data=data, + metadata_variable_name=_metadata_variable_name, + user_api_key_dict=user_api_key_dict, + ) + ######################################################################################### # User's might send "guardrails" in the request body, we need to add them to the request metadata. # Since downstream logic requires "guardrails" to be in the request metadata @@ -1351,6 +1368,103 @@ def move_guardrails_to_metadata( ] = request_body_guardrail_config +def add_guardrails_from_policy_engine( + data: dict, + metadata_variable_name: str, + user_api_key_dict: UserAPIKeyAuth, +) -> None: + """ + Add guardrails from the policy engine based on request context. + + This function: + 1. Gets matching policies based on team_alias, key_alias, and model + 2. Resolves guardrails from matching policies (including inheritance) + 3. Adds guardrails to request metadata + 4. Tracks applied policies in metadata for response headers + + Args: + data: The request data to update + metadata_variable_name: The name of the metadata field in data + user_api_key_dict: The user's API key authentication info + """ + from litellm._logging import verbose_proxy_logger + from litellm.proxy.common_utils.callback_utils import ( + add_policy_to_applied_policies_header, + ) + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + from litellm.types.proxy.policy_engine import PolicyMatchContext + + registry = get_policy_registry() + verbose_proxy_logger.debug( + f"Policy engine: registry initialized={registry.is_initialized()}, " + f"policy_count={len(registry.get_all_policies())}" + ) + if not registry.is_initialized(): + verbose_proxy_logger.debug("Policy engine not initialized, skipping policy matching") + return + + # Build context from request + context = PolicyMatchContext( + team_alias=user_api_key_dict.team_alias, + key_alias=user_api_key_dict.key_alias, + model=data.get("model"), + ) + + verbose_proxy_logger.debug( + f"Policy engine: matching policies for context team_alias={context.team_alias}, " + f"key_alias={context.key_alias}, model={context.model}" + ) + + # Get matching policies via attachments + matching_policy_names = PolicyMatcher.get_matching_policies(context=context) + + verbose_proxy_logger.debug(f"Policy engine: matched policies via attachments: {matching_policy_names}") + + if not matching_policy_names: + return + + # Filter to only policies whose conditions match the context + applied_policy_names = PolicyMatcher.get_policies_with_matching_conditions( + policy_names=matching_policy_names, + context=context, + ) + + verbose_proxy_logger.debug(f"Policy engine: applied policies (conditions matched): {applied_policy_names}") + + # Track applied policies in metadata for response headers + for policy_name in applied_policy_names: + add_policy_to_applied_policies_header( + request_data=data, policy_name=policy_name + ) + + # Resolve guardrails from matching policies + resolved_guardrails = PolicyResolver.resolve_guardrails_for_context(context=context) + + verbose_proxy_logger.debug(f"Policy engine: resolved guardrails: {resolved_guardrails}") + + if not resolved_guardrails: + return + + # Add resolved guardrails to request metadata + if metadata_variable_name not in data: + data[metadata_variable_name] = {} + + existing_guardrails = data[metadata_variable_name].get("guardrails", []) + if not isinstance(existing_guardrails, list): + existing_guardrails = [] + + # Combine existing guardrails with policy-resolved guardrails (no duplicates) + combined = set(existing_guardrails) + combined.update(resolved_guardrails) + data[metadata_variable_name]["guardrails"] = list(combined) + + verbose_proxy_logger.debug( + f"Policy engine: added guardrails to request metadata: {list(combined)}" + ) + + def add_provider_specific_headers_to_request( data: dict, headers: dict, diff --git a/litellm/proxy/management_endpoints/policy_endpoints.py b/litellm/proxy/management_endpoints/policy_endpoints.py new file mode 100644 index 00000000000..f9487dc2e59 --- /dev/null +++ b/litellm/proxy/management_endpoints/policy_endpoints.py @@ -0,0 +1,259 @@ +""" +POLICY MANAGEMENT + +All /policy management endpoints + +/policy/validate - Validate a policy configuration +/policy/list - List all loaded policies +/policy/info - Get information about a specific policy +""" + +from fastapi import APIRouter, Depends, HTTPException, Request + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.management_helpers.utils import management_endpoint_wrapper +from litellm.types.proxy.policy_engine import ( + PolicyGuardrailsResponse, + PolicyInfoResponse, + PolicyListResponse, + PolicyMatchContext, + PolicyScopeResponse, + PolicySummaryItem, + PolicyTestResponse, + PolicyValidateRequest, + PolicyValidationResponse, +) + +router = APIRouter() + + +@router.post( + "/policy/validate", + tags=["policy management"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyValidationResponse, +) +@management_endpoint_wrapper +async def validate_policy( + request: Request, + data: PolicyValidateRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> PolicyValidationResponse: + """ + Validate a policy configuration before applying it. + + Checks: + - All referenced guardrails exist in the guardrail registry + - All non-wildcard team aliases exist in the database + - All non-wildcard key aliases exist in the database + - Inheritance chains are valid (no cycles, parents exist) + - Scope patterns are syntactically valid + + Returns: + - valid: True if the policy configuration is valid (no blocking errors) + - errors: List of blocking validation errors + - warnings: List of non-blocking validation warnings + + Example request: + ```json + { + "policies": { + "global-baseline": { + "guardrails": { + "add": ["pii_blocker", "phi_blocker"] + }, + "scope": { + "teams": ["*"], + "keys": ["*"], + "models": ["*"] + } + }, + "healthcare-compliance": { + "inherit": "global-baseline", + "guardrails": { + "add": ["hipaa_audit"] + }, + "scope": { + "teams": ["healthcare-team"] + } + } + } + } + ``` + """ + from litellm.proxy.policy_engine.policy_validator import PolicyValidator + from litellm.proxy.proxy_server import prisma_client + + verbose_proxy_logger.debug( + f"Validating policy configuration with {len(data.policies)} policies" + ) + + validator = PolicyValidator(prisma_client=prisma_client) + + result = await validator.validate_policy_config( + data.policies, + validate_db=prisma_client is not None, + ) + + return result + + +@router.get( + "/policy/list", + tags=["policy management"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyListResponse, +) +@management_endpoint_wrapper +async def list_policies( + request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> PolicyListResponse: + """ + List all loaded policies with their resolved guardrails. + + Returns information about each policy including: + - Inheritance configuration + - Scope (teams, keys, models) + - Guardrails to add/remove + - Resolved guardrails (after inheritance) + - Inheritance chain + """ + from litellm.proxy.policy_engine.init_policies import get_policies_summary + + summary = get_policies_summary() + return PolicyListResponse( + policies={ + name: PolicySummaryItem( + inherit=data.get("inherit"), + scope=PolicyScopeResponse(**data.get("scope", {})), + guardrails=PolicyGuardrailsResponse(**data.get("guardrails", {})), + resolved_guardrails=data.get("resolved_guardrails", []), + inheritance_chain=data.get("inheritance_chain", []), + ) + for name, data in summary.get("policies", {}).items() + }, + total_count=summary.get("total_count", 0), + ) + + +@router.get( + "/policy/info/{policy_name}", + tags=["policy management"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyInfoResponse, +) +@management_endpoint_wrapper +async def get_policy_info( + request: Request, + policy_name: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> PolicyInfoResponse: + """ + Get detailed information about a specific policy. + + Returns: + - Policy configuration + - Resolved guardrails (after inheritance) + - Inheritance chain + """ + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + + registry = get_policy_registry() + + if not registry.is_initialized(): + raise HTTPException( + status_code=404, + detail="Policy engine not initialized. No policies loaded.", + ) + + policy = registry.get_policy(policy_name) + if policy is None: + raise HTTPException( + status_code=404, + detail=f"Policy '{policy_name}' not found", + ) + + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name=policy_name, policies=registry.get_all_policies() + ) + + return PolicyInfoResponse( + policy_name=policy_name, + inherit=policy.inherit, + scope=PolicyScopeResponse( + teams=[], + keys=[], + models=[], + ), + guardrails=PolicyGuardrailsResponse( + add=policy.guardrails.get_add(), + remove=policy.guardrails.get_remove(), + ), + resolved_guardrails=resolved.guardrails, + inheritance_chain=resolved.inheritance_chain, + ) + + +@router.post( + "/policy/test", + tags=["policy management"], + dependencies=[Depends(user_api_key_auth)], + response_model=PolicyTestResponse, +) +@management_endpoint_wrapper +async def test_policy_matching( + request: Request, + context: PolicyMatchContext, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> PolicyTestResponse: + """ + Test which policies would match a given request context. + + This is useful for debugging and understanding policy behavior. + + Request body: + ```json + { + "team_alias": "healthcare-team", + "key_alias": "my-api-key", + "model": "gpt-4" + } + ``` + + Returns: + - matching_policies: List of policy names that match + - resolved_guardrails: Final list of guardrails that would be applied + """ + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + + registry = get_policy_registry() + + if not registry.is_initialized(): + return PolicyTestResponse( + context=context, + matching_policies=[], + resolved_guardrails=[], + message="Policy engine not initialized. No policies loaded.", + ) + + policies = registry.get_all_policies() + + # Get matching policies + matching_policy_names = PolicyMatcher.get_matching_policies(context=context) + + # Resolve guardrails + resolved_guardrails = PolicyResolver.resolve_guardrails_for_context( + context=context, policies=policies + ) + + return PolicyTestResponse( + context=context, + matching_policies=matching_policy_names, + resolved_guardrails=resolved_guardrails, + ) diff --git a/litellm/proxy/policy_engine/__init__.py b/litellm/proxy/policy_engine/__init__.py new file mode 100644 index 00000000000..9ef5fd02f78 --- /dev/null +++ b/litellm/proxy/policy_engine/__init__.py @@ -0,0 +1,60 @@ +""" +LiteLLM Policy Engine + +The Policy Engine allows administrators to define policies that combine guardrails +with scoping rules. Policies can target specific teams, API keys, and models using +wildcard patterns, and support inheritance from base policies. + +Configuration structure: +- `policies`: Define WHAT guardrails to apply (with inheritance and conditions) +- `policy_attachments`: Define WHERE policies apply (teams, keys, models) + +Example: +```yaml +policies: + global-baseline: + description: "Base guardrails for all requests" + guardrails: + add: [pii_blocker] + + gpt4-safety: + inherit: global-baseline + description: "Extra safety for GPT-4" + guardrails: + add: [toxicity_filter] + condition: + model: "gpt-4.*" # regex pattern + +policy_attachments: + - policy: global-baseline + scope: "*" + - policy: gpt4-safety + scope: "*" +``` +""" + +from litellm.proxy.policy_engine.attachment_registry import ( + AttachmentRegistry, + get_attachment_registry, +) +from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator +from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher +from litellm.proxy.policy_engine.policy_registry import ( + PolicyRegistry, + get_policy_registry, +) +from litellm.proxy.policy_engine.policy_resolver import PolicyResolver +from litellm.proxy.policy_engine.policy_validator import PolicyValidator + +__all__ = [ + # Registries + "PolicyRegistry", + "get_policy_registry", + "AttachmentRegistry", + "get_attachment_registry", + # Core components + "PolicyMatcher", + "PolicyResolver", + "PolicyValidator", + "ConditionEvaluator", +] diff --git a/litellm/proxy/policy_engine/architecture.md b/litellm/proxy/policy_engine/architecture.md new file mode 100644 index 00000000000..fa9cbeecf5d --- /dev/null +++ b/litellm/proxy/policy_engine/architecture.md @@ -0,0 +1,54 @@ +# Policy Engine Architecture + +## Overview + +The Policy Engine allows administrators to define policies that combine guardrails with scoping rules. Policies can target specific teams, API keys, and models using wildcard patterns, and support inheritance from base policies. + +## Architecture Diagram + +```mermaid +flowchart TD + subgraph Config["config.yaml"] + PC[policies config] + end + + subgraph PolicyEngine["Policy Engine"] + PR[PolicyRegistry] + PV[PolicyValidator] + PM[PolicyMatcher] + PRe[PolicyResolver] + end + + subgraph Request["Incoming Request"] + CTX[Context: team_alias, key_alias, model] + end + + subgraph Output["Output"] + GR[Guardrails to Apply] + end + + PC -->|load| PR + PC -->|validate| PV + PV -->|errors/warnings| PR + + CTX -->|match| PM + PM -->|matching policies| PRe + PR -->|policies| PM + PR -->|policies| PRe + PRe -->|resolve inheritance + add/remove| GR +``` + +## Components + +| Component | File | Description | +|-----------|------|-------------| +| **PolicyRegistry** | `policy_registry.py` | In-memory singleton store for parsed policies | +| **PolicyValidator** | `policy_validator.py` | Validates configs (guardrails, inheritance, teams/keys/models) | +| **PolicyMatcher** | `policy_matcher.py` | Matches request context against policy scopes | +| **PolicyResolver** | `policy_resolver.py` | Resolves final guardrails via inheritance chain | + +## Flow + +1. **Startup**: `init_policies()` loads policies from config, validates, and populates `PolicyRegistry` +2. **Request**: `PolicyMatcher` finds policies matching the request's team/key/model +3. **Resolution**: `PolicyResolver` traverses inheritance and applies add/remove to get final guardrails diff --git a/litellm/proxy/policy_engine/attachment_registry.py b/litellm/proxy/policy_engine/attachment_registry.py new file mode 100644 index 00000000000..b5d6f2fb745 --- /dev/null +++ b/litellm/proxy/policy_engine/attachment_registry.py @@ -0,0 +1,206 @@ +""" +Attachment Registry - Manages policy attachments from YAML config. + +Attachments define WHERE policies apply, separate from the policy definitions. +This allows the same policy to be attached to multiple scopes. +""" + +from typing import Any, Dict, List, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.types.proxy.policy_engine import ( + PolicyAttachment, + PolicyMatchContext, +) + + +class AttachmentRegistry: + """ + In-memory registry for storing and managing policy attachments. + + Attachments define the relationship between policies and their scopes. + A single policy can have multiple attachments (applied to different scopes). + + Example YAML: + ```yaml + attachments: + - policy: global-baseline + scope: "*" + - policy: healthcare-compliance + teams: [healthcare-team] + - policy: dev-safety + keys: ["dev-key-*"] + ``` + """ + + def __init__(self): + self._attachments: List[PolicyAttachment] = [] + self._initialized: bool = False + + def load_attachments(self, attachments_config: List[Dict[str, Any]]) -> None: + """ + Load attachments from a configuration list. + + Args: + attachments_config: List of attachment dictionaries from YAML. + """ + self._attachments = [] + + for attachment_data in attachments_config: + try: + attachment = self._parse_attachment(attachment_data) + self._attachments.append(attachment) + verbose_proxy_logger.debug( + f"Loaded attachment for policy: {attachment.policy}" + ) + except Exception as e: + verbose_proxy_logger.error( + f"Error loading attachment: {str(e)}" + ) + raise ValueError(f"Invalid attachment: {str(e)}") from e + + self._initialized = True + verbose_proxy_logger.info(f"Loaded {len(self._attachments)} policy attachments") + + def _parse_attachment(self, attachment_data: Dict[str, Any]) -> PolicyAttachment: + """ + Parse an attachment from raw configuration data. + + Args: + attachment_data: Raw attachment configuration + + Returns: + Parsed PolicyAttachment object + """ + return PolicyAttachment( + policy=attachment_data.get("policy", ""), + scope=attachment_data.get("scope"), + teams=attachment_data.get("teams"), + keys=attachment_data.get("keys"), + models=attachment_data.get("models"), + ) + + def get_attached_policies(self, context: PolicyMatchContext) -> List[str]: + """ + Get list of policy names attached to the given context. + + Args: + context: The request context to match against + + Returns: + List of policy names that are attached to matching scopes + """ + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + + attached_policies: List[str] = [] + + for attachment in self._attachments: + scope = attachment.to_policy_scope() + if PolicyMatcher.scope_matches(scope=scope, context=context): + if attachment.policy not in attached_policies: + attached_policies.append(attachment.policy) + verbose_proxy_logger.debug( + f"Attachment matched: policy={attachment.policy}, " + f"context=(team={context.team_alias}, key={context.key_alias}, model={context.model})" + ) + + return attached_policies + + def is_policy_attached( + self, policy_name: str, context: PolicyMatchContext + ) -> bool: + """ + Check if a specific policy is attached to the given context. + + Args: + policy_name: Name of the policy to check + context: The request context to match against + + Returns: + True if the policy is attached to a matching scope + """ + attached = self.get_attached_policies(context) + return policy_name in attached + + def get_all_attachments(self) -> List[PolicyAttachment]: + """ + Get all loaded attachments. + + Returns: + List of all PolicyAttachment objects + """ + return self._attachments.copy() + + def get_attachments_for_policy(self, policy_name: str) -> List[PolicyAttachment]: + """ + Get all attachments for a specific policy. + + Args: + policy_name: Name of the policy + + Returns: + List of attachments for the policy + """ + return [a for a in self._attachments if a.policy == policy_name] + + def is_initialized(self) -> bool: + """ + Check if the registry has been initialized with attachments. + + Returns: + True if attachments have been loaded, False otherwise + """ + return self._initialized + + def clear(self) -> None: + """ + Clear all attachments from the registry. + """ + self._attachments = [] + self._initialized = False + + def add_attachment(self, attachment: PolicyAttachment) -> None: + """ + Add a single attachment. + + Args: + attachment: PolicyAttachment object to add + """ + self._attachments.append(attachment) + verbose_proxy_logger.debug(f"Added attachment for policy: {attachment.policy}") + + def remove_attachments_for_policy(self, policy_name: str) -> int: + """ + Remove all attachments for a specific policy. + + Args: + policy_name: Name of the policy + + Returns: + Number of attachments removed + """ + original_count = len(self._attachments) + self._attachments = [a for a in self._attachments if a.policy != policy_name] + removed_count = original_count - len(self._attachments) + if removed_count > 0: + verbose_proxy_logger.debug( + f"Removed {removed_count} attachment(s) for policy: {policy_name}" + ) + return removed_count + + +# Global singleton instance +_attachment_registry: Optional[AttachmentRegistry] = None + + +def get_attachment_registry() -> AttachmentRegistry: + """ + Get the global AttachmentRegistry singleton. + + Returns: + The global AttachmentRegistry instance + """ + global _attachment_registry + if _attachment_registry is None: + _attachment_registry = AttachmentRegistry() + return _attachment_registry diff --git a/litellm/proxy/policy_engine/condition_evaluator.py b/litellm/proxy/policy_engine/condition_evaluator.py new file mode 100644 index 00000000000..1f1dea15a1d --- /dev/null +++ b/litellm/proxy/policy_engine/condition_evaluator.py @@ -0,0 +1,111 @@ +""" +Condition Evaluator - Evaluates policy conditions. + +Supports model-based conditions with exact match or regex patterns. +""" + +import re +from typing import List, Optional, Union + +from litellm._logging import verbose_proxy_logger +from litellm.types.proxy.policy_engine import ( + PolicyCondition, + PolicyMatchContext, +) + + +class ConditionEvaluator: + """ + Evaluates policy conditions against request context. + + Supports model conditions with: + - Exact string match: "gpt-4" + - Regex pattern: "gpt-4.*" + - List of values: ["gpt-4", "gpt-4-turbo"] + """ + + @staticmethod + def evaluate( + condition: Optional[PolicyCondition], + context: PolicyMatchContext, + ) -> bool: + """ + Evaluate a policy condition against a request context. + + Args: + condition: The condition to evaluate (None = always matches) + context: The request context with team, key, model + + Returns: + True if condition matches, False otherwise + """ + # No condition means always matches + if condition is None: + return True + + # Check model condition + if condition.model is not None: + if not ConditionEvaluator._evaluate_model_condition( + condition=condition.model, + model=context.model, + ): + verbose_proxy_logger.debug( + f"Condition failed: model={context.model} did not match {condition.model}" + ) + return False + + return True + + @staticmethod + def _evaluate_model_condition( + condition: Union[str, List[str]], + model: Optional[str], + ) -> bool: + """ + Evaluate a model condition. + + Args: + condition: String (exact or regex) or list of strings + model: The model name to check + + Returns: + True if model matches condition, False otherwise + """ + if model is None: + return False + + # Handle list of values + if isinstance(condition, list): + return any( + ConditionEvaluator._matches_pattern(pattern, model) + for pattern in condition + ) + + # Single value - check as pattern + return ConditionEvaluator._matches_pattern(condition, model) + + @staticmethod + def _matches_pattern(pattern: str, value: str) -> bool: + """ + Check if value matches pattern (exact match or regex). + + Args: + pattern: Pattern to match (exact string or regex) + value: Value to check + + Returns: + True if matches, False otherwise + """ + # First try exact match + if pattern == value: + return True + + # Try as regex pattern + try: + if re.fullmatch(pattern, value): + return True + except re.error: + # Invalid regex, treat as literal string (already checked above) + pass + + return False diff --git a/litellm/proxy/policy_engine/init_policies.py b/litellm/proxy/policy_engine/init_policies.py new file mode 100644 index 00000000000..b734c0cb5cc --- /dev/null +++ b/litellm/proxy/policy_engine/init_policies.py @@ -0,0 +1,276 @@ +""" +Policy Initialization - Loads policies from config and validates on startup. + +Configuration structure: +- policies: Define WHAT guardrails to apply (with inheritance and conditions) +- policy_attachments: Define WHERE policies apply (teams, keys, models) +""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry +from litellm.proxy.policy_engine.policy_registry import get_policy_registry +from litellm.proxy.policy_engine.policy_validator import PolicyValidator +from litellm.types.proxy.policy_engine import PolicyValidationResponse + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient + +# ANSI color codes for terminal output +_green_color_code = "\033[92m" +_blue_color_code = "\033[94m" +_yellow_color_code = "\033[93m" +_reset_color_code = "\033[0m" + + +def _print_policies_on_startup( + policies_config: Dict[str, Any], + policy_attachments_config: Optional[List[Dict[str, Any]]] = None, +) -> None: + """ + Print loaded policies to console on startup (similar to model list). + """ + import sys + + print( # noqa: T201 + f"{_green_color_code}\nLiteLLM Policy Engine: Loaded {len(policies_config)} policies{_reset_color_code}\n" + ) + sys.stdout.flush() + + for policy_name, policy_data in policies_config.items(): + guardrails = policy_data.get("guardrails", {}) + inherit = policy_data.get("inherit") + condition = policy_data.get("condition") + description = policy_data.get("description") + + guardrails_add = guardrails.get("add", []) if isinstance(guardrails, dict) else [] + guardrails_remove = guardrails.get("remove", []) if isinstance(guardrails, dict) else [] + inherit_str = f" (inherits: {inherit})" if inherit else "" + + print( # noqa: T201 + f"{_blue_color_code} - {policy_name}{inherit_str}{_reset_color_code}" + ) + if description: + print(f" description: {description}") # noqa: T201 + if guardrails_add: + print(f" guardrails.add: {guardrails_add}") # noqa: T201 + if guardrails_remove: + print(f" guardrails.remove: {guardrails_remove}") # noqa: T201 + if condition: + model_condition = condition.get("model") if isinstance(condition, dict) else None + if model_condition: + print(f" condition.model: {model_condition}") # noqa: T201 + + # Print attachments + if policy_attachments_config: + print( # noqa: T201 + f"\n{_yellow_color_code}Policy Attachments: {len(policy_attachments_config)} attachment(s){_reset_color_code}" + ) + for attachment in policy_attachments_config: + policy = attachment.get("policy", "unknown") + scope = attachment.get("scope") + teams = attachment.get("teams") + keys = attachment.get("keys") + models = attachment.get("models") + + scope_parts = [] + if scope == "*": + scope_parts.append("scope=* (global)") + if teams: + scope_parts.append(f"teams={teams}") + if keys: + scope_parts.append(f"keys={keys}") + if models: + scope_parts.append(f"models={models}") + scope_str = ", ".join(scope_parts) if scope_parts else "all" + + print(f" - {policy} -> {scope_str}") # noqa: T201 + else: + print( # noqa: T201 + f"\n{_yellow_color_code}Warning: No policy_attachments configured. Policies will not be applied to any requests.{_reset_color_code}" + ) + + print() # noqa: T201 + sys.stdout.flush() + + +async def init_policies( + policies_config: Dict[str, Any], + policy_attachments_config: Optional[List[Dict[str, Any]]] = None, + prisma_client: Optional["PrismaClient"] = None, + validate_db: bool = True, + fail_on_error: bool = True, +) -> PolicyValidationResponse: + """ + Initialize policies from configuration. + + This function: + 1. Parses the policy configuration + 2. Validates policies (guardrails exist, teams/keys exist in DB) + 3. Loads policies into the global registry + 4. Loads attachments into the attachment registry (if provided) + + Args: + policies_config: Dictionary mapping policy names to policy definitions + policy_attachments_config: Optional list of policy attachment configurations + prisma_client: Optional Prisma client for database validation + validate_db: Whether to validate team/key aliases against database + fail_on_error: If True, raise exception on validation errors + + Returns: + PolicyValidationResponse with validation results + + Raises: + ValueError: If fail_on_error is True and validation errors are found + """ + verbose_proxy_logger.info(f"Initializing {len(policies_config)} policies...") + + # Print policies to console on startup + _print_policies_on_startup(policies_config, policy_attachments_config) + + # Get the global registries + policy_registry = get_policy_registry() + attachment_registry = get_attachment_registry() + + # Create validator + validator = PolicyValidator(prisma_client=prisma_client) + + # Validate the configuration + validation_result = await validator.validate_policy_config( + policies_config, + validate_db=validate_db, + ) + + # Log validation results + if validation_result.errors: + for error in validation_result.errors: + verbose_proxy_logger.error( + f"Policy validation error in '{error.policy_name}': " + f"[{error.error_type}] {error.message}" + ) + + if validation_result.warnings: + for warning in validation_result.warnings: + verbose_proxy_logger.warning( + f"Policy validation warning in '{warning.policy_name}': " + f"[{warning.error_type}] {warning.message}" + ) + + # Fail if there are errors and fail_on_error is True + if not validation_result.valid and fail_on_error: + error_messages = [ + f"[{e.policy_name}] {e.message}" for e in validation_result.errors + ] + raise ValueError( + f"Policy validation failed with {len(validation_result.errors)} error(s):\n" + + "\n".join(error_messages) + ) + + # Load policies into registry (even with warnings) + try: + policy_registry.load_policies(policies_config) + verbose_proxy_logger.info( + f"Successfully loaded {len(policies_config)} policies" + ) + except Exception as e: + verbose_proxy_logger.error(f"Failed to load policies: {str(e)}") + raise + + # Load attachments if provided + if policy_attachments_config: + try: + attachment_registry.load_attachments(policy_attachments_config) + verbose_proxy_logger.info( + f"Successfully loaded {len(policy_attachments_config)} policy attachments" + ) + except Exception as e: + verbose_proxy_logger.error(f"Failed to load policy attachments: {str(e)}") + raise + + return validation_result + + +def init_policies_sync( + policies_config: Dict[str, Any], + policy_attachments_config: Optional[List[Dict[str, Any]]] = None, + fail_on_error: bool = True, +) -> None: + """ + Synchronous version of init_policies (without DB validation). + + Use this when async is not available or DB validation is not needed. + + Args: + policies_config: Dictionary mapping policy names to policy definitions + policy_attachments_config: Optional list of policy attachment configurations + fail_on_error: If True, raise exception on validation errors + """ + import asyncio + + # Run the async function without DB validation + try: + loop = asyncio.get_event_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + loop.run_until_complete( + init_policies( + policies_config=policies_config, + policy_attachments_config=policy_attachments_config, + prisma_client=None, + validate_db=False, + fail_on_error=fail_on_error, + ) + ) + + +def get_policies_summary() -> Dict[str, Any]: + """ + Get a summary of loaded policies for debugging/display. + + Returns: + Dictionary with policy information + """ + from litellm.proxy.policy_engine.policy_resolver import PolicyResolver + + policy_registry = get_policy_registry() + attachment_registry = get_attachment_registry() + + if not policy_registry.is_initialized(): + return {"initialized": False, "policies": {}, "attachments": []} + + resolved = PolicyResolver.get_all_resolved_policies() + + summary: Dict[str, Any] = { + "initialized": True, + "policy_count": len(resolved), + "attachment_count": len(attachment_registry.get_all_attachments()), + "policies": {}, + "attachments": [], + } + + for policy_name, resolved_policy in resolved.items(): + policy = policy_registry.get_policy(policy_name) + summary["policies"][policy_name] = { + "inherit": policy.inherit if policy else None, + "description": policy.description if policy else None, + "guardrails_add": policy.guardrails.get_add() if policy else [], + "guardrails_remove": policy.guardrails.get_remove() if policy else [], + "condition": policy.condition.model_dump() if policy and policy.condition else None, + "resolved_guardrails": resolved_policy.guardrails, + "inheritance_chain": resolved_policy.inheritance_chain, + } + + # Add attachment info + for attachment in attachment_registry.get_all_attachments(): + summary["attachments"].append({ + "policy": attachment.policy, + "scope": attachment.scope, + "teams": attachment.teams, + "keys": attachment.keys, + "models": attachment.models, + }) + + return summary diff --git a/litellm/proxy/policy_engine/policy_matcher.py b/litellm/proxy/policy_engine/policy_matcher.py new file mode 100644 index 00000000000..ab73970bfab --- /dev/null +++ b/litellm/proxy/policy_engine/policy_matcher.py @@ -0,0 +1,168 @@ +""" +Policy Matcher - Matches requests against policy attachments. + +Uses existing wildcard pattern matching helpers to determine which policies +apply to a given request based on team alias, key alias, and model. + +Policies are matched via policy_attachments which define WHERE each policy applies. +""" + +from typing import Dict, List, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.route_checks import RouteChecks +from litellm.types.proxy.policy_engine import Policy, PolicyMatchContext, PolicyScope + + +class PolicyMatcher: + """ + Matches incoming requests against policy attachments. + + Supports wildcard patterns: + - "*" matches everything + - "prefix-*" matches anything starting with "prefix-" + + Uses policy_attachments to determine which policies apply to a request. + """ + + @staticmethod + def matches_pattern(value: Optional[str], patterns: List[str]) -> bool: + """ + Check if a value matches any of the given patterns. + + Uses the existing RouteChecks._route_matches_wildcard_pattern helper. + + Args: + value: The value to check (e.g., team alias, key alias, model) + patterns: List of patterns to match against + + Returns: + True if value matches any pattern, False otherwise + """ + # If no value provided, only match if patterns include "*" + if value is None: + return "*" in patterns + + for pattern in patterns: + # Use existing wildcard pattern matching helper + if RouteChecks._route_matches_wildcard_pattern( + route=value, pattern=pattern + ): + return True + + return False + + @staticmethod + def scope_matches(scope: PolicyScope, context: PolicyMatchContext) -> bool: + """ + Check if a policy scope matches the given context. + + A scope matches if ALL of its fields match: + - teams matches context.team_alias + - keys matches context.key_alias + - models matches context.model + + Args: + scope: The policy scope to check + context: The request context + + Returns: + True if scope matches context, False otherwise + """ + # Check teams + if not PolicyMatcher.matches_pattern(context.team_alias, scope.get_teams()): + return False + + # Check keys + if not PolicyMatcher.matches_pattern(context.key_alias, scope.get_keys()): + return False + + # Check models + if not PolicyMatcher.matches_pattern(context.model, scope.get_models()): + return False + + return True + + @staticmethod + def get_matching_policies( + context: PolicyMatchContext, + ) -> List[str]: + """ + Get list of policy names that match the given context via attachments. + + Args: + context: The request context to match against + + Returns: + List of policy names that match the context + """ + from litellm.proxy.policy_engine.attachment_registry import ( + get_attachment_registry, + ) + + registry = get_attachment_registry() + if not registry.is_initialized(): + verbose_proxy_logger.debug( + "AttachmentRegistry not initialized, returning empty list" + ) + return [] + + return registry.get_attached_policies(context) + + @staticmethod + def get_matching_policies_from_registry( + context: PolicyMatchContext, + ) -> List[str]: + """ + Get list of policy names that match the given context from the global registry. + + Args: + context: The request context to match against + + Returns: + List of policy names that match the context + """ + return PolicyMatcher.get_matching_policies(context=context) + + @staticmethod + def get_policies_with_matching_conditions( + policy_names: List[str], + context: PolicyMatchContext, + policies: Optional[Dict[str, Policy]] = None, + ) -> List[str]: + """ + Filter policies to only those whose conditions match the context. + + A policy's condition matches if: + - The policy has no condition (condition is None), OR + - The policy's condition evaluates to True for the given context + + Args: + policy_names: List of policy names to filter + context: The request context to evaluate conditions against + policies: Dictionary of all policies (if None, uses global registry) + + Returns: + List of policy names whose conditions match the context + """ + from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + + if policies is None: + registry = get_policy_registry() + if not registry.is_initialized(): + return [] + policies = registry.get_all_policies() + + matching_policies = [] + for policy_name in policy_names: + policy = policies.get(policy_name) + if policy is None: + continue + # Policy matches if it has no condition OR condition evaluates to True + if policy.condition is None or ConditionEvaluator.evaluate( + policy.condition, context + ): + matching_policies.append(policy_name) + + return matching_policies diff --git a/litellm/proxy/policy_engine/policy_registry.py b/litellm/proxy/policy_engine/policy_registry.py new file mode 100644 index 00000000000..68485f92489 --- /dev/null +++ b/litellm/proxy/policy_engine/policy_registry.py @@ -0,0 +1,196 @@ +""" +Policy Registry - In-memory storage for policies. + +Handles storing, retrieving, and managing policies. + +Policies define WHAT guardrails to apply. WHERE they apply is defined +by policy_attachments (see AttachmentRegistry). +""" + +from typing import Any, Dict, List, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.types.proxy.policy_engine import ( + Policy, + PolicyCondition, + PolicyGuardrails, +) + + +class PolicyRegistry: + """ + In-memory registry for storing and managing policies. + + This is a singleton that holds all loaded policies and provides + methods to access them. + + Policies define WHAT guardrails to apply: + - Base guardrails via guardrails.add/remove + - Inheritance via inherit field + - Conditional guardrails via condition.model + """ + + def __init__(self): + self._policies: Dict[str, Policy] = {} + self._initialized: bool = False + + def load_policies(self, policies_config: Dict[str, Any]) -> None: + """ + Load policies from a configuration dictionary. + + Args: + policies_config: Dictionary mapping policy names to policy definitions. + This is the raw config from the YAML file. + """ + self._policies = {} + + for policy_name, policy_data in policies_config.items(): + try: + policy = self._parse_policy(policy_name, policy_data) + self._policies[policy_name] = policy + verbose_proxy_logger.debug(f"Loaded policy: {policy_name}") + except Exception as e: + verbose_proxy_logger.error( + f"Error loading policy '{policy_name}': {str(e)}" + ) + raise ValueError(f"Invalid policy '{policy_name}': {str(e)}") from e + + self._initialized = True + verbose_proxy_logger.info(f"Loaded {len(self._policies)} policies") + + def _parse_policy(self, policy_name: str, policy_data: Dict[str, Any]) -> Policy: + """ + Parse a policy from raw configuration data. + + Args: + policy_name: Name of the policy + policy_data: Raw policy configuration + + Returns: + Parsed Policy object + """ + # Parse guardrails + guardrails_data = policy_data.get("guardrails", {}) + if isinstance(guardrails_data, dict): + guardrails = PolicyGuardrails( + add=guardrails_data.get("add"), + remove=guardrails_data.get("remove"), + ) + else: + # Handle legacy format where guardrails might be a list + guardrails = PolicyGuardrails(add=guardrails_data if guardrails_data else None) + + # Parse condition (simple model-based condition) + condition = None + condition_data = policy_data.get("condition") + if condition_data: + condition = PolicyCondition(model=condition_data.get("model")) + + return Policy( + inherit=policy_data.get("inherit"), + description=policy_data.get("description"), + guardrails=guardrails, + condition=condition, + ) + + def get_policy(self, policy_name: str) -> Optional[Policy]: + """ + Get a policy by name. + + Args: + policy_name: Name of the policy to retrieve + + Returns: + Policy object if found, None otherwise + """ + return self._policies.get(policy_name) + + def get_all_policies(self) -> Dict[str, Policy]: + """ + Get all loaded policies. + + Returns: + Dictionary mapping policy names to Policy objects + """ + return self._policies.copy() + + def get_policy_names(self) -> List[str]: + """ + Get list of all policy names. + + Returns: + List of policy names + """ + return list(self._policies.keys()) + + def has_policy(self, policy_name: str) -> bool: + """ + Check if a policy exists. + + Args: + policy_name: Name of the policy to check + + Returns: + True if policy exists, False otherwise + """ + return policy_name in self._policies + + def is_initialized(self) -> bool: + """ + Check if the registry has been initialized with policies. + + Returns: + True if policies have been loaded, False otherwise + """ + return self._initialized + + def clear(self) -> None: + """ + Clear all policies from the registry. + """ + self._policies = {} + self._initialized = False + + def add_policy(self, policy_name: str, policy: Policy) -> None: + """ + Add or update a single policy. + + Args: + policy_name: Name of the policy + policy: Policy object to add + """ + self._policies[policy_name] = policy + verbose_proxy_logger.debug(f"Added/updated policy: {policy_name}") + + def remove_policy(self, policy_name: str) -> bool: + """ + Remove a policy by name. + + Args: + policy_name: Name of the policy to remove + + Returns: + True if policy was removed, False if it didn't exist + """ + if policy_name in self._policies: + del self._policies[policy_name] + verbose_proxy_logger.debug(f"Removed policy: {policy_name}") + return True + return False + + +# Global singleton instance +_policy_registry: Optional[PolicyRegistry] = None + + +def get_policy_registry() -> PolicyRegistry: + """ + Get the global PolicyRegistry singleton. + + Returns: + The global PolicyRegistry instance + """ + global _policy_registry + if _policy_registry is None: + _policy_registry = PolicyRegistry() + return _policy_registry diff --git a/litellm/proxy/policy_engine/policy_resolver.py b/litellm/proxy/policy_engine/policy_resolver.py new file mode 100644 index 00000000000..cfdedc467d8 --- /dev/null +++ b/litellm/proxy/policy_engine/policy_resolver.py @@ -0,0 +1,227 @@ +""" +Policy Resolver - Resolves final guardrail list from policies. + +Handles: +- Inheritance chain resolution (inherit with add/remove) +- Applying add/remove guardrails +- Evaluating model conditions +- Combining guardrails from multiple matching policies +""" + +from typing import Dict, List, Optional, Set + +from litellm._logging import verbose_proxy_logger +from litellm.types.proxy.policy_engine import ( + Policy, + PolicyMatchContext, + ResolvedPolicy, +) + + +class PolicyResolver: + """ + Resolves the final list of guardrails from policies. + + Handles: + - Inheritance chains with add/remove operations + - Model-based conditions + """ + + @staticmethod + def resolve_inheritance_chain( + policy_name: str, + policies: Dict[str, Policy], + visited: Optional[Set[str]] = None, + ) -> List[str]: + """ + Get the inheritance chain for a policy (from root to policy). + + Args: + policy_name: Name of the policy + policies: Dictionary of all policies + visited: Set of visited policies (for cycle detection) + + Returns: + List of policy names from root ancestor to the given policy + """ + if visited is None: + visited = set() + + if policy_name in visited: + verbose_proxy_logger.warning( + f"Circular inheritance detected for policy '{policy_name}'" + ) + return [] + + policy = policies.get(policy_name) + if policy is None: + return [] + + visited.add(policy_name) + + if policy.inherit: + parent_chain = PolicyResolver.resolve_inheritance_chain( + policy_name=policy.inherit, policies=policies, visited=visited + ) + return parent_chain + [policy_name] + + return [policy_name] + + @staticmethod + def resolve_policy_guardrails( + policy_name: str, + policies: Dict[str, Policy], + context: Optional[PolicyMatchContext] = None, + ) -> ResolvedPolicy: + """ + Resolve the final guardrails for a single policy, including inheritance. + + This method: + 1. Resolves the inheritance chain + 2. Applies add/remove from each policy in the chain + 3. Evaluates model conditions (if context provided) + + Args: + policy_name: Name of the policy to resolve + policies: Dictionary of all policies + context: Optional request context for evaluating conditions + + Returns: + ResolvedPolicy with final guardrails list + """ + from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator + + inheritance_chain = PolicyResolver.resolve_inheritance_chain( + policy_name=policy_name, policies=policies + ) + + # Start with empty set of guardrails + guardrails: Set[str] = set() + + # Apply each policy in the chain (from root to leaf) + for chain_policy_name in inheritance_chain: + policy = policies.get(chain_policy_name) + if policy is None: + continue + + # Check if policy condition matches (if context provided) + if context is not None and policy.condition is not None: + if not ConditionEvaluator.evaluate( + condition=policy.condition, + context=context, + ): + verbose_proxy_logger.debug( + f"Policy '{chain_policy_name}' condition did not match, skipping guardrails" + ) + continue + + # Add guardrails from guardrails.add + for guardrail in policy.guardrails.get_add(): + guardrails.add(guardrail) + + # Remove guardrails from guardrails.remove + for guardrail in policy.guardrails.get_remove(): + guardrails.discard(guardrail) + + return ResolvedPolicy( + policy_name=policy_name, + guardrails=list(guardrails), + inheritance_chain=inheritance_chain, + ) + + @staticmethod + def resolve_guardrails_for_context( + context: PolicyMatchContext, + policies: Optional[Dict[str, Policy]] = None, + ) -> List[str]: + """ + Resolve the final list of guardrails for a request context. + + This: + 1. Finds all policies that match the context via policy_attachments + 2. Resolves each policy's guardrails (including inheritance) + 3. Evaluates model conditions + 4. Combines all guardrails (union) + + Args: + context: The request context + policies: Dictionary of all policies (if None, uses global registry) + + Returns: + List of guardrail names to apply + """ + from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + + if policies is None: + registry = get_policy_registry() + if not registry.is_initialized(): + return [] + policies = registry.get_all_policies() + + # Get matching policies via attachments + matching_policy_names = PolicyMatcher.get_matching_policies(context=context) + + if not matching_policy_names: + verbose_proxy_logger.debug( + f"No policies match context: team_alias={context.team_alias}, " + f"key_alias={context.key_alias}, model={context.model}" + ) + return [] + + # Resolve each matching policy and combine guardrails + all_guardrails: Set[str] = set() + + for policy_name in matching_policy_names: + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name=policy_name, + policies=policies, + context=context, + ) + all_guardrails.update(resolved.guardrails) + verbose_proxy_logger.debug( + f"Policy '{policy_name}' contributes guardrails: {resolved.guardrails}" + ) + + result = list(all_guardrails) + verbose_proxy_logger.debug( + f"Final guardrails for context: {result}" + ) + + return result + + @staticmethod + def get_all_resolved_policies( + policies: Optional[Dict[str, Policy]] = None, + context: Optional[PolicyMatchContext] = None, + ) -> Dict[str, ResolvedPolicy]: + """ + Resolve all policies and return their final guardrails. + + Useful for debugging and displaying policy configurations. + + Args: + policies: Dictionary of all policies (if None, uses global registry) + context: Optional context for evaluating conditions + + Returns: + Dictionary mapping policy names to ResolvedPolicy objects + """ + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + + if policies is None: + registry = get_policy_registry() + if not registry.is_initialized(): + return {} + policies = registry.get_all_policies() + + resolved: Dict[str, ResolvedPolicy] = {} + + for policy_name in policies: + resolved[policy_name] = PolicyResolver.resolve_policy_guardrails( + policy_name=policy_name, + policies=policies, + context=context, + ) + + return resolved diff --git a/litellm/proxy/policy_engine/policy_validator.py b/litellm/proxy/policy_engine/policy_validator.py new file mode 100644 index 00000000000..47787655cba --- /dev/null +++ b/litellm/proxy/policy_engine/policy_validator.py @@ -0,0 +1,348 @@ +""" +Policy Validator - Validates policy configurations. + +Validates: +- Guardrail names exist in the guardrail registry +- Non-wildcard team aliases exist in the database +- Non-wildcard key aliases exist in the database +- Non-wildcard model names exist in the router or match a wildcard route +- Inheritance chains are valid (no cycles, parents exist) +""" + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set + +from litellm._logging import verbose_proxy_logger +from litellm.types.proxy.policy_engine import ( + Policy, + PolicyValidationError, + PolicyValidationErrorType, + PolicyValidationResponse, +) + +if TYPE_CHECKING: + from litellm.proxy.utils import PrismaClient + from litellm.router import Router + + +class PolicyValidator: + """ + Validates policy configurations against actual data. + """ + + def __init__( + self, + prisma_client: Optional["PrismaClient"] = None, + llm_router: Optional["Router"] = None, + ): + """ + Initialize the validator. + + Args: + prisma_client: Optional Prisma client for database validation + llm_router: Optional LLM router for model validation + """ + self.prisma_client = prisma_client + self.llm_router = llm_router + + @staticmethod + def is_wildcard_pattern(pattern: str) -> bool: + """ + Check if a pattern contains wildcards. + + Args: + pattern: The pattern to check + + Returns: + True if the pattern contains wildcard characters + """ + return "*" in pattern or "?" in pattern + + def get_available_guardrails(self) -> Set[str]: + """ + Get set of available guardrail names from the guardrail registry. + + Returns: + Set of guardrail names + """ + try: + from litellm.proxy.guardrails.guardrail_registry import ( + IN_MEMORY_GUARDRAIL_HANDLER, + ) + + guardrails = IN_MEMORY_GUARDRAIL_HANDLER.list_in_memory_guardrails() + return {g.get("guardrail_name", "") for g in guardrails if g.get("guardrail_name")} + except Exception as e: + verbose_proxy_logger.warning( + f"Could not get guardrails from registry: {str(e)}" + ) + return set() + + async def check_team_alias_exists(self, team_alias: str) -> bool: + """ + Check if a specific team alias exists in the database. + + Args: + team_alias: The team alias to check + + Returns: + True if the team alias exists + """ + if self.prisma_client is None: + return True # Can't validate without DB, assume valid + + try: + team = await self.prisma_client.db.litellm_teamtable.find_first( + where={"team_alias": team_alias}, + ) + return team is not None + except Exception as e: + verbose_proxy_logger.warning( + f"Could not check team alias '{team_alias}': {str(e)}" + ) + return True # Assume valid on error + + async def check_key_alias_exists(self, key_alias: str) -> bool: + """ + Check if a specific key alias exists in the database. + + Args: + key_alias: The key alias to check + + Returns: + True if the key alias exists + """ + if self.prisma_client is None: + return True # Can't validate without DB, assume valid + + try: + key = await self.prisma_client.db.litellm_verificationtoken.find_first( + where={"key_alias": key_alias}, + ) + return key is not None + except Exception as e: + verbose_proxy_logger.warning( + f"Could not check key alias '{key_alias}': {str(e)}" + ) + return True # Assume valid on error + + def check_model_exists(self, model: str) -> bool: + """ + Check if a model exists in the router or matches a wildcard pattern. + + Args: + model: The model name to check + + Returns: + True if the model exists or matches a pattern in the router + """ + if self.llm_router is None: + return True # Can't validate without router, assume valid + + try: + # Check if model is in router's model names + if model in self.llm_router.model_names: + return True + + # Check if model matches any pattern via pattern router + if hasattr(self.llm_router, "pattern_router"): + pattern_deployments = self.llm_router.pattern_router.get_deployments_by_pattern( + model=model + ) + if pattern_deployments: + return True + + return False + except Exception as e: + verbose_proxy_logger.warning( + f"Could not check model '{model}': {str(e)}" + ) + return True # Assume valid on error + + def _validate_inheritance_chain( + self, + policy_name: str, + policies: Dict[str, Policy], + visited: Optional[Set[str]] = None, + max_depth: int = 100, + ) -> List[PolicyValidationError]: + """ + Validate the inheritance chain for a policy. + + Checks for: + - Parent policy exists + - No circular inheritance + - Max depth not exceeded + + Args: + policy_name: Name of the policy to validate + policies: All policies + visited: Set of already visited policy names (for cycle detection) + max_depth: Maximum recursion depth to prevent infinite loops + + Returns: + List of validation errors + """ + errors: List[PolicyValidationError] = [] + + # Prevent infinite recursion + if max_depth <= 0: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.CIRCULAR_INHERITANCE, + message=f"Inheritance chain too deep (exceeded max depth of 100)", + field="inherit", + ) + ) + return errors + + if visited is None: + visited = set() + + if policy_name in visited: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.CIRCULAR_INHERITANCE, + message=f"Circular inheritance detected: {' -> '.join(visited)} -> {policy_name}", + field="inherit", + ) + ) + return errors + + policy = policies.get(policy_name) + if policy is None: + return errors + + if policy.inherit: + if policy.inherit not in policies: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_INHERITANCE, + message=f"Parent policy '{policy.inherit}' not found", + field="inherit", + value=policy.inherit, + ) + ) + else: + # Recursively check parent with decremented depth + visited.add(policy_name) + errors.extend( + self._validate_inheritance_chain( + policy.inherit, policies, visited, max_depth - 1 + ) + ) + + return errors + + async def validate_policies( + self, + policies: Dict[str, Policy], + validate_db: bool = True, + ) -> PolicyValidationResponse: + """ + Validate a set of policies. + + Args: + policies: Dictionary mapping policy names to Policy objects + validate_db: Whether to validate against database (teams, keys) + + Returns: + PolicyValidationResponse with errors and warnings + """ + errors: List[PolicyValidationError] = [] + warnings: List[PolicyValidationError] = [] + + # Get available guardrails + available_guardrails = self.get_available_guardrails() + + for policy_name, policy in policies.items(): + # Validate guardrails + for guardrail in policy.guardrails.get_add(): + if available_guardrails and guardrail not in available_guardrails: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_GUARDRAIL, + message=f"Guardrail '{guardrail}' not found in guardrail registry", + field="guardrails.add", + value=guardrail, + ) + ) + + for guardrail in policy.guardrails.get_remove(): + if available_guardrails and guardrail not in available_guardrails: + warnings.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_GUARDRAIL, + message=f"Guardrail '{guardrail}' in remove list not found in guardrail registry", + field="guardrails.remove", + value=guardrail, + ) + ) + + # Note: Team, key, and model validation is done via policy_attachments + # Policies no longer have scope - attachments define where policies apply + + # Validate inheritance + inheritance_errors = self._validate_inheritance_chain( + policy_name=policy_name, policies=policies + ) + errors.extend(inheritance_errors) + + return PolicyValidationResponse( + valid=len(errors) == 0, + errors=errors, + warnings=warnings, + ) + + async def validate_policy_config( + self, + policy_config: Dict[str, Any], + validate_db: bool = True, + ) -> PolicyValidationResponse: + """ + Validate a raw policy configuration dictionary. + + This parses the config and then validates it. + + Args: + policy_config: Raw policy configuration from YAML + validate_db: Whether to validate against database + + Returns: + PolicyValidationResponse with errors and warnings + """ + from litellm.proxy.policy_engine.policy_registry import PolicyRegistry + + # First, try to parse the policies + errors: List[PolicyValidationError] = [] + policies: Dict[str, Policy] = {} + + temp_registry = PolicyRegistry() + + for policy_name, policy_data in policy_config.items(): + try: + policy = temp_registry._parse_policy(policy_name, policy_data) + policies[policy_name] = policy + except Exception as e: + errors.append( + PolicyValidationError( + policy_name=policy_name, + error_type=PolicyValidationErrorType.INVALID_SYNTAX, + message=f"Failed to parse policy: {str(e)}", + ) + ) + + # If there were parsing errors, return early + if errors: + return PolicyValidationResponse( + valid=False, + errors=errors, + warnings=[], + ) + + # Validate the parsed policies + return await self.validate_policies(policies, validate_db=validate_db) diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 958ddbf613c..ea405c1dea4 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -1,42 +1,101 @@ model_list: - # Anthropic direct - - model_name: anthropic-claude + - model_name: "*" litellm_params: - model: anthropic/claude-sonnet-4-20250514 - api_key: os.environ/ANTHROPIC_API_KEY - - # Azure AI Anthropic - - model_name: azure-ai-claude + model: "*" + - model_name: "gpt-4" litellm_params: - model: azure_ai/claude-3-5-sonnet - api_base: https://krish-mh44t553-eastus2.services.ai.azure.com/ - api_key: os.environ/AZURE_ANTHROPIC_API_KEY - - # Azure AI Anthropic (alternate endpoint format) - - model_name: claude-4.5-haiku + model: "gpt-4" + api_key: os.environ/OPENAI_API_KEY + - model_name: "gpt-3.5-turbo" litellm_params: - model: anthropic/claude-haiku-4-5 - api_base: https://krish-mh44t553-eastus2.services.ai.azure.com/anthropic/v1/messages - api_version: "2023-06-01" - api_key: os.environ/AZURE_ANTHROPIC_API_KEY + model: "gpt-3.5-turbo" + api_key: os.environ/OPENAI_API_KEY +general_settings: + master_key: sk-1234 +# ─────────────────────────────────────────────── +# POLICIES - Define WHAT guardrails to apply +# ─────────────────────────────────────────────── +# +# Policies define guardrails with: +# - inherit: Inherit guardrails from another policy +# - description: Human-readable description +# - guardrails.add: Add guardrails (on top of inherited) +# - guardrails.remove: Remove guardrails (from inherited) +# - condition.model: Model pattern (exact or regex) for when guardrails apply +# +policies: + # Global baseline policy + global-baseline: + description: "Base guardrails for all requests" + guardrails: + add: + - pii_blocker -# Search Tools Configuration - Define search providers for WebSearch interception -# search_tools: -# - search_tool_name: "my-perplexity-search" -# litellm_params: -# search_provider: "perplexity" # Can be: perplexity, brave, etc. + # Healthcare policy - inherits from global-baseline + healthcare-compliance: + inherit: global-baseline + description: "HIPAA compliance for healthcare teams" + guardrails: + add: + - hipaa_audit -litellm_settings: - callbacks: ["websearch_interception"] - # WebSearch Interception - Automatically intercepts and executes WebSearch tool calls - # for models that don't natively support web search (e.g., Bedrock/Claude) - websearch_interception_params: - enabled_providers: ["bedrock"] # List of providers to enable interception for - search_tool_name: "my-perplexity-search" # Optional: Name of search tool from search_tools config + # Dev policy - inherits but removes PII blocker for testing + internal-dev: + inherit: global-baseline + description: "Relaxed policy for internal development" + guardrails: + add: + - toxicity_filter + remove: + - pii_blocker -general_settings: - store_prompts_in_spend_logs: true - forward_client_headers_to_llm_api: true + # Policy with model condition (regex pattern) + gpt4-safety: + description: "Extra safety for GPT-4 models" + guardrails: + add: + - toxicity_filter + condition: + model: "gpt-4.*" # regex: matches gpt-4, gpt-4-turbo, gpt-4o, etc. + + # Policy with model condition (exact match list) + bedrock-compliance: + description: "Compliance for Bedrock models" + guardrails: + add: + - strict_pii_blocker + condition: + model: ["bedrock/claude-3", "bedrock/claude-2"] # exact matches + +# ─────────────────────────────────────────────── +# POLICY ATTACHMENTS - Define WHERE policies apply +# ─────────────────────────────────────────────── +# +# Attachments are REQUIRED to make policies active. +# A policy without an attachment will not be applied. +# +policy_attachments: + # Global attachment - applies to all requests + - policy: global-baseline + scope: "*" + + # Team-specific attachment + - policy: healthcare-compliance + teams: + - healthcare-team + - medical-research + + # Key pattern attachment + - policy: internal-dev + keys: + - "dev-key-*" + - "test-key-*" + + # Model-specific policies (attached globally, condition filters by model) + - policy: gpt4-safety + scope: "*" + - policy: bedrock-compliance + scope: "*" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ed7c5f8c2f3..8500d820a4c 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -203,13 +203,13 @@ def generate_feedback_box(): from litellm.proxy.analytics_endpoints.analytics_endpoints import ( router as analytics_router, ) +from litellm.proxy.anthropic_endpoints.claude_code_endpoints import ( + claude_code_marketplace_router, +) from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router from litellm.proxy.anthropic_endpoints.skills_endpoints import ( router as anthropic_skills_router, ) -from litellm.proxy.anthropic_endpoints.claude_code_endpoints import ( - claude_code_marketplace_router, -) from litellm.proxy.auth.auth_checks import ( ExperimentalUIJWTToken, get_team_object, @@ -334,6 +334,7 @@ def generate_feedback_box(): from litellm.proxy.management_endpoints.organization_endpoints import ( router as organization_router, ) +from litellm.proxy.management_endpoints.policy_endpoints import router as policy_router from litellm.proxy.management_endpoints.router_settings_endpoints import ( router as router_settings_router, ) @@ -2771,6 +2772,13 @@ async def load_config( # noqa: PLR0915 llm_router=router, ) + # Policy Engine settings + await self._init_policy_engine( + config=config, + prisma_client=prisma_client, + llm_router=router, + ) + ## Prompt settings prompts: Optional[List[Dict]] = None if config is not None: @@ -2831,6 +2839,45 @@ async def _init_non_llm_configs(self, config: dict): ) pass + async def _init_policy_engine( + self, + config: Optional[dict], + prisma_client: Optional["PrismaClient"], + llm_router: Optional["Router"], + ): + """ + Initialize the policy engine from config. + + Args: + config: The proxy configuration dictionary + prisma_client: Optional Prisma client for DB validation + llm_router: Optional LLM router for model validation + """ + + from litellm.proxy.policy_engine.init_policies import init_policies + from litellm.proxy.policy_engine.policy_validator import PolicyValidator + if config is None: + verbose_proxy_logger.debug("Policy engine: config is None, skipping") + return + + policies_config = config.get("policies", None) + if not policies_config: + verbose_proxy_logger.debug("Policy engine: no policies in config, skipping") + return + + policy_attachments_config = config.get("policy_attachments", None) + + verbose_proxy_logger.info(f"Policy engine: found {len(policies_config)} policies in config") + + # Initialize policies + await init_policies( + policies_config=policies_config, + policy_attachments_config=policy_attachments_config, + prisma_client=prisma_client, + validate_db=prisma_client is not None, + fail_on_error=True, + ) + def _load_alerting_settings(self, general_settings: dict): """ Initialize alerting settings @@ -10655,6 +10702,7 @@ async def get_routes(): app.include_router(caching_router) app.include_router(analytics_router) app.include_router(guardrails_router) +app.include_router(policy_router) app.include_router(search_tool_management_router) app.include_router(prompts_router) app.include_router(callback_management_endpoints_router) diff --git a/litellm/types/policy_engine.py b/litellm/types/policy_engine.py new file mode 100644 index 00000000000..d5eb7e2b140 --- /dev/null +++ b/litellm/types/policy_engine.py @@ -0,0 +1,36 @@ +""" +Type definitions for the LiteLLM Policy Engine. + +This module re-exports types from litellm.types.proxy.policy_engine for backward compatibility. +The canonical location for these types is litellm/types/proxy/policy_engine/. +""" + +# Re-export all types from the new location +from litellm.types.proxy.policy_engine import ( # Policy types; Validation types; Resolver types + Policy, + PolicyConfig, + PolicyGuardrails, + PolicyMatchContext, + PolicyScope, + PolicyValidateRequest, + PolicyValidationError, + PolicyValidationErrorType, + PolicyValidationResponse, + ResolvedPolicy, +) + +__all__ = [ + # Policy types + "Policy", + "PolicyConfig", + "PolicyGuardrails", + "PolicyScope", + # Validation types + "PolicyValidateRequest", + "PolicyValidationError", + "PolicyValidationErrorType", + "PolicyValidationResponse", + # Resolver types + "PolicyMatchContext", + "ResolvedPolicy", +] diff --git a/litellm/types/proxy/policy_engine/__init__.py b/litellm/types/proxy/policy_engine/__init__.py new file mode 100644 index 00000000000..50ed4581013 --- /dev/null +++ b/litellm/types/proxy/policy_engine/__init__.py @@ -0,0 +1,61 @@ +""" +Type definitions for the LiteLLM Policy Engine. + +The Policy Engine allows administrators to define policies that combine guardrails +with scoping rules. Policies can target specific teams, API keys, and models using +wildcard patterns, and support inheritance from base policies. + +Configuration: +- `policies`: Define WHAT guardrails to apply (with inheritance and conditions) +- `policy_attachments`: Define WHERE policies apply (teams, keys, models) +""" + +from litellm.types.proxy.policy_engine.policy_types import ( + Policy, + PolicyAttachment, + PolicyCondition, + PolicyConfig, + PolicyGuardrails, + PolicyScope, +) +from litellm.types.proxy.policy_engine.resolver_types import ( + PolicyGuardrailsResponse, + PolicyInfoResponse, + PolicyListResponse, + PolicyMatchContext, + PolicyScopeResponse, + PolicySummaryItem, + PolicyTestResponse, + ResolvedPolicy, +) +from litellm.types.proxy.policy_engine.validation_types import ( + PolicyValidateRequest, + PolicyValidationError, + PolicyValidationErrorType, + PolicyValidationResponse, +) + +__all__ = [ + # Policy types + "Policy", + "PolicyConfig", + "PolicyGuardrails", + "PolicyScope", + "PolicyCondition", + "PolicyAttachment", + # Validation types + "PolicyValidateRequest", + "PolicyValidationError", + "PolicyValidationErrorType", + "PolicyValidationResponse", + # Resolver types + "PolicyMatchContext", + "ResolvedPolicy", + # API Response types + "PolicyGuardrailsResponse", + "PolicyInfoResponse", + "PolicyListResponse", + "PolicyScopeResponse", + "PolicySummaryItem", + "PolicyTestResponse", +] diff --git a/litellm/types/proxy/policy_engine/policy_types.py b/litellm/types/proxy/policy_engine/policy_types.py new file mode 100644 index 00000000000..1c01f89e8b4 --- /dev/null +++ b/litellm/types/proxy/policy_engine/policy_types.py @@ -0,0 +1,299 @@ +""" +Core policy type definitions. + +Policy Engine Configuration: +```yaml +policies: + global-baseline: + description: "Base guardrails for all requests" + guardrails: + add: [pii_blocker] + + healthcare-compliance: + inherit: global-baseline + guardrails: + add: [hipaa_audit] + condition: + model: "gpt-4" # exact match or regex pattern + +policy_attachments: + - policy: global-baseline + scope: "*" + - policy: healthcare-compliance + teams: [healthcare-team] +``` + +Key concepts: +- `policies`: Define WHAT guardrails to apply (with inheritance via `inherit` and `guardrails.add`/`remove`) +- `policy_attachments`: Define WHERE policies apply (teams, keys, models) +- `condition`: Optional model condition for when guardrails apply +""" + +from typing import Any, Dict, List, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Condition +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyCondition(BaseModel): + """ + Condition for when a policy's guardrails apply. + + Currently supports model-based conditions with exact match or regex. + + Example YAML: + ```yaml + condition: + model: "gpt-4" # exact match + model: "gpt-4.*" # regex pattern + model: ["gpt-4", "gpt-4-turbo"] # list of exact matches + ``` + """ + + model: Optional[Union[str, List[str]]] = Field( + default=None, + description="Model name(s) to match. Can be exact string, regex pattern, or list.", + ) + + model_config = ConfigDict(extra="forbid") + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Scope (used internally by attachments) +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyScope(BaseModel): + """ + Defines the scope for matching requests. + + Used internally by PolicyAttachment to define WHERE a policy applies. + + Scope Fields: + | Field | What it matches | Wildcard support | + |--------|-----------------|----------------------| + | teams | Team aliases | *, healthcare-* | + | keys | Key aliases | *, dev-key-* | + | models | Model names | *, bedrock/*, gpt-* | + + If a field is None or empty, it defaults to matching everything (["*"]). + A request must match ALL specified scope fields for the attachment to apply. + """ + + teams: Optional[List[str]] = Field( + default=None, + description="Team aliases or wildcard patterns. Use '*' for all teams.", + ) + keys: Optional[List[str]] = Field( + default=None, + description="Key aliases or wildcard patterns. Use '*' for all keys.", + ) + models: Optional[List[str]] = Field( + default=None, + description="Model names or wildcard patterns. Use '*' for all models.", + ) + + model_config = ConfigDict(extra="forbid") + + def get_teams(self) -> List[str]: + """Returns teams list, defaulting to ['*'] if not specified.""" + return self.teams if self.teams else ["*"] + + def get_keys(self) -> List[str]: + """Returns keys list, defaulting to ['*'] if not specified.""" + return self.keys if self.keys else ["*"] + + def get_models(self) -> List[str]: + """Returns models list, defaulting to ['*'] if not specified.""" + return self.models if self.models else ["*"] + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Guardrails +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyGuardrails(BaseModel): + """ + Defines guardrails to add or remove in a policy. + + - `add`: List of guardrail names to add (on top of inherited guardrails) + - `remove`: List of guardrail names to remove (from inherited guardrails) + + This supports the inheritance pattern where child policies can: + - Add new guardrails on top of parent's guardrails + - Remove specific guardrails inherited from parent + """ + + add: Optional[List[str]] = Field( + default=None, + description="Guardrail names to add to this policy.", + ) + remove: Optional[List[str]] = Field( + default=None, + description="Guardrail names to remove (typically from inherited policy).", + ) + + model_config = ConfigDict(extra="forbid") + + def get_add(self) -> List[str]: + """Returns add list, defaulting to empty list if not specified.""" + return self.add if self.add else [] + + def get_remove(self) -> List[str]: + """Returns remove list, defaulting to empty list if not specified.""" + return self.remove if self.remove else [] + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy +# ───────────────────────────────────────────────────────────────────────────── + + +class Policy(BaseModel): + """ + A policy that defines WHAT guardrails to apply. + + Policies define guardrails but NOT where they apply - that's done via policy_attachments. + + Policies can inherit from other policies using the `inherit` field. + When inheriting: + - Guardrails from `guardrails.add` are added to the inherited guardrails + - Guardrails from `guardrails.remove` are removed from the inherited guardrails + + Policies can have a `condition` for model-based guardrail application. + + Example configuration: + ```yaml + policies: + global-baseline: + description: "Base guardrails for all requests" + guardrails: + add: + - pii_blocker + - phi_blocker + + healthcare-compliance: + inherit: global-baseline + description: "HIPAA compliance for healthcare" + guardrails: + add: + - hipaa_audit + + gpt4-safety: + description: "Extra safety for GPT-4 models" + guardrails: + add: + - toxicity_filter + condition: + model: "gpt-4.*" # regex pattern + + policy_attachments: + - policy: global-baseline + scope: "*" + - policy: healthcare-compliance + teams: [healthcare-team] + - policy: gpt4-safety + scope: "*" + ``` + """ + + inherit: Optional[str] = Field( + default=None, + description="Name of the parent policy to inherit from.", + ) + description: Optional[str] = Field( + default=None, + description="Human-readable description of the policy.", + ) + guardrails: PolicyGuardrails = Field( + default_factory=PolicyGuardrails, + description="Guardrails configuration with add/remove lists.", + ) + condition: Optional[PolicyCondition] = Field( + default=None, + description="Optional condition for when this policy's guardrails apply.", + ) + + model_config = ConfigDict(extra="forbid") + + +# ───────────────────────────────────────────────────────────────────────────── +# Policy Attachments +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyAttachment(BaseModel): + """ + Attaches a policy to a scope - defines WHERE a policy applies. + + Attachments are REQUIRED to make policies active. A policy without + an attachment will not be applied to any requests. + + Example YAML: + ```yaml + policy_attachments: + - policy: global-baseline + scope: "*" # applies to all requests + - policy: healthcare-compliance + teams: [healthcare-team, medical-research] + - policy: dev-safety + keys: ["dev-key-*", "test-key-*"] + - policy: gpt4-specific + models: ["gpt-4", "gpt-4-turbo"] + ``` + """ + + policy: str = Field( + description="Name of the policy to attach.", + ) + scope: Optional[str] = Field( + default=None, + description="Use '*' for global scope (applies to all requests).", + ) + teams: Optional[List[str]] = Field( + default=None, + description="Team aliases or patterns this attachment applies to.", + ) + keys: Optional[List[str]] = Field( + default=None, + description="Key aliases or patterns this attachment applies to.", + ) + models: Optional[List[str]] = Field( + default=None, + description="Model names or patterns this attachment applies to.", + ) + + model_config = ConfigDict(extra="forbid") + + def is_global(self) -> bool: + """Check if this is a global attachment (scope='*').""" + return self.scope == "*" + + def to_policy_scope(self) -> PolicyScope: + """Convert attachment to a PolicyScope for matching.""" + if self.is_global(): + return PolicyScope(teams=["*"], keys=["*"], models=["*"]) + return PolicyScope( + teams=self.teams, + keys=self.keys, + models=self.models, + ) + + +class PolicyConfig(BaseModel): + """ + Root configuration for all policies. + + Maps policy names to their Policy definitions. + """ + + policies: Dict[str, Policy] = Field( + default_factory=dict, + description="Map of policy names to Policy objects.", + ) + + model_config = ConfigDict(extra="forbid") diff --git a/litellm/types/proxy/policy_engine/resolver_types.py b/litellm/types/proxy/policy_engine/resolver_types.py new file mode 100644 index 00000000000..81ae248d436 --- /dev/null +++ b/litellm/types/proxy/policy_engine/resolver_types.py @@ -0,0 +1,110 @@ +""" +Policy resolver type definitions. + +These types are used for matching requests to policies and resolving +the final guardrails list. +""" + +from typing import Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class PolicyMatchContext(BaseModel): + """ + Context used to match a request against policies. + + Contains the team alias, key alias, and model from the incoming request. + """ + + team_alias: Optional[str] = Field( + default=None, + description="Team alias from the request.", + ) + key_alias: Optional[str] = Field( + default=None, + description="API key alias from the request.", + ) + model: Optional[str] = Field( + default=None, + description="Model name from the request.", + ) + + model_config = ConfigDict(extra="forbid") + + +class ResolvedPolicy(BaseModel): + """ + Result of resolving a policy with its inheritance chain. + + Contains the final list of guardrails after applying all add/remove operations. + """ + + policy_name: str = Field(description="Name of the resolved policy.") + guardrails: List[str] = Field( + default_factory=list, + description="Final list of guardrail names to apply.", + ) + inheritance_chain: List[str] = Field( + default_factory=list, + description="List of policy names in the inheritance chain (from root to this policy).", + ) + + model_config = ConfigDict(extra="forbid") + + +# ───────────────────────────────────────────────────────────────────────────── +# API Response Types +# ───────────────────────────────────────────────────────────────────────────── + + +class PolicyScopeResponse(BaseModel): + """Scope configuration for a policy.""" + + teams: List[str] = Field(default_factory=list) + keys: List[str] = Field(default_factory=list) + models: List[str] = Field(default_factory=list) + + +class PolicyGuardrailsResponse(BaseModel): + """Guardrails configuration for a policy.""" + + add: List[str] = Field(default_factory=list) + remove: List[str] = Field(default_factory=list) + + +class PolicyInfoResponse(BaseModel): + """Response for /policy/info/{policy_name} endpoint.""" + + policy_name: str + inherit: Optional[str] = None + scope: PolicyScopeResponse + guardrails: PolicyGuardrailsResponse + resolved_guardrails: List[str] + inheritance_chain: List[str] + + +class PolicySummaryItem(BaseModel): + """Summary of a single policy for list endpoint.""" + + inherit: Optional[str] = None + scope: PolicyScopeResponse + guardrails: PolicyGuardrailsResponse + resolved_guardrails: List[str] + inheritance_chain: List[str] + + +class PolicyListResponse(BaseModel): + """Response for /policy/list endpoint.""" + + policies: Dict[str, PolicySummaryItem] + total_count: int + + +class PolicyTestResponse(BaseModel): + """Response for /policy/test endpoint.""" + + context: PolicyMatchContext + matching_policies: List[str] + resolved_guardrails: List[str] + message: Optional[str] = None diff --git a/litellm/types/proxy/policy_engine/validation_types.py b/litellm/types/proxy/policy_engine/validation_types.py new file mode 100644 index 00000000000..e079febcc9e --- /dev/null +++ b/litellm/types/proxy/policy_engine/validation_types.py @@ -0,0 +1,80 @@ +""" +Policy validation type definitions. + +These types are used for validating policy configurations and returning +validation results. +""" + +from enum import Enum +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class PolicyValidationErrorType(str, Enum): + """Types of validation errors that can occur.""" + + INVALID_GUARDRAIL = "invalid_guardrail" + INVALID_TEAM = "invalid_team" + INVALID_KEY = "invalid_key" + INVALID_MODEL = "invalid_model" + INVALID_INHERITANCE = "invalid_inheritance" + CIRCULAR_INHERITANCE = "circular_inheritance" + INVALID_SCOPE = "invalid_scope" + INVALID_SYNTAX = "invalid_syntax" + + +class PolicyValidationError(BaseModel): + """ + Represents a validation error or warning for a policy. + """ + + policy_name: str = Field(description="Name of the policy with the issue.") + error_type: PolicyValidationErrorType = Field( + description="Type of validation error." + ) + message: str = Field(description="Human-readable error message.") + field: Optional[str] = Field( + default=None, + description="Specific field that caused the error (e.g., 'guardrails.add', 'scope.teams').", + ) + value: Optional[str] = Field( + default=None, + description="The invalid value that caused the error.", + ) + + model_config = ConfigDict(extra="forbid") + + +class PolicyValidationResponse(BaseModel): + """ + Response from policy validation. + + - `valid`: True if no blocking errors were found + - `errors`: List of blocking errors (prevent policy from being applied) + - `warnings`: List of non-blocking warnings (policy can still be applied) + """ + + valid: bool = Field(description="True if the policy configuration is valid.") + errors: List[PolicyValidationError] = Field( + default_factory=list, + description="List of blocking validation errors.", + ) + warnings: List[PolicyValidationError] = Field( + default_factory=list, + description="List of non-blocking validation warnings.", + ) + + model_config = ConfigDict(extra="forbid") + + +class PolicyValidateRequest(BaseModel): + """ + Request body for the /policy/validate endpoint. + """ + + policies: Dict[str, Any] = Field( + description="Policy configuration to validate. Map of policy names to policy definitions." + ) + + model_config = ConfigDict(extra="forbid") diff --git a/proxy_config.yaml b/proxy_config.yaml index 57397181cda..ea405c1dea4 100644 --- a/proxy_config.yaml +++ b/proxy_config.yaml @@ -2,6 +2,100 @@ model_list: - model_name: "*" litellm_params: model: "*" + - model_name: "gpt-4" + litellm_params: + model: "gpt-4" + api_key: os.environ/OPENAI_API_KEY + - model_name: "gpt-3.5-turbo" + litellm_params: + model: "gpt-3.5-turbo" + api_key: os.environ/OPENAI_API_KEY general_settings: master_key: sk-1234 + +# ─────────────────────────────────────────────── +# POLICIES - Define WHAT guardrails to apply +# ─────────────────────────────────────────────── +# +# Policies define guardrails with: +# - inherit: Inherit guardrails from another policy +# - description: Human-readable description +# - guardrails.add: Add guardrails (on top of inherited) +# - guardrails.remove: Remove guardrails (from inherited) +# - condition.model: Model pattern (exact or regex) for when guardrails apply +# +policies: + # Global baseline policy + global-baseline: + description: "Base guardrails for all requests" + guardrails: + add: + - pii_blocker + + # Healthcare policy - inherits from global-baseline + healthcare-compliance: + inherit: global-baseline + description: "HIPAA compliance for healthcare teams" + guardrails: + add: + - hipaa_audit + + # Dev policy - inherits but removes PII blocker for testing + internal-dev: + inherit: global-baseline + description: "Relaxed policy for internal development" + guardrails: + add: + - toxicity_filter + remove: + - pii_blocker + + # Policy with model condition (regex pattern) + gpt4-safety: + description: "Extra safety for GPT-4 models" + guardrails: + add: + - toxicity_filter + condition: + model: "gpt-4.*" # regex: matches gpt-4, gpt-4-turbo, gpt-4o, etc. + + # Policy with model condition (exact match list) + bedrock-compliance: + description: "Compliance for Bedrock models" + guardrails: + add: + - strict_pii_blocker + condition: + model: ["bedrock/claude-3", "bedrock/claude-2"] # exact matches + +# ─────────────────────────────────────────────── +# POLICY ATTACHMENTS - Define WHERE policies apply +# ─────────────────────────────────────────────── +# +# Attachments are REQUIRED to make policies active. +# A policy without an attachment will not be applied. +# +policy_attachments: + # Global attachment - applies to all requests + - policy: global-baseline + scope: "*" + + # Team-specific attachment + - policy: healthcare-compliance + teams: + - healthcare-team + - medical-research + + # Key pattern attachment + - policy: internal-dev + keys: + - "dev-key-*" + - "test-key-*" + + # Model-specific policies (attached globally, condition filters by model) + - policy: gpt4-safety + scope: "*" + + - policy: bedrock-compliance + scope: "*" diff --git a/tests/code_coverage_tests/recursive_detector.py b/tests/code_coverage_tests/recursive_detector.py index 2a460f621b2..d5640f4256c 100644 --- a/tests/code_coverage_tests/recursive_detector.py +++ b/tests/code_coverage_tests/recursive_detector.py @@ -39,6 +39,7 @@ "_delete_nested_value_custom", # max depth set (bounded by number of path segments). "filter_exceptions_from_params", # max depth set (default 20) to prevent infinite recursion. "__getattr__", # lazy loading pattern in litellm/__init__.py with proper caching to prevent infinite recursion. + "_validate_inheritance_chain", # max depth set (default 100) to prevent infinite recursion in policy inheritance validation. ] diff --git a/tests/test_litellm/proxy/policy_engine/__init__.py b/tests/test_litellm/proxy/policy_engine/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py b/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py new file mode 100644 index 00000000000..1ed956fe99f --- /dev/null +++ b/tests/test_litellm/proxy/policy_engine/test_attachment_registry.py @@ -0,0 +1,202 @@ +""" +Unit tests for AttachmentRegistry - tests policy attachment matching. + +Tests the main entry point: get_attached_policies() +""" + +import pytest + +from litellm.proxy.policy_engine.attachment_registry import ( + AttachmentRegistry, + get_attachment_registry, +) +from litellm.types.proxy.policy_engine import PolicyMatchContext + + +class TestGetAttachedPolicies: + """Test get_attached_policies - the main entry point.""" + + def test_global_scope_matches_all_requests(self): + """Test global scope (*) matches any request context.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "global-baseline", "scope": "*"}, + ]) + + # Should match any context + context = PolicyMatchContext( + team_alias="any-team", key_alias="any-key", model="any-model" + ) + attached = registry.get_attached_policies(context) + assert "global-baseline" in attached + + def test_team_specific_attachment(self): + """Test team-specific attachment matches only that team.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "healthcare-policy", "teams": ["healthcare-team"]}, + ]) + + # Match + context = PolicyMatchContext( + team_alias="healthcare-team", key_alias="key", model="gpt-4" + ) + assert "healthcare-policy" in registry.get_attached_policies(context) + + # No match - different team + context_other = PolicyMatchContext( + team_alias="finance-team", key_alias="key", model="gpt-4" + ) + assert "healthcare-policy" not in registry.get_attached_policies(context_other) + + def test_key_wildcard_pattern_attachment(self): + """Test key pattern attachment with wildcard.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "dev-policy", "keys": ["dev-key-*"]}, + ]) + + # Match - key starts with dev-key- + context = PolicyMatchContext( + team_alias="team", key_alias="dev-key-123", model="gpt-4" + ) + assert "dev-policy" in registry.get_attached_policies(context) + + # No match - different prefix + context_prod = PolicyMatchContext( + team_alias="team", key_alias="prod-key-123", model="gpt-4" + ) + assert "dev-policy" not in registry.get_attached_policies(context_prod) + + def test_model_specific_attachment(self): + """Test model-specific attachment.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "gpt4-policy", "models": ["gpt-4", "gpt-4-turbo"]}, + ]) + + # Match + context = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4" + ) + assert "gpt4-policy" in registry.get_attached_policies(context) + + # No match + context_other = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-3.5" + ) + assert "gpt4-policy" not in registry.get_attached_policies(context_other) + + def test_model_wildcard_pattern(self): + """Test model wildcard pattern like bedrock/*.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "bedrock-policy", "models": ["bedrock/*"]}, + ]) + + # Match + context = PolicyMatchContext( + team_alias="team", key_alias="key", model="bedrock/claude-3" + ) + assert "bedrock-policy" in registry.get_attached_policies(context) + + # No match + context_other = PolicyMatchContext( + team_alias="team", key_alias="key", model="openai/gpt-4" + ) + assert "bedrock-policy" not in registry.get_attached_policies(context_other) + + def test_multiple_attachments_match_same_context(self): + """Test multiple attachments can match the same context.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "global-baseline", "scope": "*"}, + {"policy": "healthcare-policy", "teams": ["healthcare-team"]}, + {"policy": "gpt4-policy", "models": ["gpt-4"]}, + ]) + + context = PolicyMatchContext( + team_alias="healthcare-team", key_alias="key", model="gpt-4" + ) + attached = registry.get_attached_policies(context) + + # All three should match + assert "global-baseline" in attached + assert "healthcare-policy" in attached + assert "gpt4-policy" in attached + assert len(attached) == 3 + + def test_same_policy_multiple_attachments_no_duplicates(self): + """Test same policy attached multiple ways doesn't duplicate.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "multi-policy", "scope": "*"}, + {"policy": "multi-policy", "teams": ["healthcare-team"]}, + ]) + + context = PolicyMatchContext( + team_alias="healthcare-team", key_alias="key", model="gpt-4" + ) + attached = registry.get_attached_policies(context) + + # Should only appear once + assert attached.count("multi-policy") == 1 + + def test_no_attachments_returns_empty(self): + """Test empty attachments returns empty list.""" + registry = AttachmentRegistry() + registry.load_attachments([]) + + context = PolicyMatchContext( + team_alias="team", key_alias="key", model="gpt-4" + ) + attached = registry.get_attached_policies(context) + assert attached == [] + + def test_no_matching_attachments_returns_empty(self): + """Test no matching attachments returns empty list.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "healthcare-policy", "teams": ["healthcare-team"]}, + ]) + + context = PolicyMatchContext( + team_alias="finance-team", key_alias="key", model="gpt-4" + ) + attached = registry.get_attached_policies(context) + assert attached == [] + + def test_combined_team_and_model_attachment(self): + """Test attachment with both team and model constraints.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "strict-policy", "teams": ["healthcare-team"], "models": ["gpt-4"]}, + ]) + + # Match - both team and model match + context = PolicyMatchContext( + team_alias="healthcare-team", key_alias="key", model="gpt-4" + ) + assert "strict-policy" in registry.get_attached_policies(context) + + # No match - team matches but model doesn't + context_wrong_model = PolicyMatchContext( + team_alias="healthcare-team", key_alias="key", model="gpt-3.5" + ) + assert "strict-policy" not in registry.get_attached_policies(context_wrong_model) + + # No match - model matches but team doesn't + context_wrong_team = PolicyMatchContext( + team_alias="finance-team", key_alias="key", model="gpt-4" + ) + assert "strict-policy" not in registry.get_attached_policies(context_wrong_team) + + +class TestAttachmentRegistrySingleton: + """Test global singleton behavior.""" + + def test_get_attachment_registry_returns_same_instance(self): + """Test get_attachment_registry returns same instance.""" + registry1 = get_attachment_registry() + registry2 = get_attachment_registry() + assert registry1 is registry2 diff --git a/tests/test_litellm/proxy/policy_engine/test_condition_evaluator.py b/tests/test_litellm/proxy/policy_engine/test_condition_evaluator.py new file mode 100644 index 00000000000..292f6e8f7da --- /dev/null +++ b/tests/test_litellm/proxy/policy_engine/test_condition_evaluator.py @@ -0,0 +1,113 @@ +""" +Unit tests for ConditionEvaluator - tests model condition evaluation. + +Tests: +- Exact model match +- Regex pattern match +- List of models +""" + +import pytest + +from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator +from litellm.types.proxy.policy_engine import ( + PolicyCondition, + PolicyMatchContext, +) + + +class TestConditionEvaluator: + """Test condition evaluation.""" + + def test_no_condition_always_matches(self): + """Test that None condition always matches.""" + context = PolicyMatchContext(team_alias="team", key_alias="key", model="gpt-4") + assert ConditionEvaluator.evaluate(None, context) is True + + def test_exact_model_match(self): + """Test exact model string match.""" + condition = PolicyCondition(model="gpt-4") + + # Match + context = PolicyMatchContext(team_alias="team", key_alias="key", model="gpt-4") + assert ConditionEvaluator.evaluate(condition, context) is True + + # No match + context_other = PolicyMatchContext(team_alias="team", key_alias="key", model="gpt-3.5") + assert ConditionEvaluator.evaluate(condition, context_other) is False + + def test_regex_pattern_match(self): + """Test regex pattern matching.""" + condition = PolicyCondition(model="gpt-4.*") + + # Matches + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-4") + ) is True + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-4-turbo") + ) is True + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-4o") + ) is True + + # No match + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-3.5") + ) is False + + def test_list_of_models_match(self): + """Test list of model values.""" + condition = PolicyCondition(model=["gpt-4", "gpt-4-turbo", "claude-3"]) + + # Matches + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-4") + ) is True + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="claude-3") + ) is True + + # No match + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-3.5") + ) is False + + def test_list_with_regex_patterns(self): + """Test list can contain regex patterns.""" + condition = PolicyCondition(model=["gpt-4.*", "claude-.*"]) + + # Matches + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-4-turbo") + ) is True + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="claude-3") + ) is True + + # No match + assert ConditionEvaluator.evaluate( + condition, + PolicyMatchContext(team_alias="t", key_alias="k", model="llama-2") + ) is False + + def test_none_model_does_not_match(self): + """Test that None model value doesn't match conditions.""" + condition = PolicyCondition(model="gpt-4") + context = PolicyMatchContext(team_alias="t", key_alias="k", model=None) + assert ConditionEvaluator.evaluate(condition, context) is False + + def test_empty_condition_always_matches(self): + """Test condition with no model field always matches.""" + condition = PolicyCondition() # No model specified + context = PolicyMatchContext(team_alias="t", key_alias="k", model="any-model") + assert ConditionEvaluator.evaluate(condition, context) is True diff --git a/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py b/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py new file mode 100644 index 00000000000..c011f31af6a --- /dev/null +++ b/tests/test_litellm/proxy/policy_engine/test_policy_matcher.py @@ -0,0 +1,96 @@ +""" +Unit tests for PolicyMatcher - tests wildcard pattern matching via attachments. + +Tests: +- Wildcard matching (*, prefix-*) +- Scope matching via attachments (teams, keys, models) +""" + +import pytest + +from litellm.proxy.policy_engine.attachment_registry import AttachmentRegistry +from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher +from litellm.types.proxy.policy_engine import ( + PolicyMatchContext, + PolicyScope, +) + + +class TestPolicyMatcherPatternMatching: + """Test pattern matching utilities.""" + + def test_matches_pattern_exact(self): + """Test exact pattern matching.""" + assert PolicyMatcher.matches_pattern("healthcare-team", ["healthcare-team"]) is True + assert PolicyMatcher.matches_pattern("finance-team", ["healthcare-team"]) is False + + def test_matches_pattern_wildcard(self): + """Test wildcard pattern matching.""" + assert PolicyMatcher.matches_pattern("any-team", ["*"]) is True + assert PolicyMatcher.matches_pattern("dev-key-123", ["dev-key-*"]) is True + assert PolicyMatcher.matches_pattern("prod-key-123", ["dev-key-*"]) is False + + def test_matches_pattern_none_value(self): + """Test None value only matches '*'.""" + assert PolicyMatcher.matches_pattern(None, ["*"]) is True + assert PolicyMatcher.matches_pattern(None, ["specific"]) is False + + +class TestPolicyMatcherScopeMatching: + """Test scope matching against context.""" + + def test_scope_matches_all_fields(self): + """Test scope matches when all fields match.""" + scope = PolicyScope(teams=["healthcare-team"], keys=["*"], models=["gpt-4"]) + context = PolicyMatchContext(team_alias="healthcare-team", key_alias="any-key", model="gpt-4") + assert PolicyMatcher.scope_matches(scope, context) is True + + def test_scope_does_not_match_team(self): + """Test scope doesn't match when team doesn't match.""" + scope = PolicyScope(teams=["healthcare-team"], keys=["*"], models=["*"]) + context = PolicyMatchContext(team_alias="finance-team", key_alias="any-key", model="gpt-4") + assert PolicyMatcher.scope_matches(scope, context) is False + + def test_scope_matches_with_wildcard_patterns(self): + """Test scope matches with wildcard patterns.""" + scope = PolicyScope(teams=["*"], keys=["dev-key-*"], models=["bedrock/*"]) + context = PolicyMatchContext(team_alias="any-team", key_alias="dev-key-123", model="bedrock/claude-3") + assert PolicyMatcher.scope_matches(scope, context) is True + + def test_scope_global_wildcard(self): + """Test global scope with all wildcards.""" + scope = PolicyScope(teams=["*"], keys=["*"], models=["*"]) + context = PolicyMatchContext(team_alias="any-team", key_alias="any-key", model="any-model") + assert PolicyMatcher.scope_matches(scope, context) is True + + +class TestPolicyMatcherWithAttachments: + """Test getting matching policies via attachments.""" + + def test_get_matching_policies_via_attachments(self): + """Test matching policies through attachment registry.""" + # Create and configure attachment registry + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "healthcare-policy", "teams": ["healthcare-team"]}, + {"policy": "global-policy", "scope": "*"}, + ]) + + # Test matching via the registry directly + context = PolicyMatchContext(team_alias="healthcare-team", key_alias="k", model="gpt-4") + attached = registry.get_attached_policies(context) + + assert "healthcare-policy" in attached + assert "global-policy" in attached + + def test_get_matching_policies_no_match(self): + """Test no policies match when attachments don't match context.""" + registry = AttachmentRegistry() + registry.load_attachments([ + {"policy": "healthcare-policy", "teams": ["healthcare-team"]}, + ]) + + context = PolicyMatchContext(team_alias="finance-team", key_alias="k", model="gpt-4") + attached = registry.get_attached_policies(context) + + assert "healthcare-policy" not in attached diff --git a/tests/test_litellm/proxy/policy_engine/test_policy_resolver.py b/tests/test_litellm/proxy/policy_engine/test_policy_resolver.py new file mode 100644 index 00000000000..9d672e018af --- /dev/null +++ b/tests/test_litellm/proxy/policy_engine/test_policy_resolver.py @@ -0,0 +1,193 @@ +""" +Unit tests for PolicyResolver - tests guardrail resolution. + +Tests: +- Inheritance chain resolution +- Inheritance with add/remove +- Model conditions +""" + +import pytest + +from litellm.proxy.policy_engine.policy_resolver import PolicyResolver +from litellm.types.proxy.policy_engine import ( + Policy, + PolicyCondition, + PolicyGuardrails, + PolicyMatchContext, +) + + +class TestPolicyResolverInheritance: + """Test resolve_policy_guardrails - inheritance and add/remove.""" + + def test_resolve_simple_policy(self): + """Test resolving guardrails for a simple policy.""" + policies = { + "global": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker", "toxicity_filter"]), + ), + } + + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name="global", policies=policies + ) + + assert set(resolved.guardrails) == {"pii_blocker", "toxicity_filter"} + assert resolved.inheritance_chain == ["global"] + + def test_resolve_with_inheritance(self): + """Test child policy inherits and adds guardrails from parent.""" + policies = { + "base": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker"]), + ), + "healthcare": Policy( + inherit="base", + guardrails=PolicyGuardrails(add=["hipaa_audit"]), + ), + } + + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name="healthcare", policies=policies + ) + + # Healthcare inherits pii_blocker from base and adds hipaa_audit + assert set(resolved.guardrails) == {"pii_blocker", "hipaa_audit"} + assert resolved.inheritance_chain == ["base", "healthcare"] + + def test_resolve_with_remove(self): + """Test child policy can remove guardrails from parent.""" + policies = { + "base": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker", "phi_blocker"]), + ), + "dev": Policy( + inherit="base", + guardrails=PolicyGuardrails(add=["toxicity_filter"], remove=["phi_blocker"]), + ), + } + + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name="dev", policies=policies + ) + + # dev inherits pii_blocker from base, adds toxicity_filter, removes phi_blocker + assert "pii_blocker" in resolved.guardrails + assert "toxicity_filter" in resolved.guardrails + assert "phi_blocker" not in resolved.guardrails + + def test_resolve_deep_inheritance_chain(self): + """Test multi-level inheritance chain.""" + policies = { + "root": Policy( + guardrails=PolicyGuardrails(add=["root_guardrail"]), + ), + "middle": Policy( + inherit="root", + guardrails=PolicyGuardrails(add=["middle_guardrail"]), + ), + "leaf": Policy( + inherit="middle", + guardrails=PolicyGuardrails(add=["leaf_guardrail"]), + ), + } + + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name="leaf", policies=policies + ) + + assert set(resolved.guardrails) == {"root_guardrail", "middle_guardrail", "leaf_guardrail"} + assert resolved.inheritance_chain == ["root", "middle", "leaf"] + + +class TestPolicyResolverWithConditions: + """Test resolve_policy_guardrails with model conditions.""" + + def test_condition_matches(self): + """Test guardrails are added when condition matches.""" + policies = { + "gpt4-policy": Policy( + guardrails=PolicyGuardrails(add=["toxicity_filter"]), + condition=PolicyCondition(model="gpt-4.*"), + ), + } + + # GPT-4 should get guardrails + context = PolicyMatchContext(team_alias="team", key_alias="k", model="gpt-4") + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name="gpt4-policy", + policies=policies, + context=context, + ) + + assert "toxicity_filter" in resolved.guardrails + + def test_condition_does_not_match(self): + """Test guardrails are NOT added when condition doesn't match.""" + policies = { + "gpt4-policy": Policy( + guardrails=PolicyGuardrails(add=["toxicity_filter"]), + condition=PolicyCondition(model="gpt-4.*"), + ), + } + + # GPT-3.5 should NOT get guardrails + context = PolicyMatchContext(team_alias="team", key_alias="k", model="gpt-3.5") + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name="gpt4-policy", + policies=policies, + context=context, + ) + + assert "toxicity_filter" not in resolved.guardrails + + def test_no_condition_always_applies(self): + """Test policy without condition always applies.""" + policies = { + "global": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker"]), + ), + } + + context = PolicyMatchContext(team_alias="any", key_alias="any", model="any") + resolved = PolicyResolver.resolve_policy_guardrails( + policy_name="global", + policies=policies, + context=context, + ) + + assert "pii_blocker" in resolved.guardrails + + def test_inheritance_with_condition(self): + """Test inheritance works with conditions.""" + policies = { + "base": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker"]), + ), + "child": Policy( + inherit="base", + guardrails=PolicyGuardrails(add=["child_guardrail"]), + condition=PolicyCondition(model="gpt-4"), + ), + } + + # GPT-4 should get both base and child guardrails + context_gpt4 = PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-4") + resolved_gpt4 = PolicyResolver.resolve_policy_guardrails( + policy_name="child", + policies=policies, + context=context_gpt4, + ) + assert "pii_blocker" in resolved_gpt4.guardrails + assert "child_guardrail" in resolved_gpt4.guardrails + + # GPT-3.5 should only get base guardrails (child condition doesn't match) + context_gpt35 = PolicyMatchContext(team_alias="t", key_alias="k", model="gpt-3.5") + resolved_gpt35 = PolicyResolver.resolve_policy_guardrails( + policy_name="child", + policies=policies, + context=context_gpt35, + ) + assert "pii_blocker" in resolved_gpt35.guardrails + assert "child_guardrail" not in resolved_gpt35.guardrails diff --git a/tests/test_litellm/proxy/policy_engine/test_policy_validator.py b/tests/test_litellm/proxy/policy_engine/test_policy_validator.py new file mode 100644 index 00000000000..1dbdf5a3ddf --- /dev/null +++ b/tests/test_litellm/proxy/policy_engine/test_policy_validator.py @@ -0,0 +1,85 @@ +""" +Unit tests for PolicyValidator - tests policy configuration validation. + +Tests validation of: +- Inheritance chains (parent exists, no circular deps) +- Guardrail names exist in registry +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from litellm.proxy.policy_engine.policy_validator import PolicyValidator +from litellm.types.proxy.policy_engine import ( + Policy, + PolicyGuardrails, + PolicyValidationErrorType, +) + + +class TestPolicyValidator: + """Test policy validation logic.""" + + @pytest.mark.asyncio + async def test_validate_missing_parent_policy(self): + """Test that referencing non-existent parent policy fails.""" + policies = { + "child": Policy( + inherit="nonexistent-parent", + guardrails=PolicyGuardrails(add=["hipaa_audit"]), + ), + } + + validator = PolicyValidator(prisma_client=None) + result = await validator.validate_policies(policies=policies, validate_db=False) + + assert result.valid is False + assert any( + e.error_type == PolicyValidationErrorType.INVALID_INHERITANCE + for e in result.errors + ) + + @pytest.mark.asyncio + async def test_validate_invalid_guardrail(self): + """Test that referencing non-existent guardrail fails.""" + policies = { + "test-policy": Policy( + guardrails=PolicyGuardrails(add=["nonexistent_guardrail"]), + ), + } + + validator = PolicyValidator(prisma_client=None) + with patch.object( + validator, "get_available_guardrails", return_value={"pii_blocker", "toxicity_filter"} + ): + result = await validator.validate_policies(policies=policies, validate_db=False) + + assert result.valid is False + assert any( + e.error_type == PolicyValidationErrorType.INVALID_GUARDRAIL + and e.value == "nonexistent_guardrail" + for e in result.errors + ) + + @pytest.mark.asyncio + async def test_validate_valid_policy(self): + """Test that a valid policy passes validation.""" + policies = { + "base": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker"]), + ), + "child": Policy( + inherit="base", + guardrails=PolicyGuardrails(add=["toxicity_filter"]), + ), + } + + validator = PolicyValidator(prisma_client=None) + with patch.object( + validator, "get_available_guardrails", return_value={"pii_blocker", "toxicity_filter"} + ): + result = await validator.validate_policies(policies=policies, validate_db=False) + + assert result.valid is True + assert len(result.errors) == 0 diff --git a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py index 133fc07d340..b9485a2e4cb 100644 --- a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py +++ b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py @@ -16,6 +16,7 @@ _get_dynamic_logging_metadata, _get_enforced_params, _update_model_if_key_alias_exists, + add_guardrails_from_policy_engine, add_litellm_data_to_request, check_if_token_is_service_account, ) @@ -1477,3 +1478,73 @@ async def test_embedding_header_forwarding_without_model_group_config(): finally: # Restore original model_group_settings litellm.model_group_settings = original_model_group_settings + + +def test_add_guardrails_from_policy_engine(): + """ + Test that add_guardrails_from_policy_engine adds guardrails from matching policies + and tracks applied policies in metadata. + """ + from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry + from litellm.proxy.policy_engine.policy_registry import get_policy_registry + from litellm.types.proxy.policy_engine import ( + Policy, + PolicyAttachment, + PolicyGuardrails, + ) + + # Setup test data + data = { + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "metadata": {}, + } + + user_api_key_dict = UserAPIKeyAuth( + api_key="test-key", + team_alias="healthcare-team", + key_alias="my-key", + ) + + # Setup mock policies in the registry (policies define WHAT guardrails to apply) + policy_registry = get_policy_registry() + policy_registry._policies = { + "global-baseline": Policy( + guardrails=PolicyGuardrails(add=["pii_blocker"]), + ), + "healthcare": Policy( + guardrails=PolicyGuardrails(add=["hipaa_audit"]), + ), + } + policy_registry._initialized = True + + # Setup attachments in the attachment registry (attachments define WHERE policies apply) + attachment_registry = get_attachment_registry() + attachment_registry._attachments = [ + PolicyAttachment(policy="global-baseline", scope="*"), # applies to all + PolicyAttachment(policy="healthcare", teams=["healthcare-team"]), # applies to healthcare team + ] + attachment_registry._initialized = True + + # Call the function + add_guardrails_from_policy_engine( + data=data, + metadata_variable_name="metadata", + user_api_key_dict=user_api_key_dict, + ) + + # Verify guardrails were added + assert "guardrails" in data["metadata"] + assert "pii_blocker" in data["metadata"]["guardrails"] + assert "hipaa_audit" in data["metadata"]["guardrails"] + + # Verify applied policies were tracked + assert "applied_policies" in data["metadata"] + assert "global-baseline" in data["metadata"]["applied_policies"] + assert "healthcare" in data["metadata"]["applied_policies"] + + # Clean up registries + policy_registry._policies = {} + policy_registry._initialized = False + attachment_registry._attachments = [] + attachment_registry._initialized = False