Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(batch): add flag in SqsFifoProcessor to enable continuous message processing #3954

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from typing import List, Optional, Tuple
import logging
from typing import Dict, List, Optional, Tuple

from aws_lambda_powertools.utilities.batch import BatchProcessor, EventType
from aws_lambda_powertools.utilities.batch.types import BatchSqsTypeModel

logger = logging.getLogger(__name__)


class SQSFifoCircuitBreakerError(Exception):
rubenfonseca marked this conversation as resolved.
Show resolved Hide resolved
"""
Expand Down Expand Up @@ -57,7 +60,20 @@ def lambda_handler(event, context: LambdaContext):
None,
)

def __init__(self, model: Optional["BatchSqsTypeModel"] = None):
def __init__(self, model: Optional["BatchSqsTypeModel"] = None, skip_group_on_error: bool = False):
"""
Initialize the SqsFifoProcessor.

Parameters
----------
model: Optional["BatchSqsTypeModel"]
An optional model for batch processing.
skip_group_on_error: bool
# TODO: Alterar
Determine whether to return on the first error encountered. Default is True
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

"""
self._skip_group_on_error = skip_group_on_error
super().__init__(EventType.SQS, model)

def process(self) -> List[Tuple]:
Expand All @@ -66,18 +82,60 @@ def process(self) -> List[Tuple]:
the process is short-circuited, and the remaining messages are reported as failed items.
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
"""
result: List[Tuple] = []
skip_messages_group_id: List = []
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

for i, record in enumerate(self.records):
# If we have failed messages, it means that the last message failed.
# We then short circuit the process, failing the remaining messages
if self.fail_messages:
# If we have failed messages and we are set to return on the first error,
# short circuit the process and return the remaining messages as failed items
if self.fail_messages and not self._skip_group_on_error:
logger.debug("Processing of failed messages stopped due to the 'skip_group_on_error' is set to False")
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved
return self._short_circuit_processing(i, result)

# Otherwise, process the message normally
result.append(self._process_record(record))
msg_id = record.get("messageId")

# skip_group_on_error is True:
# Skip processing the current message if its ID belongs to a group with failed messages
if msg_id in skip_messages_group_id:
logger.debug(
f"Skipping message with ID '{msg_id}' as it is part of a group containing failed messages.",
)
continue

processed_messages = self._process_record(record)
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

# If a processed message fail and skip_group_on_error is True,
# mark subsequent messages from the same MessageGroupId as skipped
if processed_messages[0] == "fail" and self._skip_group_on_error:
self._process_failed_subsequent_messages(record, i, skip_messages_group_id, result)

# Append the processed message normally
result.append(processed_messages)

return result

def _process_failed_subsequent_messages(
self,
record: Dict,
i: int,
skip_messages_group_id: List,
result: List[Tuple],
) -> None:
"""
Process failed subsequent messages from the same MessageGroupId and mark them as skipped.
"""
_attributes_record = record.get("attributes", {})

for subsequent_record in self.records[i + 1 :]:
_attributes = subsequent_record.get("attributes", {})
if _attributes.get("MessageGroupId") == _attributes_record.get("MessageGroupId"):
skip_messages_group_id.append(subsequent_record.get("messageId"))
data = self._to_batch_type(
record=subsequent_record,
event_type=self.event_type,
model=self.model,
)
result.append(self.failure_handler(record=data, exception=self.circuit_breaker_exc))

