diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index c157b3aa605..71c683cc9f4 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Dict, List, Literal, Optional, Union from litellm._logging import verbose_logger @@ -186,20 +187,17 @@ def _validate_premium_user(self) -> bool: def add_standard_logging_guardrail_information_to_request_data( self, - guardrail_json_response: Union[Exception, str, dict], + guardrail_json_response: Union[Exception, str, dict, List[dict]], request_data: dict, guardrail_status: Literal["success", "failure"], + start_time: Optional[float] = None, + end_time: Optional[float] = None, + duration: Optional[float] = None, + masked_entity_count: Optional[Dict[str, int]] = None, ) -> None: """ Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc. """ - from litellm.proxy.proxy_server import premium_user - - if premium_user is not True: - verbose_logger.warning( - f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}" - ) - return if isinstance(guardrail_json_response, Exception): guardrail_json_response = str(guardrail_json_response) slg = StandardLoggingGuardrailInformation( @@ -207,6 +205,10 @@ def add_standard_logging_guardrail_information_to_request_data( guardrail_mode=self.event_hook, guardrail_response=guardrail_json_response, guardrail_status=guardrail_status, + start_time=start_time, + end_time=end_time, + duration=duration, + masked_entity_count=masked_entity_count, ) if "metadata" in request_data: request_data["metadata"]["standard_logging_guardrail_information"] = slg @@ -244,6 +246,54 @@ async def apply_guardrail( """ return text + def _process_response( + self, + response: Optional[Dict], + request_data: dict, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + duration: Optional[float] = None, + ): + """ + Add StandardLoggingGuardrailInformation to the request data + + This gets logged on downsteam Langfuse, DataDog, etc. + """ + # Convert None to empty dict to satisfy type requirements + guardrail_response = {} if response is None else response + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=guardrail_response, + request_data=request_data, + guardrail_status="success", + duration=duration, + start_time=start_time, + end_time=end_time, + ) + return response + + def _process_error( + self, + e: Exception, + request_data: dict, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + duration: Optional[float] = None, + ): + """ + Add StandardLoggingGuardrailInformation to the request data + + This gets logged on downsteam Langfuse, DataDog, etc. + """ + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=e, + request_data=request_data, + guardrail_status="failure", + duration=duration, + start_time=start_time, + end_time=end_time, + ) + raise e + def log_guardrail_information(func): """ @@ -259,21 +309,7 @@ def log_guardrail_information(func): import asyncio import functools - def process_response(self, response, request_data): - self.add_standard_logging_guardrail_information_to_request_data( - guardrail_json_response=response, - request_data=request_data, - guardrail_status="success", - ) - return response - - def process_error(self, e, request_data): - self.add_standard_logging_guardrail_information_to_request_data( - guardrail_json_response=e, - request_data=request_data, - guardrail_status="failure", - ) - raise e + start_time = datetime.now() @functools.wraps(func) async def async_wrapper(*args, **kwargs): @@ -283,9 +319,21 @@ async def async_wrapper(*args, **kwargs): ) try: response = await func(*args, **kwargs) - return process_response(self, response, request_data) + return self._process_response( + response=response, + request_data=request_data, + start_time=start_time.timestamp(), + end_time=datetime.now().timestamp(), + duration=(datetime.now() - start_time).total_seconds(), + ) except Exception as e: - return process_error(self, e, request_data) + return self._process_error( + e=e, + request_data=request_data, + start_time=start_time.timestamp(), + end_time=datetime.now().timestamp(), + duration=(datetime.now() - start_time).total_seconds(), + ) @functools.wraps(func) def sync_wrapper(*args, **kwargs): @@ -295,9 +343,17 @@ def sync_wrapper(*args, **kwargs): ) try: response = func(*args, **kwargs) - return process_response(self, response, request_data) + return self._process_response( + response=response, + request_data=request_data, + duration=(datetime.now() - start_time).total_seconds(), + ) except Exception as e: - return process_error(self, e, request_data) + return self._process_error( + e=e, + request_data=request_data, + duration=(datetime.now() - start_time).total_seconds(), + ) @functools.wraps(func) def wrapper(*args, **kwargs): diff --git a/litellm/integrations/langfuse/langfuse.py b/litellm/integrations/langfuse/langfuse.py index d0472ee6383..02862c52c5d 100644 --- a/litellm/integrations/langfuse/langfuse.py +++ b/litellm/integrations/langfuse/langfuse.py @@ -27,9 +27,12 @@ ) if TYPE_CHECKING: + from langfuse.client import StatefulTraceClient + from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache else: DynamicLoggingCache = Any + StatefulTraceClient = Any class LangFuseLogger: @@ -626,16 +629,17 @@ def _log_langfuse_v2( # noqa: PLR0915 if key.lower() not in ["authorization", "cookie", "referer"]: clean_headers[key] = value - # clean_metadata["request"] = { - # "method": method, - # "url": url, - # "headers": clean_headers, - # } - trace = self.Langfuse.trace(**trace_params) + trace: StatefulTraceClient = self.Langfuse.trace(**trace_params) # Log provider specific information as a span log_provider_specific_information_as_span(trace, clean_metadata) + # Log guardrail information as a span + self._log_guardrail_information_as_span( + trace=trace, + standard_logging_object=standard_logging_object, + ) + generation_id = None usage = None if response_obj is not None: @@ -809,6 +813,47 @@ def _get_langfuse_flush_interval(flush_interval: int) -> int: """ return int(os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval) + def _log_guardrail_information_as_span( + self, + trace: StatefulTraceClient, + standard_logging_object: Optional[StandardLoggingPayload], + ): + """ + Log guardrail information as a span + """ + if standard_logging_object is None: + verbose_logger.debug( + "Not logging guardrail information as span because standard_logging_object is None" + ) + return + + guardrail_information = standard_logging_object.get( + "guardrail_information", None + ) + if guardrail_information is None: + verbose_logger.debug( + "Not logging guardrail information as span because guardrail_information is None" + ) + return + + span = trace.span( + name="guardrail", + input=guardrail_information.get("guardrail_request", None), + output=guardrail_information.get("guardrail_response", None), + metadata={ + "guardrail_name": guardrail_information.get("guardrail_name", None), + "guardrail_mode": guardrail_information.get("guardrail_mode", None), + "guardrail_masked_entity_count": guardrail_information.get( + "masked_entity_count", None + ), + }, + start_time=guardrail_information.get("start_time", None), # type: ignore + end_time=guardrail_information.get("end_time", None), # type: ignore + ) + + verbose_logger.debug(f"Logged guardrail information as span: {span}") + span.end() + def _add_prompt_to_generation_params( generation_params: dict, diff --git a/litellm/proxy/guardrails/guardrail_hooks/presidio.py b/litellm/proxy/guardrails/guardrail_hooks/presidio.py index 95fc6915ad0..0b7e38ad6ca 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/presidio.py +++ b/litellm/proxy/guardrails/guardrail_hooks/presidio.py @@ -11,7 +11,8 @@ import asyncio import json import uuid -from typing import Any, Dict, List, Optional, Tuple, Union, cast +from datetime import datetime +from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast import aiohttp @@ -20,10 +21,7 @@ from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.exceptions import BlockedPiiEntityError -from litellm.integrations.custom_guardrail import ( - CustomGuardrail, - log_guardrail_information, -) +from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.proxy._types import UserAPIKeyAuth from litellm.types.guardrails import ( GuardrailEventHooks, @@ -218,7 +216,11 @@ async def analyze_text( raise e async def anonymize_text( - self, text: str, analyze_results: Any, output_parse_pii: bool + self, + text: str, + analyze_results: Any, + output_parse_pii: bool, + masked_entity_count: Dict[str, int], ) -> str: """ Send analysis results to the Presidio anonymizer endpoint to get redacted text @@ -256,6 +258,11 @@ async def anonymize_text( ] # get text it'll replace new_text = new_text[:start] + replacement + new_text[end:] + entity_type = item.get("entity_type", None) + if entity_type is not None: + masked_entity_count[entity_type] = ( + masked_entity_count.get(entity_type, 0) + 1 + ) return redacted_text["text"] else: raise Exception(f"Invalid anonymizer response: {redacted_text}") @@ -300,6 +307,11 @@ async def check_pii( """ Calls Presidio Analyze + Anonymize endpoints for PII Analysis + Masking """ + start_time = datetime.now() + analyze_results: Optional[Union[List[PresidioAnalyzeResponseItem], Dict]] = None + status: Literal["success", "failure"] = "success" + masked_entity_count: Dict[str, int] = {} + exception_str: str = "" try: if self.mock_redacted_text is not None: redacted_text = self.mock_redacted_text @@ -324,13 +336,33 @@ async def check_pii( text=text, analyze_results=analyze_results, output_parse_pii=output_parse_pii, + masked_entity_count=masked_entity_count, ) - return redacted_text["text"] except Exception as e: + status = "failure" + exception_str = str(e) raise e + finally: + #################################################### + # Create Guardrail Trace for logging on Langfuse, Datadog, etc. + #################################################### + guardrail_json_response: Union[Exception, str, dict, List[dict]] = {} + if status == "success": + if isinstance(analyze_results, List): + guardrail_json_response = [dict(item) for item in analyze_results] + else: + guardrail_json_response = exception_str + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_json_response=guardrail_json_response, + request_data=request_data, + guardrail_status=status, + start_time=start_time.timestamp(), + end_time=datetime.now().timestamp(), + duration=(datetime.now() - start_time).total_seconds(), + masked_entity_count=masked_entity_count, + ) - @log_guardrail_information async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, @@ -394,7 +426,6 @@ async def async_pre_call_hook( except Exception as e: raise e - @log_guardrail_information def logging_hook( self, kwargs: dict, result: Any, call_type: str ) -> Tuple[dict, Any]: @@ -427,7 +458,6 @@ def run_in_new_loop(): # No running event loop, we can safely run in this thread return run_in_new_loop() - @log_guardrail_information async def async_logging_hook( self, kwargs: dict, result: Any, call_type: str ) -> Tuple[dict, Any]: @@ -476,7 +506,6 @@ async def async_logging_hook( return kwargs, result - @log_guardrail_information async def async_post_call_success_hook( # type: ignore self, data: dict, diff --git a/litellm/types/utils.py b/litellm/types/utils.py index e6bd2802865..f96f35d1b29 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -633,23 +633,23 @@ def __init__( if audio is None: # delete audio from self # OpenAI compatible APIs like mistral API will raise an error if audio is passed in - if hasattr(self, 'audio'): + if hasattr(self, "audio"): del self.audio if annotations is None: # ensure default response matches OpenAI spec # Some OpenAI compatible APIs raise an error if annotations are passed in - if hasattr(self, 'annotations'): + if hasattr(self, "annotations"): del self.annotations if reasoning_content is None: # ensure default response matches OpenAI spec - if hasattr(self, 'reasoning_content'): + if hasattr(self, "reasoning_content"): del self.reasoning_content if thinking_blocks is None: # ensure default response matches OpenAI spec - if hasattr(self, 'thinking_blocks'): + if hasattr(self, "thinking_blocks"): del self.thinking_blocks add_provider_specific_fields(self, provider_specific_fields) @@ -1870,8 +1870,24 @@ class StandardLoggingPayloadErrorInformation(TypedDict, total=False): class StandardLoggingGuardrailInformation(TypedDict, total=False): guardrail_name: Optional[str] guardrail_mode: Optional[Union[GuardrailEventHooks, List[GuardrailEventHooks]]] - guardrail_response: Optional[Union[dict, str]] + guardrail_request: Optional[dict] + guardrail_response: Optional[Union[dict, str, List[dict]]] guardrail_status: Literal["success", "failure"] + start_time: Optional[float] + end_time: Optional[float] + duration: Optional[float] + """ + Duration of the guardrail in seconds + """ + + masked_entity_count: Optional[Dict[str, int]] + """ + Count of masked entities + { + "CREDIT_CARD": 2, + "PHONE": 1 + } + """ StandardLoggingPayloadStatus = Literal["success", "failure"] diff --git a/tests/guardrails_tests/test_presidio_pii.py b/tests/guardrails_tests/test_presidio_pii.py index 1a8e4723068..25e027c0c2f 100644 --- a/tests/guardrails_tests/test_presidio_pii.py +++ b/tests/guardrails_tests/test_presidio_pii.py @@ -236,20 +236,24 @@ async def test_presidio_pre_call_hook_with_different_call_types(call_type): def test_validate_environment_missing_http(base_url): pii_masking = _OPTIONAL_PresidioPIIMasking(mock_testing=True) - os.environ["PRESIDIO_ANALYZER_API_BASE"] = f"{base_url}/analyze" - os.environ["PRESIDIO_ANONYMIZER_API_BASE"] = f"{base_url}/anonymize" - pii_masking.validate_environment() + # Use patch.dict to temporarily modify environment variables only for this test + env_vars = { + "PRESIDIO_ANALYZER_API_BASE": f"{base_url}/analyze", + "PRESIDIO_ANONYMIZER_API_BASE": f"{base_url}/anonymize" + } + with patch.dict(os.environ, env_vars): + pii_masking.validate_environment() - expected_url = base_url - if not (base_url.startswith("https://") or base_url.startswith("http://")): - expected_url = "http://" + base_url + expected_url = base_url + if not (base_url.startswith("https://") or base_url.startswith("http://")): + expected_url = "http://" + base_url - assert ( - pii_masking.presidio_anonymizer_api_base == f"{expected_url}/anonymize/" - ), "Got={}, Expected={}".format( - pii_masking.presidio_anonymizer_api_base, f"{expected_url}/anonymize/" - ) - assert pii_masking.presidio_analyzer_api_base == f"{expected_url}/analyze/" + assert ( + pii_masking.presidio_anonymizer_api_base == f"{expected_url}/anonymize/" + ), "Got={}, Expected={}".format( + pii_masking.presidio_anonymizer_api_base, f"{expected_url}/anonymize/" + ) + assert pii_masking.presidio_analyzer_api_base == f"{expected_url}/analyze/" @pytest.mark.asyncio @@ -433,6 +437,10 @@ async def test_presidio_pii_masking_logging_output_only_no_pre_api_hook(): @pytest.mark.asyncio +@patch.dict(os.environ, { + "PRESIDIO_ANALYZER_API_BASE": "http://localhost:5002", + "PRESIDIO_ANONYMIZER_API_BASE": "http://localhost:5001" +}) async def test_presidio_pii_masking_logging_output_only_logged_response_guardrails_config(): from typing import Dict, List, Optional @@ -445,9 +453,8 @@ async def test_presidio_pii_masking_logging_output_only_logged_response_guardrai ) litellm.set_verbose = True - os.environ["PRESIDIO_ANALYZER_API_BASE"] = "http://localhost:5002" - os.environ["PRESIDIO_ANONYMIZER_API_BASE"] = "http://localhost:5001" - + # Environment variables are now patched via the decorator instead of setting them directly + guardrails_config: List[Dict[str, GuardrailItemSpec]] = [ { "pii_masking": { diff --git a/tests/guardrails_tests/test_tracing_guardrails.py b/tests/guardrails_tests/test_tracing_guardrails.py new file mode 100644 index 00000000000..9a607e7a18c --- /dev/null +++ b/tests/guardrails_tests/test_tracing_guardrails.py @@ -0,0 +1,184 @@ +import sys +import os +import io, asyncio +import json +import pytest +import time +from litellm import mock_completion +from unittest.mock import MagicMock, AsyncMock, patch +sys.path.insert(0, os.path.abspath("../..")) +import litellm +from litellm.proxy.guardrails.guardrail_hooks.presidio import _OPTIONAL_PresidioPIIMasking, PresidioPerRequestConfig +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.utils import StandardLoggingPayload, StandardLoggingGuardrailInformation +from litellm.types.guardrails import GuardrailEventHooks +from typing import Optional + + +class TestCustomLogger(CustomLogger): + def __init__(self, *args, **kwargs): + self.standard_logging_payload: Optional[StandardLoggingPayload] = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self.standard_logging_payload = kwargs.get("standard_logging_object") + pass + +@pytest.mark.asyncio +async def test_standard_logging_payload_includes_guardrail_information(): + """ + Test that the standard logging payload includes the guardrail information when a guardrail is applied + """ + test_custom_logger = TestCustomLogger() + litellm.callbacks = [test_custom_logger] + presidio_guard = _OPTIONAL_PresidioPIIMasking( + guardrail_name="presidio_guard", + event_hook=GuardrailEventHooks.pre_call, + presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"), + presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"), + ) + # 1. call the pre call hook with guardrail + request_data = { + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Hello, my phone number is +1 412 555 1212"}, + ], + "mock_response": "Hello", + "guardrails": ["presidio_guard"], + "metadata": {}, + } + await presidio_guard.async_pre_call_hook( + user_api_key_dict={}, + cache=None, + data=request_data, + call_type="acompletion" + ) + + # 2. call litellm.acompletion + response = await litellm.acompletion(**request_data) + + # 3. assert that the standard logging payload includes the guardrail information + await asyncio.sleep(1) + print("got standard logging payload=", json.dumps(test_custom_logger.standard_logging_payload, indent=4, default=str)) + assert test_custom_logger.standard_logging_payload is not None + assert test_custom_logger.standard_logging_payload["guardrail_information"] is not None + assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_name"] == "presidio_guard" + assert test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_mode"] == GuardrailEventHooks.pre_call + + # assert that the guardrail_response is a response from presidio analyze + presidio_response = test_custom_logger.standard_logging_payload["guardrail_information"]["guardrail_response"] + assert isinstance(presidio_response, list) + for response_item in presidio_response: + assert "analysis_explanation" in response_item + assert "start" in response_item + assert "end" in response_item + assert "score" in response_item + assert "entity_type" in response_item + assert "recognition_metadata" in response_item + + + # assert that the duration is not None + assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] is not None + assert test_custom_logger.standard_logging_payload["guardrail_information"]["duration"] > 0 + + # assert that we get the count of masked entities + assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"] is not None + assert test_custom_logger.standard_logging_payload["guardrail_information"]["masked_entity_count"]["PHONE_NUMBER"] == 1 + + + + +@pytest.mark.asyncio +@pytest.mark.skip(reason="Local only test") +async def test_langfuse_trace_includes_guardrail_information(): + """ + Test that the langfuse trace includes the guardrail information when a guardrail is applied + """ + import httpx + from unittest.mock import AsyncMock, patch + from litellm.integrations.langfuse.langfuse_prompt_management import LangfusePromptManagement + callback = LangfusePromptManagement(flush_interval=3) + import json + + # Create a mock Response object + mock_response = AsyncMock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.json.return_value = {"status": "success"} + + # Create mock for httpx.Client.post + mock_post = AsyncMock() + mock_post.return_value = mock_response + + with patch("httpx.Client.post", mock_post): + litellm._turn_on_debug() + litellm.callbacks = [callback] + presidio_guard = _OPTIONAL_PresidioPIIMasking( + guardrail_name="presidio_guard", + event_hook=GuardrailEventHooks.pre_call, + presidio_analyzer_api_base=os.getenv("PRESIDIO_ANALYZER_API_BASE"), + presidio_anonymizer_api_base=os.getenv("PRESIDIO_ANONYMIZER_API_BASE"), + ) + # 1. call the pre call hook with guardrail + request_data = { + "model": "gpt-4o", + "messages": [ + {"role": "user", "content": "Hello, my phone number is +1 412 555 1212"}, + ], + "mock_response": "Hello", + "guardrails": ["presidio_guard"], + "metadata": {}, + } + await presidio_guard.async_pre_call_hook( + user_api_key_dict={}, + cache=None, + data=request_data, + call_type="acompletion" + ) + + # 2. call litellm.acompletion + response = await litellm.acompletion(**request_data) + + # 3. Wait for async logging operations to complete + await asyncio.sleep(5) + + # 4. Verify the Langfuse payload + assert mock_post.call_count >= 1 + url = mock_post.call_args[0][0] + request_body = mock_post.call_args[1].get("content") + + # Parse the JSON body + actual_payload = json.loads(request_body) + print("\nLangfuse payload:", json.dumps(actual_payload, indent=2)) + + # Look for the guardrail span in the payload + guardrail_span = None + for item in actual_payload["batch"]: + if (item["type"] == "span-create" and + item["body"].get("name") == "guardrail"): + guardrail_span = item + break + + # Assert that the guardrail span exists + assert guardrail_span is not None, "No guardrail span found in Langfuse payload" + + # Validate the structure of the guardrail span + assert guardrail_span["body"]["name"] == "guardrail" + assert "metadata" in guardrail_span["body"] + assert guardrail_span["body"]["metadata"]["guardrail_name"] == "presidio_guard" + assert guardrail_span["body"]["metadata"]["guardrail_mode"] == GuardrailEventHooks.pre_call + assert "guardrail_masked_entity_count" in guardrail_span["body"]["metadata"] + assert guardrail_span["body"]["metadata"]["guardrail_masked_entity_count"]["PHONE_NUMBER"] == 1 + + # Validate the output format matches the expected structure + assert "output" in guardrail_span["body"] + assert isinstance(guardrail_span["body"]["output"], list) + assert len(guardrail_span["body"]["output"]) > 0 + + # Validate the first output item has the expected structure + output_item = guardrail_span["body"]["output"][0] + assert "entity_type" in output_item + assert output_item["entity_type"] == "PHONE_NUMBER" + assert "score" in output_item + assert "start" in output_item + assert "end" in output_item + assert "recognition_metadata" in output_item + assert "recognizer_name" in output_item["recognition_metadata"]