From f1e48a9ec2e367d24f4c4306a026e8006af1edb4 Mon Sep 17 00:00:00 2001 From: Bert Blommers Date: Sun, 10 Nov 2024 13:38:03 -0100 Subject: [PATCH] DynamoDB: scan() now supports parallelization using the Segment/TotalSegments parameters (#8303) --- moto/dynamodb/models/__init__.py | 2 + moto/dynamodb/models/dynamo_type.py | 31 ++++- moto/dynamodb/models/table.py | 6 +- moto/dynamodb/responses.py | 31 ++++- .../exceptions/test_dynamodb_exceptions.py | 30 ++++- tests/test_dynamodb/test_dynamodb_scan.py | 125 ++++++++++++++++++ 6 files changed, 211 insertions(+), 14 deletions(-) diff --git a/moto/dynamodb/models/__init__.py b/moto/dynamodb/models/__init__.py index b6c823cd9925..d3d1a71e81a8 100644 --- a/moto/dynamodb/models/__init__.py +++ b/moto/dynamodb/models/__init__.py @@ -401,6 +401,7 @@ def scan( index_name: str, consistent_read: bool, projection_expression: Optional[List[List[str]]], + segments: Union[Tuple[None, None], Tuple[int, int]], ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: table = self.get_table(table_name) @@ -421,6 +422,7 @@ def scan( index_name, consistent_read, projection_expression, + segments=segments, ) def update_item( diff --git a/moto/dynamodb/models/dynamo_type.py b/moto/dynamodb/models/dynamo_type.py index 62f87356f87a..3aa92fd53804 100644 --- a/moto/dynamodb/models/dynamo_type.py +++ b/moto/dynamodb/models/dynamo_type.py @@ -1,7 +1,7 @@ import base64 import copy from decimal import Decimal -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from boto3.dynamodb.types import TypeDeserializer, TypeSerializer from botocore.utils import merge_dicts @@ -12,6 +12,7 @@ IncorrectDataType, ItemSizeTooLarge, ) +from moto.utilities.utils import md5_hash from .utilities import bytesize, find_nested_key @@ -455,3 +456,31 @@ def project(self, projection_expressions: List[List[str]]) -> "Item": # We need to convert that into DynamoDB dictionary ({'M': {'key': {'S': 'value'}}}) attrs=serializer.serialize(result)["M"], ) + + def is_within_segment( + self, segments: Union[Tuple[None, None], Tuple[int, int]] + ) -> bool: + """ + Segments can be either (x, y) or (None, None) + None, None => the user requested the entire table, so the item always falls within that + x, y => the user requested segment x out of y + + Segment membership is computed based on the value of the hash key + """ + if segments == (None, None): + return True + + segment, total_segments = segments + # Creates a reproducible hash number for this item (between 0 and 256) + # Note that we can't use the builtin hash() method, as that is not deterministic between executions + # + # Using a hash based on the hash key ensures parity with how AWS seems to behave: + # - Items are not divided equally between segment + # - Items always fall in the same segment, regardless of how often you call `scan()` + # - Items with the same hash key but different range keys always fall in the same segment + # - Items with different hash keys may be part of different segments + # + item_hash = md5_hash(self.hash_key.value.encode("utf8")).digest()[0] + # Modulo ensures that we always get a number between 0 and (total_segments) + item_segment = item_hash % total_segments + return segment == item_segment diff --git a/moto/dynamodb/models/table.py b/moto/dynamodb/models/table.py index bd074fe05103..73d382609373 100644 --- a/moto/dynamodb/models/table.py +++ b/moto/dynamodb/models/table.py @@ -1,6 +1,6 @@ import copy from collections import defaultdict -from typing import Any, Dict, List, Optional, Sequence, Tuple +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from moto.core.common_models import BaseModel, CloudFormationModel from moto.core.utils import unix_time, unix_time_millis, utcnow @@ -897,6 +897,7 @@ def scan( index_name: Optional[str] = None, consistent_read: bool = False, projection_expression: Optional[List[List[str]]] = None, + segments: Union[Tuple[None, None], Tuple[int, int]] = (None, None), ) -> Tuple[List[Item], int, Optional[Dict[str, Any]]]: results: List[Item] = [] result_size = 0 @@ -942,6 +943,9 @@ def scan( last_evaluated_key = None processing_previous_page = exclusive_start_key is not None for item in items: + if not item.is_within_segment(segments): + continue + # Cycle through the previous page of results # When we encounter our start key, we know we've reached the end of the previous page if processing_previous_page: diff --git a/moto/dynamodb/responses.py b/moto/dynamodb/responses.py index cb1f777f2851..c79331afcbe4 100644 --- a/moto/dynamodb/responses.py +++ b/moto/dynamodb/responses.py @@ -829,6 +829,24 @@ def scan(self) -> str: limit = self.body.get("Limit") index_name = self.body.get("IndexName") consistent_read = self.body.get("ConsistentRead", False) + segment = self.body.get("Segment") + total_segments = self.body.get("TotalSegments") + if segment is not None and total_segments is None: + raise MockValidationException( + "The TotalSegments parameter is required but was not present in the request when Segment parameter is present" + ) + if total_segments is not None and segment is None: + raise MockValidationException( + "The Segment parameter is required but was not present in the request when parameter TotalSegments is present" + ) + if ( + segment is not None + and total_segments is not None + and segment >= total_segments + ): + raise MockValidationException( + f"The Segment parameter is zero-based and must be less than parameter TotalSegments: Segment: {segment} is not less than TotalSegments: {total_segments}" + ) projection_expressions = self._adjust_projection_expression( projection_expression, expression_attribute_names @@ -840,12 +858,13 @@ def scan(self) -> str: filters, limit, exclusive_start_key, - filter_expression, - expression_attribute_names, - expression_attribute_values, - index_name, - consistent_read, - projection_expressions, + filter_expression=filter_expression, + expr_names=expression_attribute_names, + expr_values=expression_attribute_values, + index_name=index_name, + consistent_read=consistent_read, + projection_expression=projection_expressions, + segments=(segment, total_segments), ) except ValueError as err: raise MockValidationException(f"Bad Filter Expression: {err}") diff --git a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py index e0bb4c38d105..8861f36e7694 100644 --- a/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py +++ b/tests/test_dynamodb/exceptions/test_dynamodb_exceptions.py @@ -33,28 +33,40 @@ class BaseTest: @classmethod - def setup_class(cls): + def setup_class(cls, add_range=False): if not allow_aws_request(): cls.mock = mock_aws() cls.mock.start() cls.client = boto3.client("dynamodb", region_name="us-east-1") cls.table_name = "T" + str(uuid4())[0:6] + cls.has_range_key = add_range dynamodb = boto3.resource("dynamodb", region_name="us-east-1") # Create the DynamoDB table. + schema = [{"AttributeName": "pk", "KeyType": "HASH"}] + defs = [{"AttributeName": "pk", "AttributeType": "S"}] + if add_range: + schema.append({"AttributeName": "rk", "KeyType": "RANGE"}) + defs.append({"AttributeName": "rk", "AttributeType": "S"}) dynamodb.create_table( TableName=cls.table_name, - KeySchema=[{"AttributeName": "pk", "KeyType": "HASH"}], - AttributeDefinitions=[{"AttributeName": "pk", "AttributeType": "S"}], + KeySchema=schema, + AttributeDefinitions=defs, ProvisionedThroughput={"ReadCapacityUnits": 5, "WriteCapacityUnits": 5}, ) waiter = cls.client.get_waiter("table_exists") waiter.wait(TableName=cls.table_name) cls.table = dynamodb.Table(cls.table_name) - cls.table.put_item( - Item={"pk": "the-key", "subject": "123", "body": "some test msg"} - ) + + def setup_method(self): + # Empty table between runs + items = self.table.scan()["Items"] + for item in items: + if self.has_range_key: + self.table.delete_item(Key={"pk": item["pk"], "rk": item["rk"]}) + else: + self.table.delete_item(Key={"pk": item["pk"]}) @classmethod def teardown_class(cls): @@ -1296,6 +1308,12 @@ def test_query_with_missing_expression_attribute(): @pytest.mark.aws_verified class TestReturnValuesOnConditionCheckFailure(BaseTest): + def setup_method(self): + super().setup_method() + self.table.put_item( + Item={"pk": "the-key", "subject": "123", "body": "some test msg"} + ) + def test_put_item_does_not_return_old_item(self): with pytest.raises(ClientError) as exc: self.table.put_item( diff --git a/tests/test_dynamodb/test_dynamodb_scan.py b/tests/test_dynamodb/test_dynamodb_scan.py index 1f070a9b4140..0e9e77663306 100644 --- a/tests/test_dynamodb/test_dynamodb_scan.py +++ b/tests/test_dynamodb/test_dynamodb_scan.py @@ -7,6 +7,7 @@ from botocore.exceptions import ClientError from moto import mock_aws +from tests.test_dynamodb.exceptions.test_dynamodb_exceptions import BaseTest from . import dynamodb_aws_verified @@ -729,3 +730,127 @@ def test_scan_with_scanfilter(self): "Items" ] assert items == [{"partitionKey": "pk-1"}] + + +@pytest.mark.aws_verified +class TestParallelScan(BaseTest): + @staticmethod + def setup_class(cls): # pylint: disable=arguments-renamed + super().setup_class(add_range=True) + + def test_segment_only(self): + with pytest.raises(ClientError) as exc: + self.table.scan(Segment=1) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "The TotalSegments parameter is required but was not present in the request when Segment parameter is present" + ) + + def test_total_segments_only(self): + with pytest.raises(ClientError) as exc: + self.table.scan(TotalSegments=1) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "The Segment parameter is required but was not present in the request when parameter TotalSegments is present" + ) + + def test_parallelize_all_different_hash_keys(self): + for i in range(10): + self.table.put_item(Item={"pk": f"item{i}", "rk": "sth"}) + + resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"] + resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"] + resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"] + + assert len(resp1) + len(resp2) + len(resp3) == 10 + + def test_parallelize_different_hash_key_per_segment(self): + for i in range(3): + for j in range(4): + self.table.put_item(Item={"pk": f"item{i}", "rk": f"rk{j}"}) + + resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"] + resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"] + resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"] + + assert len(resp1) + len(resp2) + len(resp3) == 12 + + def test_scan_using_filter_expression(self): + # AWS seems to return all data in Segment 1 + for i in range(10): + self.table.put_item(Item={"pk": "item", "rk": f"range{i}"}) + for i in range(10): + self.table.put_item(Item={"pk": "n/a", "rk": f"range{i}"}) + for i in range(20, 10, -1): + self.table.put_item(Item={"pk": "item", "rk": f"range{i}"}) + + resp1 = self.table.scan( + FilterExpression=Attr("pk").eq("item"), Segment=0, TotalSegments=3 + )["Items"] + resp2 = self.table.scan( + FilterExpression=Attr("pk").eq("item"), Segment=1, TotalSegments=3 + )["Items"] + resp3 = self.table.scan( + FilterExpression=Attr("pk").eq("item"), Segment=2, TotalSegments=3 + )["Items"] + + assert len(resp1) + len(resp2) + len(resp3) == 20 + + def test_scan_single_hash_key(self): + # AWS seems to return all data in Segment 1 + for i in range(10): + self.table.put_item(Item={"pk": "item", "rk": f"range{i}"}) + for i in range(20, 10, -1): + self.table.put_item(Item={"pk": "item", "rk": f"range{i}"}) + + resp1 = self.table.scan(Segment=0, TotalSegments=3)["Items"] + resp2 = self.table.scan(Segment=1, TotalSegments=3)["Items"] + resp3 = self.table.scan(Segment=2, TotalSegments=3)["Items"] + + assert len(resp1) + len(resp2) + len(resp3) == 20 + + def test_pagination(self): + for i in range(50): + self.table.put_item(Item={"pk": "item", "rk": f"range{i}"}) + + resp1 = self.table.scan(Segment=0, TotalSegments=3, Limit=10) + resp2 = self.table.scan(Segment=1, TotalSegments=3, Limit=10) + resp3 = self.table.scan(Segment=2, TotalSegments=3, Limit=10) + + first_pass = len(resp1["Items"]) + len(resp2["Items"]) + len(resp3["Items"]) + assert first_pass <= 30 + + second_pass = 0 + if "LastEvaluatedKey" in resp1: + resp = self.table.scan( + Segment=0, TotalSegments=3, ExclusiveStartKey=resp1["LastEvaluatedKey"] + ) + second_pass += len(resp["Items"]) + + if "LastEvaluatedKey" in resp2: + resp = self.table.scan( + Segment=1, TotalSegments=3, ExclusiveStartKey=resp2["LastEvaluatedKey"] + ) + second_pass += len(resp["Items"]) + + if "LastEvaluatedKey" in resp3: + resp = self.table.scan( + Segment=2, TotalSegments=3, ExclusiveStartKey=resp3["LastEvaluatedKey"] + ) + second_pass += len(resp["Items"]) + + assert first_pass + second_pass == 50 + + def test_segment_larger_than_total_segments(self): + with pytest.raises(ClientError) as exc: + self.table.scan(Segment=3, TotalSegments=3) + err = exc.value.response["Error"] + assert err["Code"] == "ValidationException" + assert ( + err["Message"] + == "The Segment parameter is zero-based and must be less than parameter TotalSegments: Segment: 3 is not less than TotalSegments: 3" + )