Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litellm/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def __init__(
response: Optional[httpx.Response] = None,
litellm_debug_info: Optional[str] = None,
provider_specific_fields: Optional[dict] = None,
body: Optional[dict] = None,
):
self.status_code = 400
self.message = "litellm.ContentPolicyViolationError: {}".format(message)
Expand All @@ -466,6 +467,7 @@ def __init__(
llm_provider=self.llm_provider, # type: ignore
response=response,
litellm_debug_info=self.litellm_debug_info,
body=body,
) # Call the base class constructor with the parameters it needs

def __str__(self):
Expand Down
26 changes: 24 additions & 2 deletions litellm/litellm_core_utils/exception_mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ def get_error_message(error_obj) -> Optional[str]:
if hasattr(error_obj, "body"):
_error_obj_body = getattr(error_obj, "body")
if isinstance(_error_obj_body, dict):
return _error_obj_body.get("message")
# OpenAI-style: {"message": "...", "type": "...", ...}
if _error_obj_body.get("message"):
return _error_obj_body.get("message")

# Azure-style: {"error": {"message": "...", ...}}
nested_error = _error_obj_body.get("error")
if isinstance(nested_error, dict):
return nested_error.get("message")

# If all else fails, return None
return None
Expand Down Expand Up @@ -2044,6 +2051,20 @@ def exception_type( # type: ignore # noqa: PLR0915
else:
message = str(original_exception)

# Azure OpenAI (especially Images) often nests error details under
# body["error"]. Detect content policy violations using the structured
# payload in addition to string matching.
azure_error_code: Optional[str] = None
try:
body_dict = getattr(original_exception, "body", None) or {}
if isinstance(body_dict, dict):
if isinstance(body_dict.get("error"), dict):
azure_error_code = body_dict["error"].get("code") # type: ignore[index]
else:
azure_error_code = body_dict.get("code")
except Exception:
azure_error_code = None

if "Internal server error" in error_str:
exception_mapping_worked = True
raise litellm.InternalServerError(
Expand Down Expand Up @@ -2072,7 +2093,8 @@ def exception_type( # type: ignore # noqa: PLR0915
response=getattr(original_exception, "response", None),
)
elif (
ExceptionCheckers.is_azure_content_policy_violation_error(error_str)
azure_error_code == "content_policy_violation"
or ExceptionCheckers.is_azure_content_policy_violation_error(error_str)
):
exception_mapping_worked = True
from litellm.llms.azure.exception_mapping import (
Expand Down
75 changes: 62 additions & 13 deletions litellm/llms/azure/exception_mapping.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Any, Dict, Optional, Tuple

from litellm.exceptions import ContentPolicyViolationError

Expand All @@ -18,27 +18,76 @@ def create_content_policy_violation_error(
"""
Create a content policy violation error
"""
azure_error, inner_error = AzureOpenAIExceptionMapping._extract_azure_error(
original_exception
)

# Prefer the provider message/type/code when present.
provider_message = (
azure_error.get("message")
if isinstance(azure_error, dict)
else None
) or message
provider_type = (
azure_error.get("type") if isinstance(azure_error, dict) else None
)
provider_code = (
azure_error.get("code") if isinstance(azure_error, dict) else None
)

# Keep the OpenAI-style body fields populated so downstream (proxy + SDK)
# can surface `type` / `code` correctly.
openai_style_body: Dict[str, Any] = {
"message": provider_message,
"type": provider_type or "invalid_request_error",
"code": provider_code or "content_policy_violation",
"param": None,
}

raise ContentPolicyViolationError(
message=f"AzureException - {message}",
message=provider_message,
llm_provider="azure",
model=model,
litellm_debug_info=extra_information,
response=getattr(original_exception, "response", None),
provider_specific_fields={
"innererror": AzureOpenAIExceptionMapping._get_innererror_from_exception(
original_exception
)
# Preserve legacy key for backward compatibility.
"innererror": inner_error,
# Prefer Azure's current naming.
"inner_error": inner_error,
# Include the full Azure error object for clients that want it.
"azure_error": azure_error or None,
},
body=openai_style_body,
)

@staticmethod
def _get_innererror_from_exception(original_exception: Exception) -> Optional[dict]:
"""
Azure OpenAI returns the innererror in the body of the exception
This method extracts the innererror from the exception
def _extract_azure_error(
original_exception: Exception,
) -> Tuple[Dict[str, Any], Optional[dict]]:
"""Extract Azure OpenAI error payload and inner error details.

Azure error formats can vary by endpoint/version. Common shapes:
- {"innererror": {...}} (legacy)
- {"error": {"code": "...", "message": "...", "type": "...", "inner_error": {...}}}
- {"code": "...", "message": "...", "type": "..."} (already flattened)
"""
innererror = None
body_dict = getattr(original_exception, "body", None) or {}
if isinstance(body_dict, dict):
innererror = body_dict.get("innererror")
return innererror
if not isinstance(body_dict, dict):
return {}, None

# Some SDKs place the payload under "error".
azure_error: Dict[str, Any]
if isinstance(body_dict.get("error"), dict):
azure_error = body_dict.get("error", {}) # type: ignore[assignment]
else:
azure_error = body_dict

inner_error = (
azure_error.get("inner_error")
or azure_error.get("innererror")
or body_dict.get("innererror")
or body_dict.get("inner_error")
)

return azure_error, inner_error
51 changes: 50 additions & 1 deletion tests/test_litellm/llms/azure/test_azure_exception_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,4 +190,53 @@ def test_azure_content_policy_violation_non_dict_body(self):
print("got exception=", e)
print("exception fields=", vars(e))
assert e.provider_specific_fields is not None
assert e.provider_specific_fields.get("innererror") is None
assert e.provider_specific_fields.get("innererror") is None

def test_azure_images_content_policy_violation_preserves_nested_inner_error(self):
"""Azure Images endpoints return errors nested under body['error'] with inner_error.

Ensure we:
- Detect the violation via structured payload (code=content_policy_violation)
- Preserve code/type/message
- Surface inner_error + revised_prompt + content_filter_results
"""

mock_exception = Exception("Bad request") # does not include policy substrings
mock_exception.body = {
"error": {
"code": "content_policy_violation",
"inner_error": {
"code": "ResponsibleAIPolicyViolation",
"content_filter_results": {
"violence": {"filtered": True, "severity": "low"}
},
"revised_prompt": "revised",
},
"message": "Your request was rejected as a result of our safety system.",
"type": "invalid_request_error",
}
}

mock_response = MagicMock()
mock_response.status_code = 400
mock_exception.response = mock_response

with pytest.raises(ContentPolicyViolationError) as exc_info:
exception_type(
model="azure/dall-e-3",
original_exception=mock_exception,
custom_llm_provider="azure",
)

e = exc_info.value

# OpenAI-style error fields should be populated
assert getattr(e, "code", None) == "content_policy_violation"
assert getattr(e, "type", None) == "invalid_request_error"
assert "safety system" in str(e)

# Provider-specific nested details must be preserved
assert e.provider_specific_fields is not None
assert e.provider_specific_fields["inner_error"]["code"] == "ResponsibleAIPolicyViolation"
assert e.provider_specific_fields["inner_error"]["revised_prompt"] == "revised"
assert e.provider_specific_fields["inner_error"]["content_filter_results"]["violence"]["filtered"] is True
Loading