Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions litellm/proxy/common_request_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,11 +650,15 @@ async def base_process_llm_request(
)

tasks = []
# Start the moderation check (during_call_hook) as early as possible
# This gives it a head start to mask/validate input while the proxy handles routing
tasks.append(
proxy_logging_obj.during_call_hook(
data=self.data,
user_api_key_dict=user_api_key_dict,
call_type=route_type, # type: ignore
asyncio.create_task(
proxy_logging_obj.during_call_hook(
data=self.data,
user_api_key_dict=user_api_key_dict,
call_type=route_type, # type: ignore
)
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,15 @@ def __init__(
for pattern_config in normalized_patterns:
self._add_pattern(pattern_config)

# Warn if using during_call with MASK action (unstable)
if self.event_hook == GuardrailEventHooks.during_call and any(
p["action"] == ContentFilterAction.MASK for p in self.compiled_patterns
):
verbose_proxy_logger.warning(
f"ContentFilterGuardrail '{self.guardrail_name}': 'during_call' mode with 'MASK' action is unstable due to race conditions. "
"Use 'pre_call' mode for reliable request masking."
)

# Load blocked words - always initialize as dict
self.blocked_words: Dict[str, Tuple[ContentFilterAction, Optional[str]]] = {}
for word in normalized_blocked_words:
Expand Down Expand Up @@ -905,11 +914,15 @@ async def _process_images(
elif isinstance(e.detail, str):
e.detail = e.detail + " (Image description): " + description
else:
e.detail = "Content blocked: Image description detected" + description
e.detail = (
"Content blocked: Image description detected" + description
)
raise e

def _count_masked_entities(
self, detections: List[ContentFilterDetection], masked_entity_count: Dict[str, int]
self,
detections: List[ContentFilterDetection],
masked_entity_count: Dict[str, int],
) -> None:
"""
Count masked entities by type from detections.
Expand Down Expand Up @@ -964,9 +977,11 @@ def _log_guardrail_information(
dict(detection) for detection in detections
]
if status != "success":
guardrail_json_response = exception_str if exception_str else [
dict(detection) for detection in detections
]
guardrail_json_response = (
exception_str
if exception_str
else [dict(detection) for detection in detections]
)

self.add_standard_logging_guardrail_information_to_request_data(
guardrail_provider=self.guardrail_provider,
Expand Down Expand Up @@ -1066,99 +1081,84 @@ async def async_post_call_streaming_iterator_hook(
Process streaming response chunks and check for blocked content.

For BLOCK action: Raises HTTPException immediately when blocked content is detected.
For MASK action: Content passes through (masking streaming responses is not supported).
For MASK action: Content is buffered to handle patterns split across chunks.
"""
accumulated_full_text = ""
yielded_masked_text_len = 0
buffer_size = 50 # Increased buffer to catch patterns split across many chunks

# Accumulate content as we iterate through chunks
accumulated_content = ""
verbose_proxy_logger.info(
f"ContentFilterGuardrail: Starting robust streaming masking for model {request_data.get('model')}"
)

async for item in response:
# Accumulate content from this chunk before checking
if isinstance(item, ModelResponseStream) and item.choices:
delta_content = ""
is_final = False
for choice in item.choices:
if hasattr(choice, "delta") and choice.delta:
content = getattr(choice.delta, "content", None)
if content and isinstance(content, str):
accumulated_content += content

# Check accumulated content for blocked patterns/keywords after processing all choices
# Only check for BLOCK actions, not MASK (masking streaming is not supported)
if accumulated_content:
try:
# Check patterns
pattern_match = self._check_patterns(accumulated_content)
if pattern_match:
matched_text, pattern_name, action = pattern_match
if action == ContentFilterAction.BLOCK:
error_msg = (
f"Content blocked: {pattern_name} pattern detected"
)
verbose_proxy_logger.warning(error_msg)
raise HTTPException(
status_code=403,
detail={
"error": error_msg,
"pattern": pattern_name,
},
)

# Check blocked words
blocked_word_match = self._check_blocked_words(
accumulated_content
)
if blocked_word_match:
keyword, action, description = blocked_word_match
if action == ContentFilterAction.BLOCK:
error_msg = (
f"Content blocked: keyword '{keyword}' detected"
)
if description:
error_msg += f" ({description})"
verbose_proxy_logger.warning(error_msg)
raise HTTPException(
status_code=403,
detail={
"error": error_msg,
"keyword": keyword,
"description": description,
},
)

# Check category keywords
all_exceptions = []
for category in self.loaded_categories.values():
all_exceptions.extend(category.exceptions)
category_match = self._check_category_keywords(
accumulated_content, all_exceptions
)
if category_match:
keyword, category_name, severity, action = category_match
if action == ContentFilterAction.BLOCK:
error_msg = (
f"Content blocked: {category_name} category keyword '{keyword}' detected "
f"(severity: {severity})"
)
verbose_proxy_logger.warning(error_msg)
raise HTTPException(
status_code=403,
detail={
"error": error_msg,
"category": category_name,
"keyword": keyword,
"severity": severity,
},
)
except HTTPException:
# Re-raise HTTPException (blocked content detected)
raise
except Exception as e:
# Log other exceptions but don't block the stream
verbose_proxy_logger.warning(
f"Error checking content filter in streaming: {e}"
)
delta_content += content
if getattr(choice, "finish_reason", None):
is_final = True

accumulated_full_text += delta_content

# Check for blocking or apply masking
# Add a space at the end if it's the final chunk to trigger word boundaries (\b)
text_to_check = accumulated_full_text
if is_final:
text_to_check += " "

try:
masked_text = self._filter_single_text(text_to_check)
if is_final and masked_text.endswith(" "):
masked_text = masked_text[:-1]
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(
f"ContentFilterGuardrail: Error in masking: {e}"
)
masked_text = text_to_check # Fallback to current text

# Determine how much can be safely yielded
if is_final:
safe_to_yield_len = len(masked_text)
else:
safe_to_yield_len = max(0, len(masked_text) - buffer_size)

if safe_to_yield_len > yielded_masked_text_len:
new_masked_content = masked_text[
yielded_masked_text_len:safe_to_yield_len
]
# Modify the chunk to contain only the new masked content
if (
item.choices
and hasattr(item.choices[0], "delta")
and item.choices[0].delta
):
item.choices[0].delta.content = new_masked_content
yielded_masked_text_len = safe_to_yield_len
yield item
else:
# Hold content by yielding empty content chunk (keeps metadata/structure)
if (
item.choices
and hasattr(item.choices[0], "delta")
and item.choices[0].delta
):
item.choices[0].delta.content = ""
yield item
else:
# Not a ModelResponseStream or no choices - yield as is
yield item

# Yield the chunk (only if no exception was raised above)
yield item
# Any remaining content (should have been handled by is_final, but just in case)
if yielded_masked_text_len < len(accumulated_full_text):
# We already reached the end of the generator
pass

@staticmethod
def get_config_model():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
{
"name": "ipv6",
"display_name": "IP Address (IPv6)",
"pattern": "\\b(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}\\b",
"pattern": "(?<![0-9a-fA-F:])(?:(?:[0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,7}:|:(?::[0-9a-fA-F]{1,4}){1,7}|::|(?:[0-9a-fA-F]{1,4}:){1,6}:[0-9a-fA-F]{1,4}|(?:[0-9a-fA-F]{1,4}:){1,5}(?::[0-9a-fA-F]{1,4}){1,2}|(?:[0-9a-fA-F]{1,4}:){1,4}(?::[0-9a-fA-F]{1,4}){1,3}|(?:[0-9a-fA-F]{1,4}:){1,3}(?::[0-9a-fA-F]{1,4}){1,4}|(?:[0-9a-fA-F]{1,4}:){1,2}(?::[0-9a-fA-F]{1,4}){1,5}|[0-9a-fA-F]{1,4}:(?::[0-9a-fA-F]{1,4}){1,6})(?![0-9a-fA-F:])",
"category": "Network Patterns",
"description": "Detects IPv6 addresses"
},
Expand All @@ -122,9 +122,9 @@
{
"name": "passport_us",
"display_name": "Passport (US)",
"pattern": "\\b[0-9]{9}\\b",
"pattern": "\\b([A-Z][0-9]{8}|[0-9]{9})\\b",
"category": "PII Patterns",
"description": "US passport numbers (9 digits)"
"description": "US passport numbers (9 digits or alphanumeric letter + 8 digits)"
},
{
"name": "passport_uk",
Expand Down Expand Up @@ -157,9 +157,9 @@
{
"name": "passport_canada",
"display_name": "Passport (Canada)",
"pattern": "\\b[A-Z]{2}[0-9]{6}\\b",
"pattern": "\\b([A-Z]{2}[0-9]{6}|[A-Z][0-9]{6}[A-Z]{2})\\b",
"category": "PII Patterns",
"description": "Canadian passport numbers (2 letters + 6 digits)"
"description": "Canadian passport numbers (old: 2 letters + 6 digits; new: 1 letter + 6 digits + 2 letters)"
},
{
"name": "passport_india",
Expand Down Expand Up @@ -369,4 +369,4 @@
"description": "Detects Brazilian RG identity card numbers (common pattern for SP, RJ, MG states)"
}
]
}
}
13 changes: 12 additions & 1 deletion litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1901,7 +1901,18 @@ async def async_post_call_streaming_iterator_hook(
) or _callback.should_run_guardrail(
data=request_data, event_type=GuardrailEventHooks.post_call
):
if "apply_guardrail" in type(callback).__dict__:
if (
"async_post_call_streaming_iterator_hook"
in type(callback).__dict__
):
current_response = (
_callback.async_post_call_streaming_iterator_hook(
user_api_key_dict=user_api_key_dict,
response=current_response,
request_data=request_data,
)
)
elif "apply_guardrail" in type(callback).__dict__:
request_data["guardrail_to_apply"] = callback
current_response = (
unified_guardrail.async_post_call_streaming_iterator_hook(
Expand Down
25 changes: 22 additions & 3 deletions litellm/types/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from litellm.types.proxy.guardrails.guardrail_hooks.qualifire import (
QualifireGuardrailConfigModel,
)
from litellm.types.proxy.guardrails.guardrail_hooks.litellm_content_filter import (
ContentFilterCategoryConfig,
)

"""
Pydantic object defining how to set guardrails on litellm proxy
Expand Down Expand Up @@ -547,9 +550,27 @@ class ContentFilterConfigModel(BaseModel):
blocked_words_file: Optional[str] = Field(
default=None, description="Path to YAML file containing blocked_words list"
)
categories: Optional[List[ContentFilterCategoryConfig]] = Field(
default=None,
description="List of prebuilt categories to enable (harmful_*, bias_*)",
)
severity_threshold: Optional[str] = Field(
default=None,
description="Minimum severity to block (high, medium, low)",
)
pattern_redaction_format: Optional[str] = Field(
default=None,
description="Format string for pattern redaction (use {pattern_name} placeholder)",
)
keyword_redaction_tag: Optional[str] = Field(
default=None,
description="Tag to use for keyword redaction",
)


class BaseLitellmParams(BaseModel): # works for new and patch update guardrails
class BaseLitellmParams(
ContentFilterConfigModel
): # works for new and patch update guardrails
api_key: Optional[str] = Field(
default=None, description="API key for the guardrail service"
)
Expand Down Expand Up @@ -630,7 +651,6 @@ class BaseLitellmParams(BaseModel): # works for new and patch update guardrails
description="Whether to fail the request if Model Armor encounters an error",
)

# Generic Guardrail API params
additional_provider_specific_params: Optional[Dict[str, Any]] = Field(
default=None,
description="Additional provider-specific parameters for generic guardrail APIs",
Expand All @@ -657,7 +677,6 @@ class LitellmParams(
ToolPermissionGuardrailConfigModel,
ZscalerAIGuardConfigModel,
JavelinGuardrailConfigModel,
ContentFilterConfigModel,
BaseLitellmParams,
EnkryptAIGuardrailConfigs,
IBMGuardrailsBaseConfigModel,
Expand Down
Loading
Loading