diff --git a/aws_lambda_powertools/utilities/batch/exceptions.py b/aws_lambda_powertools/utilities/batch/exceptions.py index a3eefbb9cea..3f4075c7d2f 100644 --- a/aws_lambda_powertools/utilities/batch/exceptions.py +++ b/aws_lambda_powertools/utilities/batch/exceptions.py @@ -36,3 +36,19 @@ def __init__(self, msg="", child_exceptions: List[ExceptionInfo] | None = None): def __str__(self): parent_exception_str = super(BatchProcessingError, self).__str__() return self.format_exceptions(parent_exception_str) + + +class SQSFifoCircuitBreakerError(Exception): + """ + Signals a record not processed due to the SQS FIFO processing being interrupted + """ + + pass + + +class SQSFifoMessageGroupCircuitBreakerError(Exception): + """ + Signals a record not processed due to the SQS FIFO message group processing being interrupted + """ + + pass diff --git a/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py index d48749a137e..e54389718bc 100644 --- a/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py +++ b/aws_lambda_powertools/utilities/batch/sqs_fifo_partial_processor.py @@ -1,15 +1,14 @@ -from typing import List, Optional, Tuple - -from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType +import logging +from typing import Optional, Set + +from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType, ExceptionInfo, FailureResponse +from aws_lambda_powertools.utilities.batch.exceptions import ( + SQSFifoCircuitBreakerError, + SQSFifoMessageGroupCircuitBreakerError, +) from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel - -class SQSFifoCircuitBreakerError(Exception): - """ - Signals a record not processed due to the SQS FIFO processing being interrupted - """ - - pass +logger = logging.getLogger(__name__) class SqsFifoPartialProcessor(BatchProcessor): @@ -57,36 +56,59 @@ def lambda_handler(event, context: LambdaContext): None, ) - def __init__(self, model: Optional["BatchSqsTypeModel"] = None): - super().__init__(EventType.SQS, model) + group_circuit_breaker_exc = ( + SQSFifoMessageGroupCircuitBreakerError, + SQSFifoMessageGroupCircuitBreakerError("A previous record from this message group failed processing"), + None, + ) - def process(self) -> List[Tuple]: + def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False): """ - Call instance's handler for each record. When the first failed message is detected, - the process is short-circuited, and the remaining messages are reported as failed items. + Initialize the SqsFifoProcessor. + + Parameters + ---------- + model: Optional["BatchSqsTypeModel"] + An optional model for batch processing. + skip_group_on_error: bool + Determines whether to exclusively skip messages from the MessageGroupID that encountered processing failures + Default is False. + """ - result: List[Tuple] = [] + self._skip_group_on_error: bool = skip_group_on_error + self._current_group_id = None + self._failed_group_ids: Set[str] = set() + super().__init__(EventType.SQS, model) - for i, record in enumerate(self.records): - # If we have failed messages, it means that the last message failed. - # We then short circuit the process, failing the remaining messages - if self.fail_messages: - return self._short_circuit_processing(i, result) + def _process_record(self, record): + self._current_group_id = record.get("attributes", {}).get("MessageGroupId") - # Otherwise, process the message normally - result.append(self._process_record(record)) + # Short-circuits the process if: + # - There are failed messages, OR + # - The `skip_group_on_error` option is on, and the current message is part of a failed group. + fail_entire_batch = bool(self.fail_messages) and not self._skip_group_on_error + fail_group_id = self._skip_group_on_error and self._current_group_id in self._failed_group_ids + if fail_entire_batch or fail_group_id: + return self.failure_handler( + record=self._to_batch_type(record, event_type=self.event_type, model=self.model), + exception=self.group_circuit_breaker_exc if self._skip_group_on_error else self.circuit_breaker_exc, + ) - return result + return super()._process_record(record) - def _short_circuit_processing(self, first_failure_index: int, result: List[Tuple]) -> List[Tuple]: - """ - Starting from the first failure index, fail all the remaining messages, and append them to the result list. - """ - remaining_records = self.records[first_failure_index:] - for remaining_record in remaining_records: - data = self._to_batch_type(record=remaining_record, event_type=self.event_type, model=self.model) - result.append(self.failure_handler(record=data, exception=self.circuit_breaker_exc)) - return result + def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: + # If we are failing a message and the `skip_group_on_error` is on, we store the failed group ID + # This way, future messages with the same group ID will be failed automatically. + if self._skip_group_on_error and self._current_group_id: + self._failed_group_ids.add(self._current_group_id) + + return super().failure_handler(record, exception) + + def _clean(self): + self._failed_group_ids.clear() + self._current_group_id = None + + super()._clean() async def _async_process_record(self, record: dict): raise NotImplementedError() diff --git a/docs/utilities/batch.md b/docs/utilities/batch.md index ada05766ab4..e5241d516e8 100644 --- a/docs/utilities/batch.md +++ b/docs/utilities/batch.md @@ -141,8 +141,11 @@ Processing batches from SQS works in three stages: #### FIFO queues -When using [SQS FIFO queues](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/FIFO-queues.html){target="_blank" rel="nofollow"}, we will stop processing messages after the first failure, and return all failed and unprocessed messages in `batchItemFailures`. -This helps preserve the ordering of messages in your queue. +When working with [SQS FIFO queues](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/FIFO-queues.html){target="_blank"}, a batch may include messages from different group IDs. + +By default, we will stop processing at the first failure and mark unprocessed messages as failed to preserve ordering. However, this behavior may not be optimal for customers who wish to proceed with processing messages from a different group ID. + +Enable the `skip_group_on_error` option for seamless processing of messages from various group IDs. This setup ensures that messages from a failed group ID are sent back to SQS, enabling uninterrupted processing of messages from the subsequent group ID. === "Recommended" @@ -164,6 +167,12 @@ This helps preserve the ordering of messages in your queue. --8<-- "examples/batch_processing/src/getting_started_sqs_fifo_decorator.py" ``` +=== "Enabling skip_group_on_error flag" + + ```python hl_lines="2-6 9 23" + --8<-- "examples/batch_processing/src/getting_started_sqs_fifo_skip_on_error.py" + ``` + ### Processing messages from Kinesis Processing batches from Kinesis works in three stages: @@ -311,7 +320,7 @@ sequenceDiagram > Read more about [Batch Failure Reporting feature in AWS Lambda](https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html#services-sqs-batchfailurereporting){target="_blank"}. -Sequence diagram to explain how [`SqsFifoPartialProcessor` works](#fifo-queues) with SQS FIFO queues. +Sequence diagram to explain how [`SqsFifoPartialProcessor` works](#fifo-queues) with SQS FIFO queues without `skip_group_on_error` flag.
```mermaid @@ -335,6 +344,31 @@ sequenceDiagram SQS FIFO mechanism with Batch Item Failures
+Sequence diagram to explain how [`SqsFifoPartialProcessor` works](#fifo-queues) with SQS FIFO queues with `skip_group_on_error` flag. + +
+```mermaid +sequenceDiagram + autonumber + participant SQS queue + participant Lambda service + participant Lambda function + Lambda service->>SQS queue: Poll + Lambda service->>Lambda function: Invoke (batch event) + activate Lambda function + Lambda function-->Lambda function: Process 2 out of 10 batch items + Lambda function--xLambda function: Fail on 3rd batch item + Lambda function-->Lambda function: Process messages from another MessageGroupID + Lambda function->>Lambda service: Report 3rd batch item and all messages within the same MessageGroupID as failure + deactivate Lambda function + activate SQS queue + Lambda service->>SQS queue: Delete successful messages processed + SQS queue-->>SQS queue: Failed messages return + deactivate SQS queue +``` +SQS FIFO mechanism with Batch Item Failures +
+ #### Kinesis and DynamoDB Streams > Read more about [Batch Failure Reporting feature](https://docs.aws.amazon.com/lambda/latest/dg/with-kinesis.html#services-kinesis-batchfailurereporting){target="_blank"}. diff --git a/examples/batch_processing/src/getting_started_sqs_fifo_skip_on_error.py b/examples/batch_processing/src/getting_started_sqs_fifo_skip_on_error.py new file mode 100644 index 00000000000..83015483d1f --- /dev/null +++ b/examples/batch_processing/src/getting_started_sqs_fifo_skip_on_error.py @@ -0,0 +1,23 @@ +from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.utilities.batch import ( + SqsFifoPartialProcessor, + process_partial_response, +) +from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext + +processor = SqsFifoPartialProcessor(skip_group_on_error=True) +tracer = Tracer() +logger = Logger() + + +@tracer.capture_method +def record_handler(record: SQSRecord): + payload: str = record.json_body # if json string data, otherwise record.body for str + logger.info(payload) + + +@logger.inject_lambda_context +@tracer.capture_lambda_handler +def lambda_handler(event, context: LambdaContext): + return process_partial_response(event=event, record_handler=record_handler, processor=processor, context=context) diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index e146d65744f..8ea2fac7bc5 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -28,6 +28,7 @@ from aws_lambda_powertools.utilities.parser.models import ( DynamoDBStreamChangedRecordModel, DynamoDBStreamRecordModel, + SqsRecordModel, ) from aws_lambda_powertools.utilities.parser.types import Literal from tests.functional.batch.sample_models import ( @@ -38,6 +39,32 @@ from tests.functional.utils import b64_to_str, str_to_b64 +@pytest.fixture(scope="module") +def sqs_event_fifo_factory() -> Callable: + def factory(body: str, message_group_id: str = ""): + return { + "messageId": f"{uuid.uuid4()}", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a", + "body": body, + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1703675223472", + "SequenceNumber": "18882884930918384133", + "MessageGroupId": message_group_id, + "SenderId": "SenderId", + "MessageDeduplicationId": "1eea03c3f7e782c7bdc2f2a917f40389314733ff39f5ab16219580c0109ade98", + "ApproximateFirstReceiveTimestamp": "1703675223484", + }, + "messageAttributes": {}, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-1", + } + + return factory + + @pytest.fixture(scope="module") def sqs_event_factory() -> Callable: def factory(body: str): @@ -48,7 +75,7 @@ def factory(body: str): "attributes": { "ApproximateReceiveCount": "1", "SentTimestamp": "1545082649183", - "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "SenderId": "SenderId", "ApproximateFirstReceiveTimestamp": "1545082649185", }, "messageAttributes": {}, @@ -660,10 +687,10 @@ def lambda_handler(event, context): assert "All records failed processing. " in str(e.value) -def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_factory, record_handler): +def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_fifo_factory, record_handler): # GIVEN - first_record = SQSRecord(sqs_event_factory("success")) - second_record = SQSRecord(sqs_event_factory("success")) + first_record = SQSRecord(sqs_event_fifo_factory("success")) + second_record = SQSRecord(sqs_event_fifo_factory("success")) event = {"Records": [first_record.raw_event, second_record.raw_event]} processor = SqsFifoPartialProcessor() @@ -679,12 +706,12 @@ def lambda_handler(event, context): assert result["batchItemFailures"] == [] -def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_factory, record_handler): +def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_fifo_factory, record_handler): # GIVEN - first_record = SQSRecord(sqs_event_factory("success")) - second_record = SQSRecord(sqs_event_factory("fail")) + first_record = SQSRecord(sqs_event_fifo_factory("success")) + second_record = SQSRecord(sqs_event_fifo_factory("fail")) # this would normally succeed, but since it's a FIFO queue, it will be marked as failure - third_record = SQSRecord(sqs_event_factory("success")) + third_record = SQSRecord(sqs_event_fifo_factory("success")) event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]} processor = SqsFifoPartialProcessor() @@ -702,6 +729,120 @@ def lambda_handler(event, context): assert result["batchItemFailures"][1]["itemIdentifier"] == third_record.message_id +def test_sqs_fifo_batch_processor_middleware_with_skip_group_on_error(sqs_event_fifo_factory, record_handler): + # GIVEN a batch of 5 records with 3 different MessageGroupID + first_record = SQSRecord(sqs_event_fifo_factory("success", "1")) + second_record = SQSRecord(sqs_event_fifo_factory("success", "1")) + third_record = SQSRecord(sqs_event_fifo_factory("fail", "2")) + fourth_record = SQSRecord(sqs_event_fifo_factory("success", "2")) + fifth_record = SQSRecord(sqs_event_fifo_factory("fail", "3")) + event = { + "Records": [ + first_record.raw_event, + second_record.raw_event, + third_record.raw_event, + fourth_record.raw_event, + fifth_record.raw_event, + ], + } + + # WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID + processor = SqsFifoPartialProcessor(skip_group_on_error=True) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN only failed messages should originate from MessageGroupID 3 + assert len(result["batchItemFailures"]) == 3 + assert result["batchItemFailures"][0]["itemIdentifier"] == third_record.message_id + assert result["batchItemFailures"][1]["itemIdentifier"] == fourth_record.message_id + assert result["batchItemFailures"][2]["itemIdentifier"] == fifth_record.message_id + + +def test_sqs_fifo_batch_processor_middleware_with_skip_group_on_error_first_message_fail( + sqs_event_fifo_factory, + record_handler, +): + # GIVEN a batch of 5 records with 3 different MessageGroupID + first_record = SQSRecord(sqs_event_fifo_factory("fail", "1")) + second_record = SQSRecord(sqs_event_fifo_factory("success", "1")) + third_record = SQSRecord(sqs_event_fifo_factory("fail", "2")) + fourth_record = SQSRecord(sqs_event_fifo_factory("success", "2")) + fifth_record = SQSRecord(sqs_event_fifo_factory("success", "3")) + event = { + "Records": [ + first_record.raw_event, + second_record.raw_event, + third_record.raw_event, + fourth_record.raw_event, + fifth_record.raw_event, + ], + } + + # WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID + processor = SqsFifoPartialProcessor(skip_group_on_error=True) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN the handler is onvoked + result = lambda_handler(event, {}) + + # THEN messages from group 1 and 2 should fail, but not group 3 + assert len(result["batchItemFailures"]) == 4 + assert result["batchItemFailures"][0]["itemIdentifier"] == first_record.message_id + assert result["batchItemFailures"][1]["itemIdentifier"] == second_record.message_id + assert result["batchItemFailures"][2]["itemIdentifier"] == third_record.message_id + assert result["batchItemFailures"][3]["itemIdentifier"] == fourth_record.message_id + + +def test_sqs_fifo_batch_processor_middleware_with_skip_group_on_error_and_model(sqs_event_fifo_factory, record_handler): + # GIVEN a batch of 5 records with 3 different MessageGroupID + first_record = SQSRecord(sqs_event_fifo_factory("success", "1")) + second_record = SQSRecord(sqs_event_fifo_factory("success", "1")) + third_record = SQSRecord(sqs_event_fifo_factory("fail", "2")) + fourth_record = SQSRecord(sqs_event_fifo_factory("success", "2")) + fifth_record = SQSRecord(sqs_event_fifo_factory("fail", "3")) + event = { + "Records": [ + first_record.raw_event, + second_record.raw_event, + third_record.raw_event, + fourth_record.raw_event, + fifth_record.raw_event, + ], + } + + class OrderSqsRecord(SqsRecordModel): + receiptHandle: str + + # WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID + # WHEN processor is using a Pydantic Model we must be able to access MessageGroupID property + processor = SqsFifoPartialProcessor(skip_group_on_error=True, model=OrderSqsRecord) + + def record_handler(record: OrderSqsRecord): + if record.body == "fail": + raise ValueError("blah") + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN + result = lambda_handler(event, {}) + + # THEN only failed messages should originate from MessageGroupID 3 + assert len(result["batchItemFailures"]) == 3 + assert result["batchItemFailures"][0]["itemIdentifier"] == third_record.message_id + assert result["batchItemFailures"][1]["itemIdentifier"] == fourth_record.message_id + assert result["batchItemFailures"][2]["itemIdentifier"] == fifth_record.message_id + + def test_async_batch_processor_middleware_success_only(sqs_event_factory, async_record_handler): # GIVEN first_record = SQSRecord(sqs_event_factory("success"))