Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 83 additions & 27 deletions litellm/integrations/custom_guardrail.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime
from typing import Dict, List, Literal, Optional, Union

from litellm._logging import verbose_logger
Expand Down Expand Up @@ -186,27 +187,28 @@ def _validate_premium_user(self) -> bool:

def add_standard_logging_guardrail_information_to_request_data(
self,
guardrail_json_response: Union[Exception, str, dict],
guardrail_json_response: Union[Exception, str, dict, List[dict]],
request_data: dict,
guardrail_status: Literal["success", "failure"],
start_time: Optional[float] = None,
end_time: Optional[float] = None,
duration: Optional[float] = None,
masked_entity_count: Optional[Dict[str, int]] = None,
) -> None:
"""
Builds `StandardLoggingGuardrailInformation` and adds it to the request metadata so it can be used for logging to DataDog, Langfuse, etc.
"""
from litellm.proxy.proxy_server import premium_user

if premium_user is not True:
verbose_logger.warning(
f"Guardrail Tracing is only available for premium users. Skipping guardrail logging for guardrail={self.guardrail_name} event_hook={self.event_hook}"
)
return
if isinstance(guardrail_json_response, Exception):
guardrail_json_response = str(guardrail_json_response)
slg = StandardLoggingGuardrailInformation(
guardrail_name=self.guardrail_name,
guardrail_mode=self.event_hook,
guardrail_response=guardrail_json_response,
guardrail_status=guardrail_status,
start_time=start_time,
end_time=end_time,
duration=duration,
masked_entity_count=masked_entity_count,
)
if "metadata" in request_data:
request_data["metadata"]["standard_logging_guardrail_information"] = slg
Expand Down Expand Up @@ -244,6 +246,54 @@ async def apply_guardrail(
"""
return text

def _process_response(
self,
response: Optional[Dict],
request_data: dict,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
duration: Optional[float] = None,
):
"""
Add StandardLoggingGuardrailInformation to the request data

This gets logged on downsteam Langfuse, DataDog, etc.
"""
# Convert None to empty dict to satisfy type requirements
guardrail_response = {} if response is None else response
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=guardrail_response,
request_data=request_data,
guardrail_status="success",
duration=duration,
start_time=start_time,
end_time=end_time,
)
return response

def _process_error(
self,
e: Exception,
request_data: dict,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
duration: Optional[float] = None,
):
"""
Add StandardLoggingGuardrailInformation to the request data

This gets logged on downsteam Langfuse, DataDog, etc.
"""
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=e,
request_data=request_data,
guardrail_status="failure",
duration=duration,
start_time=start_time,
end_time=end_time,
)
raise e


def log_guardrail_information(func):
"""
Expand All @@ -259,21 +309,7 @@ def log_guardrail_information(func):
import asyncio
import functools

def process_response(self, response, request_data):
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=response,
request_data=request_data,
guardrail_status="success",
)
return response

def process_error(self, e, request_data):
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=e,
request_data=request_data,
guardrail_status="failure",
)
raise e
start_time = datetime.now()

@functools.wraps(func)
async def async_wrapper(*args, **kwargs):
Expand All @@ -283,9 +319,21 @@ async def async_wrapper(*args, **kwargs):
)
try:
response = await func(*args, **kwargs)
return process_response(self, response, request_data)
return self._process_response(
response=response,
request_data=request_data,
start_time=start_time.timestamp(),
end_time=datetime.now().timestamp(),
duration=(datetime.now() - start_time).total_seconds(),
)
except Exception as e:
return process_error(self, e, request_data)
return self._process_error(
e=e,
request_data=request_data,
start_time=start_time.timestamp(),
end_time=datetime.now().timestamp(),
duration=(datetime.now() - start_time).total_seconds(),
)

@functools.wraps(func)
def sync_wrapper(*args, **kwargs):
Expand All @@ -295,9 +343,17 @@ def sync_wrapper(*args, **kwargs):
)
try:
response = func(*args, **kwargs)
return process_response(self, response, request_data)
return self._process_response(
response=response,
request_data=request_data,
duration=(datetime.now() - start_time).total_seconds(),
)
except Exception as e:
return process_error(self, e, request_data)
return self._process_error(
e=e,
request_data=request_data,
duration=(datetime.now() - start_time).total_seconds(),
)

@functools.wraps(func)
def wrapper(*args, **kwargs):
Expand Down
57 changes: 51 additions & 6 deletions litellm/integrations/langfuse/langfuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
)

if TYPE_CHECKING:
from langfuse.client import StatefulTraceClient

from litellm.litellm_core_utils.litellm_logging import DynamicLoggingCache
else:
DynamicLoggingCache = Any
StatefulTraceClient = Any


