diff --git a/litellm/llms/anthropic/chat/guardrail_translation/handler.py b/litellm/llms/anthropic/chat/guardrail_translation/handler.py index 094b5842f07..b12dee1e3ae 100644 --- a/litellm/llms/anthropic/chat/guardrail_translation/handler.py +++ b/litellm/llms/anthropic/chat/guardrail_translation/handler.py @@ -253,20 +253,36 @@ async def process_output_response( task_mappings: List[Tuple[int, Optional[int]]] = [] # Track (content_index, None) for each text - response_content = response.get("content", []) + # Handle both dict and object responses + if hasattr(response, "get"): + response_content = response.get("content", []) + elif hasattr(response, "content"): + response_content = response.content or [] + else: + response_content = [] + if not response_content: return response # Step 1: Extract all text content and tool calls from response for content_idx, content_block in enumerate(response_content): - # Check if this is a text or tool_use block by checking the 'type' field - if isinstance(content_block, dict) and content_block.get("type") in [ - "text", - "tool_use", - ]: - # Cast to dict to handle the union type properly + # Handle both dict and Pydantic object content blocks + if isinstance(content_block, dict): + block_type = content_block.get("type") + block_dict = content_block + elif hasattr(content_block, "type"): + block_type = getattr(content_block, "type", None) + # Convert Pydantic object to dict for processing + if hasattr(content_block, "model_dump"): + block_dict = content_block.model_dump() + else: + block_dict = {"type": block_type, "text": getattr(content_block, "text", None)} + else: + continue + + if block_type in ["text", "tool_use"]: self._extract_output_text_and_images( - content_block=cast(Dict[str, Any], content_block), + content_block=cast(Dict[str, Any], block_dict), content_idx=content_idx, texts_to_check=texts_to_check, images_to_check=images_to_check, @@ -590,7 +606,14 @@ async def _apply_guardrail_responses_to_output( mapping = task_mappings[task_idx] content_idx = cast(int, mapping[0]) - response_content = response.get("content", []) + # Handle both dict and object responses + if hasattr(response, "get"): + response_content = response.get("content", []) + elif hasattr(response, "content"): + response_content = response.content or [] + else: + continue + if not response_content: continue @@ -601,7 +624,11 @@ async def _apply_guardrail_responses_to_output( content_block = response_content[content_idx] # Verify it's a text block and update the text field - if isinstance(content_block, dict) and content_block.get("type") == "text": - # Cast to dict to handle the union type properly for assignment - content_block = cast("AnthropicResponseTextBlock", content_block) - content_block["text"] = guardrail_response + # Handle both dict and Pydantic object content blocks + if isinstance(content_block, dict): + if content_block.get("type") == "text": + content_block["text"] = guardrail_response + elif hasattr(content_block, "type") and getattr(content_block, "type", None) == "text": + # Update Pydantic object's text attribute + if hasattr(content_block, "text"): + content_block.text = guardrail_response diff --git a/litellm/llms/openai/responses/guardrail_translation/handler.py b/litellm/llms/openai/responses/guardrail_translation/handler.py index 4480ec497c7..02899a67037 100644 --- a/litellm/llms/openai/responses/guardrail_translation/handler.py +++ b/litellm/llms/openai/responses/guardrail_translation/handler.py @@ -30,7 +30,6 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast -from openai.types.responses import ResponseFunctionToolCall from pydantic import BaseModel from litellm._logging import verbose_proxy_logger @@ -299,8 +298,25 @@ async def process_output_response( task_mappings: List[Tuple[int, int]] = [] # Track (output_item_index, content_index) for each text + # Handle both dict and Pydantic object responses + if isinstance(response, dict): + response_output = response.get("output", []) + elif hasattr(response, "output"): + response_output = response.output or [] + else: + verbose_proxy_logger.debug( + "OpenAI Responses API: No output found in response" + ) + return response + + if not response_output: + verbose_proxy_logger.debug( + "OpenAI Responses API: Empty output in response" + ) + return response + # Step 1: Extract all text content and tool calls from response output - for output_idx, output_item in enumerate(response.output): + for output_idx, output_item in enumerate(response_output): self._extract_output_text_and_images( output_item=output_item, output_idx=output_idx, @@ -538,13 +554,18 @@ def _extract_output_text_and_images( content: Optional[Union[List[OutputText], List[dict]]] = None if isinstance(output_item, BaseModel): try: + output_item_dump = output_item.model_dump() generic_response_output_item = GenericResponseOutputItem.model_validate( - output_item.model_dump() + output_item_dump ) if generic_response_output_item.content: content = generic_response_output_item.content except Exception: - return + # Try to extract content directly from output_item if validation fails + if hasattr(output_item, "content") and output_item.content: + content = output_item.content + else: + return elif isinstance(output_item, dict): content = output_item.get("content", []) else: @@ -582,22 +603,53 @@ async def _apply_guardrail_responses_to_output( Override this method to customize how responses are applied. """ + # Handle both dict and Pydantic object responses + if isinstance(response, dict): + response_output = response.get("output", []) + elif hasattr(response, "output"): + response_output = response.output or [] + else: + return + for task_idx, guardrail_response in enumerate(responses): mapping = task_mappings[task_idx] output_idx = cast(int, mapping[0]) content_idx = cast(int, mapping[1]) - output_item = response.output[output_idx] + if output_idx >= len(response_output): + continue + + output_item = response_output[output_idx] - # Handle both GenericResponseOutputItem and dict + # Handle both GenericResponseOutputItem, BaseModel, and dict if isinstance(output_item, GenericResponseOutputItem): - content_item = output_item.content[content_idx] - if isinstance(content_item, OutputText): - content_item.text = guardrail_response - elif isinstance(content_item, dict): - content_item["text"] = guardrail_response + if output_item.content and content_idx < len(output_item.content): + content_item = output_item.content[content_idx] + if isinstance(content_item, OutputText): + content_item.text = guardrail_response + elif isinstance(content_item, dict): + content_item["text"] = guardrail_response + elif isinstance(output_item, BaseModel): + # Handle other Pydantic models by converting to GenericResponseOutputItem + try: + generic_item = GenericResponseOutputItem.model_validate( + output_item.model_dump() + ) + if generic_item.content and content_idx < len(generic_item.content): + content_item = generic_item.content[content_idx] + if isinstance(content_item, OutputText): + content_item.text = guardrail_response + # Update the original response output + if hasattr(output_item, "content") and output_item.content: + original_content = output_item.content[content_idx] + if hasattr(original_content, "text"): + original_content.text = guardrail_response + except Exception: + pass elif isinstance(output_item, dict): content = output_item.get("content", []) if content and content_idx < len(content): if isinstance(content[content_idx], dict): content[content_idx]["text"] = guardrail_response + elif hasattr(content[content_idx], "text"): + content[content_idx].text = guardrail_response diff --git a/litellm/proxy/guardrails/guardrail_hooks/grayswan/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/grayswan/__init__.py index 389340014f8..99f58f654a7 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/grayswan/__init__.py +++ b/litellm/proxy/guardrails/guardrail_hooks/grayswan/__init__.py @@ -40,6 +40,12 @@ def initialize_guardrail( ), categories=_get_config_value(litellm_params, optional_params, "categories"), policy_id=_get_config_value(litellm_params, optional_params, "policy_id"), + streaming_end_of_stream_only=_get_config_value( + litellm_params, optional_params, "streaming_end_of_stream_only" + ) or False, + streaming_sampling_rate=_get_config_value( + litellm_params, optional_params, "streaming_sampling_rate" + ) or 5, event_hook=litellm_params.mode, default_on=litellm_params.default_on, ) diff --git a/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py b/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py index e1d91ee908d..4e9dd63d41a 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py +++ b/litellm/proxy/guardrails/guardrail_hooks/grayswan/grayswan.py @@ -1,26 +1,22 @@ """Gray Swan Cygnal guardrail integration.""" import os -from typing import Any, Dict, Literal, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from fastapi import HTTPException from litellm._logging import verbose_proxy_logger -from litellm.integrations.custom_guardrail import ( - CustomGuardrail, - log_guardrail_information, -) +from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, ) -from litellm.proxy._types import UserAPIKeyAuth -from litellm.proxy.common_utils.callback_utils import ( - add_guardrail_to_applied_guardrails_header, -) from litellm.types.guardrails import GuardrailEventHooks -from litellm.types.utils import Choices, LLMResponseTypes, ModelResponse +from litellm.types.utils import GenericGuardrailAPIInputs + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj class GraySwanGuardrailMissingSecrets(Exception): @@ -35,6 +31,15 @@ class GraySwanGuardrail(CustomGuardrail): """ Guardrail that calls Gray Swan's Cygnal monitoring endpoint. + Uses the unified guardrail system via `apply_guardrail` method, + which automatically works with all LiteLLM endpoints: + - OpenAI Chat Completions + - OpenAI Responses API + - OpenAI Text Completions + - Anthropic Messages + - Image Generation + - And more... + see: https://docs.grayswan.ai/cygnal/monitor-requests """ @@ -54,6 +59,8 @@ def __init__( reasoning_mode: Optional[str] = None, categories: Optional[Dict[str, str]] = None, policy_id: Optional[str] = None, + streaming_end_of_stream_only: bool = False, + streaming_sampling_rate: int = 5, **kwargs: Any, ) -> None: self.async_handler = get_async_httpx_client( @@ -88,6 +95,16 @@ def __init__( self.categories = categories self.policy_id = policy_id + # Streaming configuration + self.streaming_end_of_stream_only = streaming_end_of_stream_only + self.streaming_sampling_rate = streaming_sampling_rate + + verbose_proxy_logger.debug( + "GraySwan __init__: streaming_end_of_stream_only=%s, streaming_sampling_rate=%s", + streaming_end_of_stream_only, + streaming_sampling_rate, + ) + supported_event_hooks = [ GuardrailEventHooks.pre_call, GuardrailEventHooks.during_call, @@ -101,217 +118,107 @@ def __init__( ) # ------------------------------------------------------------------ - # Guardrail hook entry points + # Debug override to trace post_call issues # ------------------------------------------------------------------ - @log_guardrail_information - async def async_pre_call_hook( - self, - user_api_key_dict: UserAPIKeyAuth, - cache, - data: dict, - call_type: Literal[ - "completion", - "text_completion", - "embeddings", - "image_generation", - "moderation", - "audio_transcription", - "pass_through_endpoint", - "rerank", - "mcp_call", - "anthropic_messages", - ], - ) -> Optional[Union[Exception, str, dict]]: - if ( - self.should_run_guardrail( - data=data, event_type=GuardrailEventHooks.pre_call - ) - is not True - ): - return data + def should_run_guardrail(self, data, event_type) -> bool: + """Override to add debug logging.""" + result = super().should_run_guardrail(data, event_type) + # Check if apply_guardrail is in __dict__ + has_apply_guardrail = "apply_guardrail" in type(self).__dict__ + verbose_proxy_logger.debug( + "GraySwan DEBUG: should_run_guardrail event_type=%s, result=%s, event_hook=%s, has_apply_guardrail=%s, class=%s", + event_type, + result, + self.event_hook, + has_apply_guardrail, + type(self).__name__, + ) + return result - verbose_proxy_logger.debug("Gray Swan Guardrail: pre-call hook triggered") + # ------------------------------------------------------------------ + # Unified Guardrail Interface (works with ALL endpoints automatically) + # ------------------------------------------------------------------ - messages = data.get("messages") - if not messages: - verbose_proxy_logger.debug("Gray Swan Guardrail: No messages in data") - return data + async def apply_guardrail( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + input_type: Literal["request", "response"], + logging_obj: Optional["LiteLLMLoggingObj"] = None, + ) -> GenericGuardrailAPIInputs: + """ + Apply Gray Swan guardrail to extracted text content. - dynamic_body = self.get_guardrail_dynamic_request_body_params(data) or {} + This method is called by the unified guardrail system which handles + extracting text from any request format (OpenAI, Anthropic, etc.). - payload = self._prepare_payload(messages, dynamic_body) - if payload is None: - verbose_proxy_logger.debug( - "Gray Swan Guardrail: no content to scan; skipping request" - ) - return data + Args: + inputs: Dictionary containing: + - texts: List of texts to scan + - images: Optional list of images (not currently used by GraySwan) + - tool_calls: Optional list of tool calls (not currently used) + request_data: The original request data + input_type: "request" for pre-call, "response" for post-call + logging_obj: Optional logging object - await self.run_grayswan_guardrail(payload, data, GuardrailEventHooks.pre_call) - add_guardrail_to_applied_guardrails_header( - request_data=data, guardrail_name=self.guardrail_name + Returns: + GenericGuardrailAPIInputs - texts may be replaced with violation message in passthrough mode + + Raises: + HTTPException: If content is blocked (block mode) + Exception: If guardrail check fails + """ + # DEBUG: Log when apply_guardrail is called + verbose_proxy_logger.debug( + "GraySwan DEBUG: apply_guardrail called with input_type=%s, texts=%s", + input_type, + inputs.get("texts", [])[:100] if inputs.get("texts") else "NONE", ) - return data - @log_guardrail_information - async def async_moderation_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - call_type: Literal[ - "completion", - "embeddings", - "image_generation", - "moderation", - "audio_transcription", - "responses", - "mcp_call", - "anthropic_messages", - ], - ) -> Optional[Union[Exception, str, dict]]: - if ( - self.should_run_guardrail( - data=data, event_type=GuardrailEventHooks.during_call - ) - is not True - ): - return data + texts = inputs.get("texts", []) + if not texts: + verbose_proxy_logger.debug("Gray Swan Guardrail: No texts to scan") + return inputs - verbose_proxy_logger.debug("GraySwan Guardrail: during-call hook triggered") + verbose_proxy_logger.debug( + "Gray Swan Guardrail: Scanning %d text(s) for %s", + len(texts), + input_type, + ) - messages = data.get("messages") - if not messages: - verbose_proxy_logger.debug("Gray Swan Guardrail: No messages in data") - return data + # Convert texts to messages format for GraySwan API + # Use "user" role for request content, "assistant" for response content + role = "assistant" if input_type == "response" else "user" + messages = [{"role": role, "content": text} for text in texts] - dynamic_body = self.get_guardrail_dynamic_request_body_params(data) or {} + # Get dynamic params from request metadata + dynamic_body = self.get_guardrail_dynamic_request_body_params(request_data) or {} + # Prepare and send payload payload = self._prepare_payload(messages, dynamic_body) if payload is None: - verbose_proxy_logger.debug( - "Gray Swan Guardrail: no content to scan; skipping request" - ) - return data - - await self.run_grayswan_guardrail( - payload, data, GuardrailEventHooks.during_call + return inputs + + # Call GraySwan API + response_json = await self._call_grayswan_api(payload) + # Process response + is_output = input_type == "response" + result = self._process_response( + response_json=response_json, + request_data=request_data, + inputs=inputs, + is_output=is_output, ) - add_guardrail_to_applied_guardrails_header( - request_data=data, guardrail_name=self.guardrail_name - ) - return data - - @log_guardrail_information - async def async_post_call_success_hook( - self, - data: dict, - user_api_key_dict: UserAPIKeyAuth, - response: LLMResponseTypes, - ) -> LLMResponseTypes: - if ( - self.should_run_guardrail( - data=data, event_type=GuardrailEventHooks.post_call - ) - is not True - ): - return response - - verbose_proxy_logger.debug("GraySwan Guardrail: post-call hook triggered") - - response_dict = response.model_dump() if hasattr(response, "model_dump") else {} # type: ignore[union-attr] - response_messages = [ - msg if isinstance(msg, dict) else msg.model_dump() - for choice in response_dict.get("choices", []) - if isinstance(choice, dict) - for msg in [choice.get("message")] - if msg is not None - ] - - if not response_messages: - verbose_proxy_logger.debug( - "Gray Swan Guardrail: no response messages detected; skipping post-call scan" - ) - return response - dynamic_body = self.get_guardrail_dynamic_request_body_params(data) or {} - - payload = self._prepare_payload(response_messages, dynamic_body) - if payload is None: - verbose_proxy_logger.debug( - "Gray Swan Guardrail: no content to scan; skipping request" - ) - return response - - await self.run_grayswan_guardrail(payload, data, GuardrailEventHooks.post_call) - - # If passthrough mode and detection info exists, replace response content with violation message - if self.on_flagged_action == "passthrough" and "metadata" in data: - guardrail_detections = data.get("metadata", {}).get( - "guardrail_detections", [] - ) - if guardrail_detections: - # Replace the model response content with guardrail violation message - violation_message = self._format_violation_message( - guardrail_detections, is_output=True - ) - - # Handle ModelResponse (OpenAI-style chat/text completions) - # Use isinstance to narrow the type for mypy - if isinstance(response, ModelResponse) and response.choices: - verbose_proxy_logger.debug( - "Gray Swan Guardrail: Replacing response content in ModelResponse format" - ) - for choice in response.choices: - # Handle chat completion format (message.content) - # Choices has message attribute, StreamingChoices has delta - if isinstance(choice, Choices) and hasattr(choice, "message") and hasattr( - choice.message, "content" - ): - choice.message.content = violation_message - # Handle text completion format (text) - # Text attribute might be set dynamically, use setattr - elif hasattr(choice, "text"): - setattr(choice, "text", violation_message) - - # Update finish_reason to indicate content filtering - if hasattr(choice, "finish_reason"): - choice.finish_reason = "content_filter" - - # Handle AnthropicMessagesResponse format - elif hasattr(response, "content") and isinstance(response.content, list): # type: ignore - verbose_proxy_logger.debug( - "Gray Swan Guardrail: Replacing response content in Anthropic Messages format" - ) - # Replace content blocks with text block containing violation message - response.content = [ # type: ignore - {"type": "text", "text": violation_message} - ] - # Update stop_reason if present - if hasattr(response, "stop_reason"): - response.stop_reason = "end_turn" # type: ignore - - else: - verbose_proxy_logger.warning( - "Gray Swan Guardrail: Passthrough mode enabled but response format not recognized. " - "Cannot replace content. Response type: %s", - type(response).__name__, - ) - - add_guardrail_to_applied_guardrails_header( - request_data=data, guardrail_name=self.guardrail_name - ) - return response + return result # ------------------------------------------------------------------ - # Core GraySwan interaction + # Core GraySwan API interaction # ------------------------------------------------------------------ - async def run_grayswan_guardrail( - self, - payload: dict, - data: Optional[dict] = None, - hook_type: Optional[GuardrailEventHooks] = None, - ): + async def _call_grayswan_api(self, payload: dict) -> Dict[str, Any]: + """Call the GraySwan monitoring API.""" headers = self._prepare_headers() try: @@ -326,66 +233,50 @@ async def run_grayswan_guardrail( verbose_proxy_logger.debug( "Gray Swan Guardrail: monitor response %s", safe_dumps(result) ) + return result except HTTPException: raise - except Exception as exc: # pragma: no cover - depends on HTTP client behaviour + except Exception as exc: verbose_proxy_logger.exception( "Gray Swan Guardrail: API request failed: %s", exc ) raise GraySwanGuardrailAPIError(str(exc)) from exc - self._process_grayswan_response(result, data, hook_type) - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _prepare_headers(self) -> Dict[str, str]: - return { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - "grayswan-api-key": self.api_key, - } - - def _prepare_payload( - self, messages: list[dict], dynamic_body: dict - ) -> Optional[Dict[str, Any]]: - payload: Dict[str, Any] = {} - payload["messages"] = messages - - categories = dynamic_body.get("categories") or self.categories - if categories: - payload["categories"] = categories - - policy_id = dynamic_body.get("policy_id") or self.policy_id - if policy_id: - payload["policy_id"] = policy_id + def _process_response( + self, + response_json: Dict[str, Any], + request_data: dict, + inputs: GenericGuardrailAPIInputs, + is_output: bool, + ) -> GenericGuardrailAPIInputs: + """ + Process GraySwan API response and handle violations. - reasoning_mode = dynamic_body.get("reasoning_mode") or self.reasoning_mode - if reasoning_mode: - payload["reasoning_mode"] = reasoning_mode + Args: + response_json: Response from GraySwan API + request_data: Original request data + inputs: The inputs being scanned + is_output: True if scanning model output, False for input - return payload + Returns: + GenericGuardrailAPIInputs - possibly modified with violation message - def _process_grayswan_response( - self, - response_json: Dict[str, Any], - data: Optional[dict] = None, - hook_type: Optional[GuardrailEventHooks] = None, - ) -> None: + Raises: + HTTPException: If content is blocked (block mode) + """ violation_score = float(response_json.get("violation", 0.0) or 0.0) - violated_rules = response_json.get("violated_rules", []) + violated_rules = response_json.get("violated_rule_descriptions", []) mutation_detected = response_json.get("mutation") ipi_detected = response_json.get("ipi") flagged = violation_score >= self.violation_threshold if not flagged: verbose_proxy_logger.debug( - "Gray Swan Guardrail: request passed (score=%s, rules=%s)", + "Gray Swan Guardrail: content passed (score=%s, threshold=%s)", violation_score, - violated_rules, + self.violation_threshold, ) - return + return inputs verbose_proxy_logger.warning( "Gray Swan Guardrail: violation score %.3f exceeds threshold %.3f", @@ -393,13 +284,17 @@ def _process_grayswan_response( self.violation_threshold, ) + detection_info = { + "guardrail": "grayswan", + "flagged": True, + "violation_score": violation_score, + "violated_rules": violated_rules, + "mutation": mutation_detected, + "ipi": ipi_detected, + } + if self.on_flagged_action == "block": - # Determine if violation was in input or output - violation_location = ( - "output" - if hook_type == GuardrailEventHooks.post_call - else "input" - ) + violation_location = "output" if is_output else "input" raise HTTPException( status_code=400, detail={ @@ -413,114 +308,144 @@ def _process_grayswan_response( ) elif self.on_flagged_action == "monitor": verbose_proxy_logger.info( - "Gray Swan Guardrail: Monitoring mode - allowing flagged content to proceed" + "Gray Swan Guardrail: Monitoring mode - allowing flagged content" ) + return inputs elif self.on_flagged_action == "passthrough": - # Store detection info - detection_info = { - "guardrail": "grayswan", - "flagged": True, - "violation_score": violation_score, - "violated_rules": violated_rules, - "mutation": mutation_detected, - "ipi": ipi_detected, - } - - # For pre_call and during_call, raise exception to short-circuit LLM call - if hook_type in ( - GuardrailEventHooks.pre_call, - GuardrailEventHooks.during_call, - ): - verbose_proxy_logger.info( - "Gray Swan Guardrail: Passthrough mode - raising exception to short-circuit LLM call" - ) - violation_message = self._format_violation_message( - [detection_info], is_output=False - ) + # Replace content with violation message + violation_message = self._format_violation_message( + detection_info, is_output=is_output + ) + verbose_proxy_logger.info( + "Gray Swan Guardrail: Passthrough mode - replacing content with violation message" + ) + + if not is_output: + # For pre-call (request), raise exception to short-circuit LLM call + # and return synthetic response with violation message self.raise_passthrough_exception( violation_message=violation_message, - request_data=data or {}, + request_data=request_data, detection_info=detection_info, ) - # For post_call, store in metadata to replace response later - verbose_proxy_logger.info( - "Gray Swan Guardrail: Passthrough mode - storing detection info in metadata" - ) - if data is not None: - if "metadata" not in data: - data["metadata"] = {} - if "guardrail_detections" not in data["metadata"]: - data["metadata"]["guardrail_detections"] = [] - data["metadata"]["guardrail_detections"].append(detection_info) + # For post-call (response), replace texts and let unified system apply them + inputs["texts"] = [violation_message] + return inputs + + return inputs + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _prepare_headers(self) -> Dict[str, str]: + return { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "grayswan-api-key": self.api_key, + } + + def _prepare_payload( + self, messages: List[Dict[str, str]], dynamic_body: dict + ) -> Optional[Dict[str, Any]]: + payload: Dict[str, Any] = {"messages": messages} + + categories = dynamic_body.get("categories") or self.categories + if categories: + payload["categories"] = categories + + policy_id = dynamic_body.get("policy_id") or self.policy_id + if policy_id: + payload["policy_id"] = policy_id + + reasoning_mode = dynamic_body.get("reasoning_mode") or self.reasoning_mode + if reasoning_mode: + payload["reasoning_mode"] = reasoning_mode + + return payload def _format_violation_message( - self, guardrail_detections: list, is_output: bool = False + self, detection_info: dict, is_output: bool = False ) -> str: """ - Format guardrail detections into a user-friendly violation message. + Format detection info into a user-friendly violation message. Args: - guardrail_detections: List of detection info dictionaries - is_output: True if violation is in model output (post_call), False if in input (pre_call/during_call) + detection_info: Detection info dictionary + is_output: True if violation is in model output, False if in input Returns: Formatted violation message string """ - if not guardrail_detections: - return "Content was flagged by guardrail" + violation_score = detection_info.get("violation_score", 0.0) + violated_rules = detection_info.get("violated_rules", []) + mutation = detection_info.get("mutation", False) + ipi = detection_info.get("ipi", False) - # Get the most recent detection (should be from this guardrail) - detection = guardrail_detections[-1] - - violation_score = detection.get("violation_score", 0.0) - violated_rules = detection.get("violated_rules", []) - mutation = detection.get("mutation", False) - ipi = detection.get("ipi", False) - - # Indicate whether violation was in input or output violation_location = "the model response" if is_output else "input query" message_parts = [ - f"Sorry I can't help with that. According to the Gray Swan Cygnal Guardrail, the {violation_location} has a violation score of {violation_score:.2f}.", + f"Sorry I can't help with that. According to the Gray Swan Cygnal Guardrail, " + f"the {violation_location} has a violation score of {violation_score:.2f}.", ] if violated_rules: - message_parts.append( - f"It was violating the rule(s): {', '.join(map(str, violated_rules))}." - ) + formatted_rules = self._format_violated_rules(violated_rules) + if formatted_rules: + message_parts.append(f"It was violating the rule(s): {formatted_rules}.") if mutation: message_parts.append( - "Mutation effort to make the harmful intention disguised was DETECTED." + "A potential prompt manipulation/mutation was detected." ) if ipi: - message_parts.append("Indirect Prompt Injection was DETECTED.") + message_parts.append( + "Indirect prompt injection indicators were detected." + ) return "\n".join(message_parts) - def _resolve_threshold(self, threshold: Optional[float]) -> float: - if threshold is not None: - return min(max(threshold, 0.0), 1.0) + def _format_violated_rules(self, violated_rules: List) -> str: + """Format violated rules list into a readable string.""" + formatted: List[str] = [] + for rule in violated_rules: + if isinstance(rule, dict): + # New format: {'rule': 6, 'name': 'Illegal Activities...', 'description': '...'} + rule_num = rule.get("rule", "") + rule_name = rule.get("name", "") + rule_desc = rule.get("description", "") + if rule_num and rule_name: + if rule_desc: + formatted.append(f"#{rule_num} {rule_name}: {rule_desc}") + else: + formatted.append(f"#{rule_num} {rule_name}") + elif rule_name: + formatted.append(rule_name) + else: + formatted.append(str(rule)) + else: + # Legacy format: simple value + formatted.append(str(rule)) + + return ", ".join(formatted) + + def _resolve_threshold(self, value: Optional[float]) -> float: + if value is not None: + return float(value) + env_val = os.getenv("GRAYSWAN_VIOLATION_THRESHOLD") + if env_val: + try: + return float(env_val) + except ValueError: + pass return 0.5 - def _resolve_reasoning_mode(self, candidate: Optional[str]) -> Optional[str]: - if candidate is None: - return None - normalised = candidate.strip().lower() - if normalised in self.SUPPORTED_REASONING_MODES: - return normalised - verbose_proxy_logger.warning( - "Gray Swan Guardrail: ignoring unsupported reasoning_mode '%s'", - candidate, - ) + def _resolve_reasoning_mode(self, value: Optional[str]) -> Optional[str]: + if value and value.lower() in self.SUPPORTED_REASONING_MODES: + return value.lower() + env_val = os.getenv("GRAYSWAN_REASONING_MODE") + if env_val and env_val.lower() in self.SUPPORTED_REASONING_MODES: + return env_val.lower() return None - - @staticmethod - def get_config_model(): - from litellm.types.proxy.guardrails.guardrail_hooks.grayswan import ( - GraySwanGuardrailConfigModel, - ) - - return GraySwanGuardrailConfigModel diff --git a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py index cece49e99cb..a09f0bdc5c4 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py @@ -180,7 +180,7 @@ async def async_post_call_success_hook( call_type: Optional[CallTypesLiteral] = None if user_api_key_dict.request_route is not None: call_types = get_call_types_for_route(user_api_key_dict.request_route) - if call_types is not None: + if call_types is not None and len(call_types) > 0: call_type = call_types[0] if call_type is None: call_type = _infer_call_type(call_type=None, completion_response=response) @@ -238,19 +238,36 @@ async def async_post_call_streaming_iterator_hook( "guardrail_to_apply", None ) - # Get sampling rate from guardrail config or optional_params, default to 5 + # Get streaming configuration from guardrail or optional_params sampling_rate = 5 + end_of_stream_only = False # If True, only apply guardrail at end of stream + if guardrail_to_apply is not None: - # Check guardrail config first - guardrail_config = getattr(guardrail_to_apply, "guardrail_config", {}) - sampling_rate = guardrail_config.get( - "streaming_sampling_rate", sampling_rate + # Check direct attributes on guardrail first + sampling_rate = getattr( + guardrail_to_apply, "streaming_sampling_rate", sampling_rate + ) + end_of_stream_only = getattr( + guardrail_to_apply, "streaming_end_of_stream_only", end_of_stream_only ) + # Also check guardrail_config dict if present + guardrail_config = getattr(guardrail_to_apply, "guardrail_config", {}) + if isinstance(guardrail_config, dict): + sampling_rate = guardrail_config.get( + "streaming_sampling_rate", sampling_rate + ) + end_of_stream_only = guardrail_config.get( + "streaming_end_of_stream_only", end_of_stream_only + ) + # Also check optional_params as fallback sampling_rate = self.optional_params.get( "streaming_sampling_rate", sampling_rate ) + end_of_stream_only = self.optional_params.get( + "streaming_end_of_stream_only", end_of_stream_only + ) if guardrail_to_apply is None: async for item in response: @@ -306,6 +323,11 @@ async def async_post_call_streaming_iterator_hook( yield remaining_item return + # If end_of_stream_only mode, yield chunks without processing + if end_of_stream_only: + yield item + continue + # Process chunk based on sampling rate if chunk_counter % sampling_rate == 0: diff --git a/litellm/proxy/guardrails/guardrail_registry.py b/litellm/proxy/guardrails/guardrail_registry.py index f8e86334f83..2d5f07dbf6e 100644 --- a/litellm/proxy/guardrails/guardrail_registry.py +++ b/litellm/proxy/guardrails/guardrail_registry.py @@ -19,6 +19,10 @@ LitellmParams, SupportedGuardrailIntegrations, ) +from litellm.proxy.guardrails.guardrail_hooks.grayswan import ( + GraySwanGuardrail, + initialize_guardrail as initialize_grayswan, +) from .guardrail_initializers import ( initialize_bedrock, @@ -36,9 +40,12 @@ SupportedGuardrailIntegrations.PRESIDIO.value: initialize_presidio, SupportedGuardrailIntegrations.HIDE_SECRETS.value: initialize_hide_secrets, SupportedGuardrailIntegrations.TOOL_PERMISSION.value: initialize_tool_permission, + SupportedGuardrailIntegrations.GRAYSWAN.value: initialize_grayswan, } -guardrail_class_registry: Dict[str, Type[CustomGuardrail]] = {} +guardrail_class_registry: Dict[str, Type[CustomGuardrail]] = { + SupportedGuardrailIntegrations.GRAYSWAN.value: GraySwanGuardrail +} def get_guardrail_initializer_from_hooks(): diff --git a/litellm/proxy/response_api_endpoints/endpoints.py b/litellm/proxy/response_api_endpoints/endpoints.py index 9d5bccecdf8..50eb729d8e4 100644 --- a/litellm/proxy/response_api_endpoints/endpoints.py +++ b/litellm/proxy/response_api_endpoints/endpoints.py @@ -1,12 +1,16 @@ import asyncio +import time +import uuid from typing import Any, AsyncIterator, cast from fastapi import APIRouter, Depends, HTTPException, Request, Response from litellm._logging import verbose_proxy_logger +from litellm.integrations.custom_guardrail import ModifyResponseException from litellm.proxy._types import * from litellm.proxy.auth.user_api_key_auth import UserAPIKeyAuth, user_api_key_auth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing +from litellm.types.llms.openai import ResponsesAPIResponse from litellm.types.responses.main import DeleteResponseResult router = APIRouter() @@ -169,6 +173,26 @@ async def responses_api( user_api_base=user_api_base, version=version, ) + except ModifyResponseException as e: + # Guardrail passthrough: return violation message in Responses API format (200) + _data = e.request_data + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=_data, + ) + + violation_text = e.message + response_obj = ResponsesAPIResponse( + id=f"resp_{uuid.uuid4()}", + object="response", + created_at=int(time.time()), + model=e.model or data.get("model"), + output=[{"content": [{"type": "text", "text": violation_text}]}], + status="completed", + usage={"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}, + ) + return response_obj except Exception as e: raise await processor._handle_llm_api_exception( e=e,