diff --git a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py index 6800dff55ac..fac1a8d76b2 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py +++ b/litellm/proxy/guardrails/guardrail_hooks/bedrock_guardrails.py @@ -48,6 +48,7 @@ BedrockContentItem, BedrockGuardrailOutput, BedrockGuardrailResponse, + BedrockGuardrailUsage, BedrockRequest, BedrockTextContent, ) @@ -122,6 +123,16 @@ def _redact_pii_matches(response_json: dict) -> dict: class BedrockGuardrail(CustomGuardrail, BaseAWSLLM): + BEDROCK_GUARDRAIL_MAX_CHARS = 25_000 + + @staticmethod + def _extract_content_text(item: BedrockContentItem) -> str: + text_obj = item.get("text") + if not isinstance(text_obj, dict): + return "" + text = text_obj.get("text") or "" + return text if isinstance(text, str) else "" + def __init__( self, guardrailIdentifier: Optional[str] = None, @@ -406,34 +417,152 @@ def _prepare_request( return prepped_request - async def make_bedrock_api_request( + @staticmethod + def _chunk_content_items( + content: List[BedrockContentItem], + max_chars: int = BEDROCK_GUARDRAIL_MAX_CHARS, + ) -> List[List[BedrockContentItem]]: + """Split content items into chunks that stay within the character limit. + + Each chunk contains one or more content items whose combined text length + does not exceed *max_chars*. If a single text item exceeds the limit it + is split across chunks at the character boundary. + """ + + if max_chars <= 0: + raise ValueError("max_chars must be a positive integer") + + chunks: List[List[BedrockContentItem]] = [] + current_chunk: List[BedrockContentItem] = [] + current_chars = 0 + + for item in content: + text = BedrockGuardrail._extract_content_text(item) + if text == "": + current_chunk.append(item) + continue + text_len = len(text) + + # If the single item fits in the current chunk, add it + if current_chars + text_len <= max_chars: + current_chunk.append(item) + current_chars += text_len + continue + + # Item doesn't fit — slice it across chunk boundaries + offset = 0 + while offset < text_len: + remaining_space = max_chars - current_chars + if remaining_space <= 0: + chunks.append(current_chunk) + current_chunk = [] + current_chars = 0 + remaining_space = max_chars + + slice_end = min(offset + remaining_space, text_len) + chunk_text = text[offset:slice_end] + current_chunk.append( + BedrockContentItem(text=BedrockTextContent(text=chunk_text)) + ) + current_chars += len(chunk_text) + offset = slice_end + + if current_chars >= max_chars: + chunks.append(current_chunk) + current_chunk = [] + current_chars = 0 + + if current_chunk: + chunks.append(current_chunk) + + return chunks if chunks else [content] + + @staticmethod + def _merge_guardrail_responses( + responses: List[Tuple[BedrockGuardrailResponse, dict]], + ) -> Tuple[BedrockGuardrailResponse, dict]: + """Merge multiple guardrail responses — worst action wins.""" + if len(responses) == 1: + return responses[0] + + merged_response = BedrockGuardrailResponse() + all_assessments: list = [] + all_outputs: list = [] + merged_usage: dict = {} + + worst_action = "NONE" + action_priority = {"NONE": 0, "GUARDRAIL_INTERVENED": 1} + + for resp, _json_resp in responses: + action = resp.get("action", "NONE") + if action_priority.get(action, 0) > action_priority.get(worst_action, 0): + worst_action = action + + assessments = resp.get("assessments") or [] + all_assessments.extend(assessments) + + outputs = resp.get("outputs") + if outputs is None: + outputs = resp.get("output") + if isinstance(outputs, list): + all_outputs.extend(outputs) + elif isinstance(outputs, dict): + all_outputs.append(outputs) + elif isinstance(outputs, str): + all_outputs.append({"text": outputs}) + + usage = resp.get("usage") or {} + for key, val in usage.items(): + if isinstance(val, (int, float)): + merged_usage[key] = merged_usage.get(key, 0) + val + + merged_response["action"] = worst_action + if all_assessments: + merged_response["assessments"] = all_assessments + if all_outputs: + merged_response["outputs"] = all_outputs + if merged_usage: + # Ensure values are int for BedrockGuardrailUsage + merged_response["usage"] = BedrockGuardrailUsage( + **{k: int(v) for k, v in merged_usage.items()} + ) + + merged_json = dict(merged_response) + + # Propagate AWS exception markers from individual chunk responses so + # that _determine_guardrail_status_from_json can detect them. + for _resp, _json_resp in responses: + for key in ("Output", "output"): + output_payload = _json_resp.get(key) + if isinstance(output_payload, dict): + output_type = output_payload.get("__type", "") + if isinstance(output_type, str) and "Exception" in output_type: + merged_json[key] = output_payload + break + else: + continue + break + + return merged_response, merged_json + + async def _make_single_bedrock_api_request( self, + bedrock_request_data: dict, + credentials: Any, + aws_region_name: str, + api_key: Optional[str], source: Literal["INPUT", "OUTPUT"], - messages: Optional[List[AllMessageValues]] = None, - response: Optional[Union[Any, litellm.ModelResponse]] = None, - request_data: Optional[dict] = None, - ) -> BedrockGuardrailResponse: + request_data: Optional[dict], + start_time: Any, + ) -> Tuple[BedrockGuardrailResponse, dict]: + """Execute a single ApplyGuardrail API call and return parsed response + raw JSON.""" from datetime import datetime - start_time = datetime.now() - credentials, aws_region_name = self._load_credentials() - bedrock_request_data: dict = dict( - self.convert_to_bedrock_format( - source=source, messages=messages, response=response - ) - ) - bedrock_guardrail_response: BedrockGuardrailResponse = ( - BedrockGuardrailResponse() + event_type = ( + GuardrailEventHooks.pre_call + if source == "INPUT" + else GuardrailEventHooks.post_call ) - api_key: Optional[str] = None - if request_data: - bedrock_request_data.update( - self.get_guardrail_dynamic_request_body_params( - request_data=request_data - ) - ) - if request_data.get("api_key") is not None: - api_key = request_data["api_key"] prepared_request = self._prepare_request( credentials=credentials, @@ -442,17 +571,12 @@ async def make_bedrock_api_request( aws_region_name=aws_region_name, api_key=api_key, ) + verbose_proxy_logger.debug( - "Bedrock AI request body: %s, url %s, headers: %s", + "Bedrock guardrail request: body=%s, url=%s, headers=%s", bedrock_request_data, prepared_request.url, - prepared_request.headers, - ) - - event_type = ( - GuardrailEventHooks.pre_call - if source == "INPUT" - else GuardrailEventHooks.post_call + {k: v for k, v in prepared_request.headers.items() if k.lower() != "authorization"}, ) try: @@ -462,11 +586,8 @@ async def make_bedrock_api_request( headers=prepared_request.headers, # type: ignore ) except HTTPException: - # Propagate HTTPException (e.g. from non-200 path) as-is raise except Exception as e: - # If this is an HTTP error with a response body (e.g. httpx.HTTPStatusError), - # extract the AWS error message and propagate it response = getattr(e, "response", None) if isinstance(response, httpx.Response): try: @@ -488,7 +609,6 @@ async def make_bedrock_api_request( ) from e except HTTPException: raise - # Endpoint down, timeout, or other HTTP/network errors verbose_proxy_logger.error( "Bedrock AI: failed to make guardrail request: %s", str(e) ) @@ -504,46 +624,172 @@ async def make_bedrock_api_request( ) raise - ######################################################### - # Add guardrail information to request trace - ######################################################### - self.add_standard_logging_guardrail_information_to_request_data( - guardrail_provider=self.guardrail_provider, - guardrail_json_response=httpx_response.json(), - request_data=request_data or {}, - guardrail_status=self._get_bedrock_guardrail_response_status( - response=httpx_response - ), - start_time=start_time.timestamp(), - end_time=datetime.now().timestamp(), - duration=(datetime.now() - start_time).total_seconds(), - event_type=event_type, - ) - ######################################################### if httpx_response.status_code == 200: - # check if the response was flagged - _json_response = httpx_response.json() - redacted_response = _redact_pii_matches(_json_response) - verbose_proxy_logger.debug("Bedrock AI response : %s", redacted_response) - bedrock_guardrail_response = BedrockGuardrailResponse(**_json_response) - if self._should_raise_guardrail_blocked_exception( - bedrock_guardrail_response - ): - raise self._get_http_exception_for_blocked_guardrail( - bedrock_guardrail_response - ) + json_response = httpx_response.json() + guardrail_response = BedrockGuardrailResponse(**json_response) + return guardrail_response, json_response else: - status_code, detail_message = self._parse_bedrock_guardrail_error_response( - httpx_response + status_code, detail_message = ( + self._parse_bedrock_guardrail_error_response(httpx_response) ) verbose_proxy_logger.error( "Bedrock AI: error in response. Status code: %s, response: %s", httpx_response.status_code, httpx_response.text, ) + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_provider=self.guardrail_provider, + guardrail_json_response={"error": detail_message}, + request_data=request_data or {}, + guardrail_status="guardrail_failed_to_respond", + start_time=start_time.timestamp(), + end_time=datetime.now().timestamp(), + duration=(datetime.now() - start_time).total_seconds(), + event_type=event_type, + ) raise HTTPException(status_code=status_code, detail=detail_message) - return bedrock_guardrail_response + async def make_bedrock_api_request( + self, + source: Literal["INPUT", "OUTPUT"], + messages: Optional[List[AllMessageValues]] = None, + response: Optional[Union[Any, litellm.ModelResponse]] = None, + request_data: Optional[dict] = None, + ) -> BedrockGuardrailResponse: + from datetime import datetime + + start_time = datetime.now() + credentials, aws_region_name = self._load_credentials() + bedrock_request_data: dict = dict( + self.convert_to_bedrock_format( + source=source, messages=messages, response=response + ) + ) + api_key: Optional[str] = None + if request_data: + bedrock_request_data.update( + self.get_guardrail_dynamic_request_body_params( + request_data=request_data + ) + ) + if request_data.get("api_key") is not None: + api_key = request_data["api_key"] + + # Compute endpoint URL for top-level debug logging + _aws_bedrock_endpoint = self.optional_params.get( + "aws_bedrock_runtime_endpoint", None + ) + _, _endpoint_base = self.get_runtime_endpoint( + api_base=None, + aws_bedrock_runtime_endpoint=_aws_bedrock_endpoint, + aws_region_name=aws_region_name, + ) + guardrail_url = f"{_endpoint_base}/guardrail/{self.guardrailIdentifier}/version/{self.guardrailVersion}/apply" + + verbose_proxy_logger.debug( + "Bedrock AI request: body=%s, url=%s, headers=%s, guardrail=%s/%s, region=%s", + bedrock_request_data, + guardrail_url, + {"Content-Type": "application/json"}, + self.guardrailIdentifier, + self.guardrailVersion, + aws_region_name, + ) + + event_type = ( + GuardrailEventHooks.pre_call + if source == "INPUT" + else GuardrailEventHooks.post_call + ) + + # --- chunking logic --- + content = bedrock_request_data.get("content", []) + total_chars = sum(len(self._extract_content_text(item)) for item in content) + + if total_chars > self.BEDROCK_GUARDRAIL_MAX_CHARS: + chunks = self._chunk_content_items( + content, max_chars=self.BEDROCK_GUARDRAIL_MAX_CHARS + ) + responses: List[Tuple[BedrockGuardrailResponse, dict]] = [] + for chunk in chunks: + chunk_request = dict(bedrock_request_data) + chunk_request["content"] = chunk + resp, json_resp = await self._make_single_bedrock_api_request( + bedrock_request_data=chunk_request, + credentials=credentials, + aws_region_name=aws_region_name, + api_key=api_key, + source=source, + request_data=request_data, + start_time=start_time, + ) + responses.append((resp, json_resp)) + # Short-circuit: if any chunk is blocked, stop processing + if self._should_raise_guardrail_blocked_exception(resp): + break + + bedrock_guardrail_response, merged_json = ( + self._merge_guardrail_responses(responses) + ) + + redacted_response = _redact_pii_matches(merged_json) + verbose_proxy_logger.debug("Bedrock AI response : %s", redacted_response) + + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_provider=self.guardrail_provider, + guardrail_json_response=merged_json, + request_data=request_data or {}, + guardrail_status=self._determine_guardrail_status_from_json( + json_response=merged_json, + guardrail_response=bedrock_guardrail_response, + ), + start_time=start_time.timestamp(), + end_time=datetime.now().timestamp(), + duration=(datetime.now() - start_time).total_seconds(), + event_type=event_type, + ) + + if self._should_raise_guardrail_blocked_exception( + bedrock_guardrail_response + ): + raise self._get_http_exception_for_blocked_guardrail( + bedrock_guardrail_response + ) + + return bedrock_guardrail_response + + # --- single request path (content within limit) --- + resp, _json_response = await self._make_single_bedrock_api_request( + bedrock_request_data=bedrock_request_data, + credentials=credentials, + aws_region_name=aws_region_name, + api_key=api_key, + source=source, + request_data=request_data, + start_time=start_time, + ) + + redacted_response = _redact_pii_matches(_json_response) + verbose_proxy_logger.debug("Bedrock AI response : %s", redacted_response) + + self.add_standard_logging_guardrail_information_to_request_data( + guardrail_provider=self.guardrail_provider, + guardrail_json_response=_json_response, + request_data=request_data or {}, + guardrail_status=self._determine_guardrail_status_from_json( + json_response=_json_response, + guardrail_response=resp, + ), + start_time=start_time.timestamp(), + end_time=datetime.now().timestamp(), + duration=(datetime.now() - start_time).total_seconds(), + event_type=event_type, + ) + + if self._should_raise_guardrail_blocked_exception(resp): + raise self._get_http_exception_for_blocked_guardrail(resp) + + return resp def _check_bedrock_response_for_exception(self, response) -> bool: """ @@ -578,34 +824,26 @@ def _check_bedrock_response_for_exception(self, response) -> bool: return "Exception" in payload.get("Output", {}).get("__type", "") - def _get_bedrock_guardrail_response_status( - self, response: httpx.Response + def _determine_guardrail_status_from_json( + self, + json_response: dict, + guardrail_response: BedrockGuardrailResponse, ) -> GuardrailStatus: - """ - Get the status of the bedrock guardrail response. + """Determine guardrail status from already-parsed response data. - Returns: - "success": Content allowed through with no violations - "guardrail_intervened": Content blocked due to policy violations - "guardrail_failed_to_respond": Technical error or API failure + Checks for the AWS exception marker in the JSON body before falling + back to action-based classification. """ - if response.status_code == 200: - if self._check_bedrock_response_for_exception(response): + output_payload = json_response.get("Output") + if output_payload is None: + output_payload = json_response.get("output") + if isinstance(output_payload, dict): + output_type = output_payload.get("__type", "") + if isinstance(output_type, str) and "Exception" in output_type: return "guardrail_failed_to_respond" - - # Check if the guardrail would block content - try: - _json_response = response.json() - bedrock_guardrail_response = BedrockGuardrailResponse(**_json_response) - if self._should_raise_guardrail_blocked_exception( - bedrock_guardrail_response - ): - return "guardrail_intervened" - except Exception: - pass - - return "success" - return "guardrail_failed_to_respond" + if self._should_raise_guardrail_blocked_exception(guardrail_response): + return "guardrail_intervened" + return "success" def _parse_bedrock_guardrail_error_response( self, response: httpx.Response @@ -1459,4 +1697,6 @@ async def apply_guardrail( verbose_proxy_logger.error( "Bedrock Guardrail: Failed to apply guardrail: %s", str(e) ) - raise Exception(f"Bedrock guardrail failed: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Bedrock guardrail failed: {str(e)}" + ) diff --git a/tests/litellm/proxy/guardrails/test_bedrock_guardrail_chunking.py b/tests/litellm/proxy/guardrails/test_bedrock_guardrail_chunking.py new file mode 100644 index 00000000000..5c4bc41c04e --- /dev/null +++ b/tests/litellm/proxy/guardrails/test_bedrock_guardrail_chunking.py @@ -0,0 +1,310 @@ +"""Tests for Bedrock guardrail chunking/batching logic.""" + +from typing import List +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import HTTPException + +from litellm.types.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockContentItem, + BedrockGuardrailResponse, + BedrockGuardrailUsage, + BedrockTextContent, +) + + +def _make_item(text: str) -> BedrockContentItem: + return BedrockContentItem(text=BedrockTextContent(text=text)) + + +def _items_text(items: List[BedrockContentItem]) -> str: + def _item_text(item: BedrockContentItem) -> str: + text_obj = item.get("text") + if not isinstance(text_obj, dict): + return "" + text = text_obj.get("text", "") + return text if isinstance(text, str) else "" + + return "".join(_item_text(item) for item in items) + + +# --------------------------------------------------------------------------- +# _chunk_content_items +# --------------------------------------------------------------------------- + + +class TestChunkContentItems: + @staticmethod + def _chunk(content, max_chars=25_000): + from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockGuardrail, + ) + return BedrockGuardrail._chunk_content_items(content, max_chars=max_chars) + + def test_under_limit_returns_single_chunk(self): + items = [_make_item("hello")] + chunks = self._chunk(items, max_chars=100) + assert len(chunks) == 1 + assert chunks[0] == items + + def test_exact_limit_returns_single_chunk(self): + items = [_make_item("a" * 100)] + chunks = self._chunk(items, max_chars=100) + assert len(chunks) == 1 + + def test_multiple_items_split_across_chunks(self): + items = [_make_item("a" * 60), _make_item("b" * 60)] + chunks = self._chunk(items, max_chars=100) + assert len(chunks) == 2 + # First chunk has the first item, second has the second + combined = "".join(_items_text(c) for c in chunks) + assert combined == "a" * 60 + "b" * 60 + + def test_single_large_item_split_mid_text(self): + text = "x" * 250 + items = [_make_item(text)] + chunks = self._chunk(items, max_chars=100) + assert len(chunks) == 3 + # All text must be preserved + combined = "".join(_items_text(c) for c in chunks) + assert combined == text + + def test_empty_content(self): + chunks = self._chunk([], max_chars=100) + assert len(chunks) == 1 + assert chunks[0] == [] + + def test_items_without_text_key_kept(self): + items: List[BedrockContentItem] = [BedrockContentItem()] # no text key + chunks = self._chunk(items, max_chars=100) + assert len(chunks) == 1 + assert len(chunks[0]) == 1 + + def test_items_with_none_text_value_are_supported(self): + items: List[BedrockContentItem] = [{"text": None}] # type: ignore[list-item] + chunks = self._chunk(items, max_chars=100) + assert len(chunks) == 1 + assert chunks[0] == items + + def test_mixed_small_and_large(self): + items = [_make_item("a" * 40), _make_item("b" * 80), _make_item("c" * 20)] + chunks = self._chunk(items, max_chars=100) + combined = "".join(_items_text(c) for c in chunks) + assert combined == "a" * 40 + "b" * 80 + "c" * 20 + + def test_default_max_chars(self): + """Verify that the default max_chars is 25_000.""" + from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockGuardrail, + ) + items = [_make_item("a" * 10)] + chunks = BedrockGuardrail._chunk_content_items(items) + assert len(chunks) == 1 + + def test_zero_max_chars_raises(self): + """max_chars=0 must raise ValueError to prevent infinite loop.""" + with pytest.raises(ValueError): + self._chunk([_make_item("a")], max_chars=0) + + def test_negative_max_chars_raises(self): + """max_chars < 0 must raise ValueError to prevent infinite loop.""" + with pytest.raises(ValueError): + self._chunk([_make_item("a")], max_chars=-1) + + +# --------------------------------------------------------------------------- +# _merge_guardrail_responses +# --------------------------------------------------------------------------- + + +class TestMergeGuardrailResponses: + @staticmethod + def _merge(responses): + from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockGuardrail, + ) + return BedrockGuardrail._merge_guardrail_responses(responses) + + def test_single_response_passthrough(self): + resp = BedrockGuardrailResponse(action="NONE") + json_resp = dict(resp) + merged, merged_json = self._merge([(resp, json_resp)]) + assert merged is resp + + def test_worst_action_wins(self): + r1 = BedrockGuardrailResponse(action="NONE") + r2 = BedrockGuardrailResponse(action="GUARDRAIL_INTERVENED") + merged, _ = self._merge([(r1, dict(r1)), (r2, dict(r2))]) + assert merged["action"] == "GUARDRAIL_INTERVENED" + + def test_usage_summed(self): + r1 = BedrockGuardrailResponse( + action="NONE", + usage=BedrockGuardrailUsage(topicPolicyUnits=5, contentPolicyUnits=3), + ) + r2 = BedrockGuardrailResponse( + action="NONE", + usage=BedrockGuardrailUsage(topicPolicyUnits=10, contentPolicyUnits=7), + ) + merged, _ = self._merge([(r1, dict(r1)), (r2, dict(r2))]) + assert merged["usage"]["topicPolicyUnits"] == 15 + assert merged["usage"]["contentPolicyUnits"] == 10 + + def test_assessments_concatenated(self): + r1 = BedrockGuardrailResponse(action="NONE", assessments=[{"a": 1}]) + r2 = BedrockGuardrailResponse(action="NONE", assessments=[{"b": 2}]) + merged, _ = self._merge([(r1, dict(r1)), (r2, dict(r2))]) + assert len(merged["assessments"]) == 2 + + def test_outputs_concatenated(self): + r1 = BedrockGuardrailResponse(action="NONE", outputs=[{"text": "x"}]) + r2 = BedrockGuardrailResponse(action="NONE", outputs=[{"text": "y"}]) + merged, _ = self._merge([(r1, dict(r1)), (r2, dict(r2))]) + assert len(merged["outputs"]) == 2 + + def test_output_dict_is_normalized_to_single_output_entry(self): + r1 = BedrockGuardrailResponse(action="NONE") + r1["output"] = {"text": "x"} # type: ignore[typeddict-item] + r2 = BedrockGuardrailResponse(action="NONE", outputs=[{"text": "y"}]) + merged, _ = self._merge([(r1, dict(r1)), (r2, dict(r2))]) + assert merged["outputs"] == [{"text": "x"}, {"text": "y"}] + + def test_exception_marker_propagated_to_merged_json(self): + """AWS exception markers in chunk responses must survive merge.""" + r1 = BedrockGuardrailResponse(action="NONE") + j1 = dict(r1) + r2 = BedrockGuardrailResponse(action="NONE") + j2 = dict(r2) + j2["Output"] = {"__type": "SomeException", "message": "error"} + _, merged_json = self._merge([(r1, j1), (r2, j2)]) + assert "Output" in merged_json + assert "Exception" in merged_json["Output"]["__type"] + + def test_exception_marker_lowercase_output_propagated(self): + """Lowercase 'output' exception markers must also be propagated.""" + r1 = BedrockGuardrailResponse(action="NONE") + j1 = dict(r1) + j1["output"] = {"__type": "ThrottlingException"} + r2 = BedrockGuardrailResponse(action="NONE") + j2 = dict(r2) + _, merged_json = self._merge([(r1, j1), (r2, j2)]) + assert "output" in merged_json + assert "Exception" in merged_json["output"]["__type"] + + +# --------------------------------------------------------------------------- +# make_bedrock_api_request — integration with chunking +# --------------------------------------------------------------------------- + + +class TestMakeBedrockApiRequestChunking: + """Verify that make_bedrock_api_request chunks large content.""" + + @pytest.fixture() + def guardrail(self): + with patch( + "litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails.get_async_httpx_client" + ): + from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockGuardrail, + ) + g = BedrockGuardrail( + guardrailIdentifier="test-id", + guardrailVersion="1", + ) + g._load_credentials = MagicMock(return_value=(MagicMock(), "us-east-1")) + return g + + @pytest.mark.asyncio + async def test_small_content_no_chunking(self, guardrail): + """Content under limit should call _make_single_bedrock_api_request once.""" + small_content = [_make_item("a" * 100)] + guardrail.convert_to_bedrock_format = MagicMock( + return_value={"source": "INPUT", "content": small_content} + ) + guardrail.get_guardrail_dynamic_request_body_params = MagicMock( + return_value={} + ) + + ok_response = BedrockGuardrailResponse(action="NONE") + guardrail._make_single_bedrock_api_request = AsyncMock( + return_value=(ok_response, dict(ok_response)) + ) + + result = await guardrail.make_bedrock_api_request(source="INPUT") + assert guardrail._make_single_bedrock_api_request.call_count == 1 + assert result.get("action") == "NONE" + + @pytest.mark.asyncio + async def test_large_content_triggers_chunking(self, guardrail): + """Content over 25k chars should be chunked into multiple calls.""" + large_content = [_make_item("a" * 30_000)] + guardrail.convert_to_bedrock_format = MagicMock( + return_value={"source": "INPUT", "content": large_content} + ) + guardrail.get_guardrail_dynamic_request_body_params = MagicMock( + return_value={} + ) + + ok_response = BedrockGuardrailResponse( + action="NONE", + usage=BedrockGuardrailUsage(topicPolicyUnits=1), + ) + guardrail._make_single_bedrock_api_request = AsyncMock( + return_value=(ok_response, dict(ok_response)) + ) + + result = await guardrail.make_bedrock_api_request(source="INPUT") + assert guardrail._make_single_bedrock_api_request.call_count == 2 + assert result.get("usage", {}).get("topicPolicyUnits") == 2 + + @pytest.mark.asyncio + async def test_chunking_short_circuits_on_block(self, guardrail): + """If any chunk is blocked, remaining chunks should be skipped.""" + large_content = [_make_item("a" * 60_000)] + guardrail.convert_to_bedrock_format = MagicMock( + return_value={"source": "INPUT", "content": large_content} + ) + guardrail.get_guardrail_dynamic_request_body_params = MagicMock( + return_value={} + ) + + blocked_response = BedrockGuardrailResponse( + action="GUARDRAIL_INTERVENED", + assessments=[ + {"topicPolicy": {"topics": [{"action": "BLOCKED"}]}} + ], + outputs=[{"text": "blocked"}], + ) + guardrail._make_single_bedrock_api_request = AsyncMock( + return_value=(blocked_response, dict(blocked_response)) + ) + + with pytest.raises(HTTPException) as exc_info: + await guardrail.make_bedrock_api_request(source="INPUT") + assert exc_info.value.status_code == 400 + + # Should have stopped after first chunk's BLOCKED result + assert guardrail._make_single_bedrock_api_request.call_count == 1 + + +class TestGuardrailStatusDetection: + def test_lowercase_output_exception_status_is_failed(self): + with patch( + "litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails.get_async_httpx_client" + ): + from litellm.proxy.guardrails.guardrail_hooks.bedrock_guardrails import ( + BedrockGuardrail, + ) + + guardrail = BedrockGuardrail( + guardrailIdentifier="test-id", + guardrailVersion="1", + ) + + status = guardrail._determine_guardrail_status_from_json( + json_response={"output": {"__type": "SomeException"}}, + guardrail_response=BedrockGuardrailResponse(action="NONE"), + ) + assert status == "guardrail_failed_to_respond"