class LangFuseLogger:
Expand Down Expand Up @@ -626,16 +629,17 @@ def _log_langfuse_v2( # noqa: PLR0915
if key.lower() not in ["authorization", "cookie", "referer"]:
clean_headers[key] = value

# clean_metadata["request"] = {
# "method": method,
# "url": url,
# "headers": clean_headers,
# }
trace = self.Langfuse.trace(**trace_params)
trace: StatefulTraceClient = self.Langfuse.trace(**trace_params)

# Log provider specific information as a span
log_provider_specific_information_as_span(trace, clean_metadata)

# Log guardrail information as a span
self._log_guardrail_information_as_span(
trace=trace,
standard_logging_object=standard_logging_object,
)

generation_id = None
usage = None
if response_obj is not None:
Expand Down Expand Up @@ -809,6 +813,47 @@ def _get_langfuse_flush_interval(flush_interval: int) -> int:
"""
return int(os.getenv("LANGFUSE_FLUSH_INTERVAL") or flush_interval)

def _log_guardrail_information_as_span(
self,
trace: StatefulTraceClient,
standard_logging_object: Optional[StandardLoggingPayload],
):
"""
Log guardrail information as a span
"""
if standard_logging_object is None:
verbose_logger.debug(
"Not logging guardrail information as span because standard_logging_object is None"
)
return

guardrail_information = standard_logging_object.get(
"guardrail_information", None
)
if guardrail_information is None:
verbose_logger.debug(
"Not logging guardrail information as span because guardrail_information is None"
)
return

span = trace.span(
name="guardrail",
input=guardrail_information.get("guardrail_request", None),
output=guardrail_information.get("guardrail_response", None),
metadata={
"guardrail_name": guardrail_information.get("guardrail_name", None),
"guardrail_mode": guardrail_information.get("guardrail_mode", None),
"guardrail_masked_entity_count": guardrail_information.get(
"masked_entity_count", None
),
},
start_time=guardrail_information.get("start_time", None), # type: ignore
end_time=guardrail_information.get("end_time", None), # type: ignore
)

verbose_logger.debug(f"Logged guardrail information as span: {span}")
span.end()


def _add_prompt_to_generation_params(
generation_params: dict,
Expand Down
51 changes: 40 additions & 11 deletions litellm/proxy/guardrails/guardrail_hooks/presidio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import asyncio
import json
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, cast

import aiohttp

Expand All @@ -20,10 +21,7 @@
from litellm._logging import verbose_proxy_logger
from litellm.caching.caching import DualCache
from litellm.exceptions import BlockedPiiEntityError
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.guardrails import (
GuardrailEventHooks,
Expand Down Expand Up @@ -218,7 +216,11 @@ async def analyze_text(
raise e

async def anonymize_text(
self, text: str, analyze_results: Any, output_parse_pii: bool
self,
text: str,
analyze_results: Any,
output_parse_pii: bool,
masked_entity_count: Dict[str, int],
) -> str:
"""
Send analysis results to the Presidio anonymizer endpoint to get redacted text
Expand Down Expand Up @@ -256,6 +258,11 @@ async def anonymize_text(
] # get text it'll replace

new_text = new_text[:start] + replacement + new_text[end:]
entity_type = item.get("entity_type", None)
if entity_type is not None:
masked_entity_count[entity_type] = (
masked_entity_count.get(entity_type, 0) + 1
)
return redacted_text["text"]
else:
raise Exception(f"Invalid anonymizer response: {redacted_text}")
Expand Down Expand Up @@ -300,6 +307,11 @@ async def check_pii(
"""
Calls Presidio Analyze + Anonymize endpoints for PII Analysis + Masking
"""
start_time = datetime.now()
analyze_results: Optional[Union[List[PresidioAnalyzeResponseItem], Dict]] = None
status: Literal["success", "failure"] = "success"
masked_entity_count: Dict[str, int] = {}
exception_str: str = ""
try:
if self.mock_redacted_text is not None:
redacted_text = self.mock_redacted_text
Expand All @@ -324,13 +336,33 @@ async def check_pii(
text=text,
analyze_results=analyze_results,
output_parse_pii=output_parse_pii,
masked_entity_count=masked_entity_count,
)

return redacted_text["text"]
except Exception as e:
status = "failure"
exception_str = str(e)
raise e
finally:
####################################################
# Create Guardrail Trace for logging on Langfuse, Datadog, etc.
####################################################
guardrail_json_response: Union[Exception, str, dict, List[dict]] = {}
if status == "success":
if isinstance(analyze_results, List):
guardrail_json_response = [dict(item) for item in analyze_results]
else:
guardrail_json_response = exception_str
self.add_standard_logging_guardrail_information_to_request_data(
guardrail_json_response=guardrail_json_response,
request_data=request_data,
guardrail_status=status,
start_time=start_time.timestamp(),
end_time=datetime.now().timestamp(),
duration=(datetime.now() - start_time).total_seconds(),
masked_entity_count=masked_entity_count,
)

@log_guardrail_information
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
Expand Down Expand Up @@ -394,7 +426,6 @@ async def async_pre_call_hook(
except Exception as e:
raise e

@log_guardrail_information
def logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
Expand Down Expand Up @@ -427,7 +458,6 @@ def run_in_new_loop():
# No running event loop, we can safely run in this thread
return run_in_new_loop()

@log_guardrail_information
async def async_logging_hook(
self, kwargs: dict, result: Any, call_type: str
) -> Tuple[dict, Any]:
Expand Down Expand Up @@ -476,7 +506,6 @@ async def async_logging_hook(

return kwargs, result

@log_guardrail_information
async def async_post_call_success_hook( # type: ignore
self,
data: dict,
Expand Down
Loading
Loading