def _short_circuit_processing(self, first_failure_index: int, result: List[Tuple]) -> List[Tuple]:
"""
Starting from the first failure index, fail all the remaining messages, and append them to the result list.
Expand Down
40 changes: 37 additions & 3 deletions docs/utilities/batch.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,11 @@ Processing batches from SQS works in three stages:

#### FIFO queues

When using [SQS FIFO queues](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/FIFO-queues.html){target="_blank" rel="nofollow"}, we will stop processing messages after the first failure, and return all failed and unprocessed messages in `batchItemFailures`.
This helps preserve the ordering of messages in your queue.
When working with [SQS FIFO queues](https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/FIFO-queues.html){target="_blank"}, you should know that a batch may include messages from different group IDs.
leandrodamascena marked this conversation as resolved.
Show resolved Hide resolved

By default, we will stop processing at the first failure and mark unprocessed messages as failed to preserve ordering. However, this behavior may not be optimal for customers who wish to proceed with processing messages from a different group ID.

Enable the `skip_group_on_error` option for seamless processing of messages from various group IDs. This setup ensures that messages from a failed group ID are sent back to SQS, enabling uninterrupted processing of messages from the subsequent group ID.

=== "Recommended"

Expand All @@ -164,6 +167,12 @@ This helps preserve the ordering of messages in your queue.
--8<-- "examples/batch_processing/src/getting_started_sqs_fifo_decorator.py"
```

=== "Enabling skip_group_on_error flag"

```python hl_lines="2-6 9 23"
--8<-- "examples/batch_processing/src/getting_started_sqs_fifo_skip_on_error.py"
```

### Processing messages from Kinesis

Processing batches from Kinesis works in three stages:
Expand Down Expand Up @@ -311,7 +320,7 @@ sequenceDiagram

> Read more about [Batch Failure Reporting feature in AWS Lambda](https://docs.aws.amazon.com/lambda/latest/dg/with-sqs.html#services-sqs-batchfailurereporting){target="_blank"}.

Sequence diagram to explain how [`SqsFifoPartialProcessor` works](#fifo-queues) with SQS FIFO queues.
Sequence diagram to explain how [`SqsFifoPartialProcessor` works](#fifo-queues) with SQS FIFO queues without `skip_group_on_error` flag.

<center>
```mermaid
Expand All @@ -335,6 +344,31 @@ sequenceDiagram
<i>SQS FIFO mechanism with Batch Item Failures</i>
</center>

Sequence diagram to explain how [`SqsFifoPartialProcessor` works](#fifo-queues) with SQS FIFO queues with `skip_group_on_error` flag.

<center>
```mermaid
sequenceDiagram
autonumber
participant SQS queue
participant Lambda service
participant Lambda function
Lambda service->>SQS queue: Poll
Lambda service->>Lambda function: Invoke (batch event)
activate Lambda function
Lambda function-->Lambda function: Process 2 out of 10 batch items
Lambda function--xLambda function: Fail on 3rd batch item
Lambda function-->Lambda function: Process messages from another MessageGroupID
Lambda function->>Lambda service: Report 3rd batch item and all messages within the same MessageGroupID as failure
deactivate Lambda function
activate SQS queue
Lambda service->>SQS queue: Delete successful messages processed
SQS queue-->>SQS queue: Failed messages return
deactivate SQS queue
```
<i>SQS FIFO mechanism with Batch Item Failures</i>
</center>

#### Kinesis and DynamoDB Streams

> Read more about [Batch Failure Reporting feature](https://docs.aws.amazon.com/lambda/latest/dg/with-kinesis.html#services-kinesis-batchfailurereporting){target="_blank"}.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.batch import (
SqsFifoPartialProcessor,
process_partial_response,
)
from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord
from aws_lambda_powertools.utilities.typing import LambdaContext

processor = SqsFifoPartialProcessor(skip_group_on_error=True)
tracer = Tracer()
logger = Logger()


@tracer.capture_method
def record_handler(record: SQSRecord):
payload: str = record.json_body # if json string data, otherwise record.body for str
logger.info(payload)


@logger.inject_lambda_context
@tracer.capture_lambda_handler
def lambda_handler(event, context: LambdaContext):
return process_partial_response(event=event, record_handler=record_handler, processor=processor, context=context)
119 changes: 111 additions & 8 deletions tests/functional/test_utilities_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from aws_lambda_powertools.utilities.parser.models import (
DynamoDBStreamChangedRecordModel,
DynamoDBStreamRecordModel,
SqsRecordModel,
)
from aws_lambda_powertools.utilities.parser.types import Literal
from tests.functional.batch.sample_models import (
Expand All @@ -38,6 +39,32 @@
from tests.functional.utils import b64_to_str, str_to_b64


@pytest.fixture(scope="module")
def sqs_event_fifo_factory() -> Callable:
def factory(body: str, message_group_id: str = ""):
return {
"messageId": f"{uuid.uuid4()}",
"receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a",
"body": body,
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1703675223472",
"SequenceNumber": "18882884930918384133",
"MessageGroupId": message_group_id,
"SenderId": "SenderId",
"MessageDeduplicationId": "1eea03c3f7e782c7bdc2f2a917f40389314733ff39f5ab16219580c0109ade98",
"ApproximateFirstReceiveTimestamp": "1703675223484",
},
"messageAttributes": {},
"md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3",
"eventSource": "aws:sqs",
"eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue",
"awsRegion": "us-east-1",
}

return factory


@pytest.fixture(scope="module")
def sqs_event_factory() -> Callable:
def factory(body: str):
Expand All @@ -48,7 +75,7 @@ def factory(body: str):
"attributes": {
"ApproximateReceiveCount": "1",
"SentTimestamp": "1545082649183",
"SenderId": "AIDAIENQZJOLO23YVJ4VO",
"SenderId": "SenderId",
"ApproximateFirstReceiveTimestamp": "1545082649185",
},
"messageAttributes": {},
Expand Down Expand Up @@ -660,10 +687,10 @@ def lambda_handler(event, context):
assert "All records failed processing. " in str(e.value)


def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_factory, record_handler):
def test_sqs_fifo_batch_processor_middleware_success_only(sqs_event_fifo_factory, record_handler):
# GIVEN
first_record = SQSRecord(sqs_event_factory("success"))
second_record = SQSRecord(sqs_event_factory("success"))
first_record = SQSRecord(sqs_event_fifo_factory("success"))
second_record = SQSRecord(sqs_event_fifo_factory("success"))
event = {"Records": [first_record.raw_event, second_record.raw_event]}

processor = SqsFifoPartialProcessor()
Expand All @@ -679,12 +706,12 @@ def lambda_handler(event, context):
assert result["batchItemFailures"] == []


def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_factory, record_handler):
def test_sqs_fifo_batch_processor_middleware_with_failure(sqs_event_fifo_factory, record_handler):
# GIVEN
first_record = SQSRecord(sqs_event_factory("success"))
second_record = SQSRecord(sqs_event_factory("fail"))
first_record = SQSRecord(sqs_event_fifo_factory("success"))
second_record = SQSRecord(sqs_event_fifo_factory("fail"))
# this would normally succeed, but since it's a FIFO queue, it will be marked as failure
third_record = SQSRecord(sqs_event_factory("success"))
third_record = SQSRecord(sqs_event_fifo_factory("success"))
event = {"Records": [first_record.raw_event, second_record.raw_event, third_record.raw_event]}

processor = SqsFifoPartialProcessor()
Expand All @@ -702,6 +729,82 @@ def lambda_handler(event, context):
assert result["batchItemFailures"][1]["itemIdentifier"] == third_record.message_id


def test_sqs_fifo_batch_processor_middleware_with_skip_group_on_error(sqs_event_fifo_factory, record_handler):
# GIVEN a batch of 5 records with 3 different MessageGroupID
first_record = SQSRecord(sqs_event_fifo_factory("success", "1"))
second_record = SQSRecord(sqs_event_fifo_factory("success", "1"))
third_record = SQSRecord(sqs_event_fifo_factory("fail", "2"))
fourth_record = SQSRecord(sqs_event_fifo_factory("success", "2"))
fifth_record = SQSRecord(sqs_event_fifo_factory("fail", "3"))
event = {
"Records": [
first_record.raw_event,
second_record.raw_event,
third_record.raw_event,
fourth_record.raw_event,
fifth_record.raw_event,
],
}

# WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID
processor = SqsFifoPartialProcessor(skip_group_on_error=True)

@batch_processor(record_handler=record_handler, processor=processor)
def lambda_handler(event, context):
return processor.response()

# WHEN
result = lambda_handler(event, {})

# THEN only failed messages should originate from MessageGroupID 3
assert len(result["batchItemFailures"]) == 3
assert result["batchItemFailures"][0]["itemIdentifier"] == third_record.message_id
assert result["batchItemFailures"][1]["itemIdentifier"] == fourth_record.message_id
assert result["batchItemFailures"][2]["itemIdentifier"] == fifth_record.message_id


def test_sqs_fifo_batch_processor_middleware_with_skip_group_on_error_and_model(sqs_event_fifo_factory, record_handler):
# GIVEN a batch of 5 records with 3 different MessageGroupID
first_record = SQSRecord(sqs_event_fifo_factory("success", "1"))
second_record = SQSRecord(sqs_event_fifo_factory("success", "1"))
third_record = SQSRecord(sqs_event_fifo_factory("fail", "2"))
fourth_record = SQSRecord(sqs_event_fifo_factory("success", "2"))
fifth_record = SQSRecord(sqs_event_fifo_factory("fail", "3"))
event = {
"Records": [
first_record.raw_event,
second_record.raw_event,
third_record.raw_event,
fourth_record.raw_event,
fifth_record.raw_event,
],
}

class OrderSqsRecord(SqsRecordModel):
receiptHandle: str

# WHEN the FIFO processor is set to continue processing even after encountering errors in specific MessageGroupID
# WHEN processor is using a Pydantic Model we must be able to access MessageGroupID property
processor = SqsFifoPartialProcessor(skip_group_on_error=True, model=OrderSqsRecord)

def record_handler(record: OrderSqsRecord):
if record.body == "fail":
raise ValueError("blah")

@batch_processor(record_handler=record_handler, processor=processor)
def lambda_handler(event, context):
return processor.response()

# WHEN
result = lambda_handler(event, {})

# THEN only failed messages should originate from MessageGroupID 3
assert len(result["batchItemFailures"]) == 3
assert result["batchItemFailures"][0]["itemIdentifier"] == third_record.message_id
assert result["batchItemFailures"][1]["itemIdentifier"] == fourth_record.message_id
assert result["batchItemFailures"][2]["itemIdentifier"] == fifth_record.message_id


def test_async_batch_processor_middleware_success_only(sqs_event_factory, async_record_handler):
# GIVEN
first_record = SQSRecord(sqs_event_factory("success"))
Expand Down