diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningScheme.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningScheme.java index 9f4ff0d00b7a..04bc2a290071 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningScheme.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningScheme.java @@ -96,15 +96,6 @@ public Optional> getPartitionToNodeMap() return partitionToNodeMap; } - public FaultTolerantPartitioningScheme withPartitionCount(int partitionCount) - { - return new FaultTolerantPartitioningScheme( - partitionCount, - this.bucketToPartitionMap, - this.splitToBucketFunction, - this.partitionToNodeMap); - } - @Override public String toString() { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningSchemeFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningSchemeFactory.java index d81a4f3195ed..509a8c4af086 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningSchemeFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/FaultTolerantPartitioningSchemeFactory.java @@ -32,6 +32,7 @@ import java.util.function.ToIntFunction; import java.util.stream.IntStream; +import static com.google.common.base.Verify.verify; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; import static java.util.Objects.requireNonNull; @@ -43,7 +44,7 @@ public class FaultTolerantPartitioningSchemeFactory private final Session session; private final int maxPartitionCount; - private final Map cache = new HashMap<>(); + private final Map cache = new HashMap<>(); public FaultTolerantPartitioningSchemeFactory(NodePartitioningManager nodePartitioningManager, Session session, int maxPartitionCount) { @@ -54,18 +55,20 @@ public FaultTolerantPartitioningSchemeFactory(NodePartitioningManager nodePartit public FaultTolerantPartitioningScheme get(PartitioningHandle handle, Optional partitionCount) { - FaultTolerantPartitioningScheme result = cache.get(handle); + CacheKey cacheKey = new CacheKey(handle, partitionCount); + FaultTolerantPartitioningScheme result = cache.get(cacheKey); if (result == null) { // Avoid using computeIfAbsent as the "get" method is called recursively from the "create" method result = create(handle, partitionCount); - cache.put(handle, result); - } - else if (partitionCount.isPresent()) { - // With runtime adaptive partitioning, it's no longer guaranteed that the same handle will always map to - // the same partition count. Therefore, use the supplied `partitionCount` as the source of truth. - result = result.withPartitionCount(partitionCount.get()); + if (partitionCount.isPresent()) { + verify( + result.getPartitionCount() == partitionCount.get(), + "expected partitionCount to be %s but got %s", + partitionCount.get(), + result.getPartitionCount()); + } + cache.put(cacheKey, result); } - return result; } @@ -143,4 +146,13 @@ private static FaultTolerantPartitioningScheme createArbitraryConnectorSpecificS Optional.of(splitToBucket), Optional.empty()); } + + private record CacheKey(PartitioningHandle handle, Optional partitionCount) + { + private CacheKey + { + requireNonNull(handle, "handle is null"); + requireNonNull(partitionCount, "partitionCount is null"); + } + } } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 95929f03fd16..0d6f2c444951 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -100,6 +100,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static com.google.common.collect.MoreCollectors.onlyElement; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; +import static io.trino.SystemSessionProperties.DETERMINE_PARTITION_COUNT_FOR_WRITE_ENABLED; import static io.trino.SystemSessionProperties.SCALE_WRITERS; import static io.trino.SystemSessionProperties.TASK_MAX_WRITER_COUNT; import static io.trino.SystemSessionProperties.TASK_MIN_WRITER_COUNT; @@ -6428,23 +6429,28 @@ private void testMergeUpdateWithVariousLayouts(int writers, String partitioning) @Override public void testMergeMultipleOperations() { - testMergeMultipleOperations(1, ""); - testMergeMultipleOperations(4, ""); - testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['customer'])"); - testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['customer'])"); - testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['purchase'])"); - testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['purchase'])"); - testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(customer, 3)'])"); - testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(customer, 3)'])"); - testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])"); - testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])"); + testMergeMultipleOperations(1, "", false); + testMergeMultipleOperations(4, "", false); + testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['customer'])", false); + testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['customer'])", false); + testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['purchase'])", false); + testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['purchase'])", false); + testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(customer, 3)'])", false); + testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(customer, 3)'])", false); + testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])", false); + testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])", false); + testMergeMultipleOperations(1, "", true); + testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['customer'])", true); + testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(customer, 3)'])", true); + testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])", true); } - public void testMergeMultipleOperations(int writers, String partitioning) + public void testMergeMultipleOperations(int writers, String partitioning, boolean determinePartitionCountForWrite) { Session session = Session.builder(getSession()) .setSystemProperty(TASK_MIN_WRITER_COUNT, String.valueOf(writers)) .setSystemProperty(TASK_MAX_WRITER_COUNT, String.valueOf(writers)) + .setSystemProperty(DETERMINE_PARTITION_COUNT_FOR_WRITE_ENABLED, Boolean.toString(determinePartitionCountForWrite)) .build(); int targetCustomerCount = 32; diff --git a/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java b/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java index 96f2adffd60b..00856e88b842 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java @@ -27,7 +27,9 @@ public static Map getExtraProperties() .put("retry-policy", "TASK") .put("retry-initial-delay", "50ms") .put("retry-max-delay", "100ms") + .put("fault-tolerant-execution-min-partition-count", "4") .put("fault-tolerant-execution-max-partition-count", "5") + .put("fault-tolerant-execution-min-partition-count-for-write", "4") .put("fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-min", "5MB") .put("fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-max", "10MB") .put("fault-tolerant-execution-arbitrary-distribution-write-task-target-size-min", "10MB")