Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,6 @@ public Optional<List<InternalNode>> getPartitionToNodeMap()
return partitionToNodeMap;
}

public FaultTolerantPartitioningScheme withPartitionCount(int partitionCount)
{
return new FaultTolerantPartitioningScheme(
partitionCount,
this.bucketToPartitionMap,
this.splitToBucketFunction,
this.partitionToNodeMap);
}

@Override
public String toString()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.function.ToIntFunction;
import java.util.stream.IntStream;

import static com.google.common.base.Verify.verify;
import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION;
import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION;
import static java.util.Objects.requireNonNull;
Expand All @@ -43,7 +44,7 @@ public class FaultTolerantPartitioningSchemeFactory
private final Session session;
private final int maxPartitionCount;

private final Map<PartitioningHandle, FaultTolerantPartitioningScheme> cache = new HashMap<>();
private final Map<CacheKey, FaultTolerantPartitioningScheme> cache = new HashMap<>();

public FaultTolerantPartitioningSchemeFactory(NodePartitioningManager nodePartitioningManager, Session session, int maxPartitionCount)
{
Expand All @@ -54,18 +55,20 @@ public FaultTolerantPartitioningSchemeFactory(NodePartitioningManager nodePartit

public FaultTolerantPartitioningScheme get(PartitioningHandle handle, Optional<Integer> partitionCount)
{
FaultTolerantPartitioningScheme result = cache.get(handle);
CacheKey cacheKey = new CacheKey(handle, partitionCount);
FaultTolerantPartitioningScheme result = cache.get(cacheKey);
if (result == null) {
// Avoid using computeIfAbsent as the "get" method is called recursively from the "create" method
result = create(handle, partitionCount);
cache.put(handle, result);
}
else if (partitionCount.isPresent()) {
// With runtime adaptive partitioning, it's no longer guaranteed that the same handle will always map to
// the same partition count. Therefore, use the supplied `partitionCount` as the source of truth.
result = result.withPartitionCount(partitionCount.get());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

withPartitionCount became unused, remove

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before the change, we guaranteed that "if value is taken from the cache, the partitionCount is made to be exactly as requested".
i don't know why we didn't guarantee that for new values.

maybe we could add this

if (partitionCount.isPresent()) {
  verify(result.getPartitionCount() == partitionCount.get(), "...");
}

however, create doesn't seem to guarantee this property -- return new FaultTolerantPartitioningScheme(1, ....

Thus this PR (1) fixes caching logic
and introduces a pretty subtle side-effect (maybe correct). If this side effect is intentional, it would be good to call it out (perhaps as separate commit)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure this method should not be called with non-empty partition count for case when create would return scheme with partitionCount == 1.
verify seems like a good change.

if (partitionCount.isPresent()) {
verify(
result.getPartitionCount() == partitionCount.get(),
"expected partitionCount to be %s but got %s",
partitionCount.get(),
result.getPartitionCount());
}
cache.put(cacheKey, result);
}

return result;
}

Expand Down Expand Up @@ -143,4 +146,13 @@ private static FaultTolerantPartitioningScheme createArbitraryConnectorSpecificS
Optional.of(splitToBucket),
Optional.empty());
}

private record CacheKey(PartitioningHandle handle, Optional<Integer> partitionCount)
{
private CacheKey
{
requireNonNull(handle, "handle is null");
requireNonNull(partitionCount, "partitionCount is null");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
import static com.google.common.collect.Iterables.getOnlyElement;
import static com.google.common.collect.MoreCollectors.onlyElement;
import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly;
import static io.trino.SystemSessionProperties.DETERMINE_PARTITION_COUNT_FOR_WRITE_ENABLED;
import static io.trino.SystemSessionProperties.SCALE_WRITERS;
import static io.trino.SystemSessionProperties.TASK_MAX_WRITER_COUNT;
import static io.trino.SystemSessionProperties.TASK_MIN_WRITER_COUNT;
Expand Down Expand Up @@ -6428,23 +6429,28 @@ private void testMergeUpdateWithVariousLayouts(int writers, String partitioning)
@Override
public void testMergeMultipleOperations()
{
testMergeMultipleOperations(1, "");
testMergeMultipleOperations(4, "");
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['customer'])");
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['customer'])");
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['purchase'])");
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['purchase'])");
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(customer, 3)'])");
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(customer, 3)'])");
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])");
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])");
testMergeMultipleOperations(1, "", false);
testMergeMultipleOperations(4, "", false);
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['customer'])", false);
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['customer'])", false);
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['purchase'])", false);
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['purchase'])", false);
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(customer, 3)'])", false);
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(customer, 3)'])", false);
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])", false);
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])", false);
testMergeMultipleOperations(1, "", true);
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['customer'])", true);
testMergeMultipleOperations(1, "WITH (partitioning = ARRAY['bucket(customer, 3)'])", true);
testMergeMultipleOperations(4, "WITH (partitioning = ARRAY['bucket(purchase, 4)'])", true);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a data provider would produce much more readable failure messages
otherwise when inspecting CI error message you need to hope the line numbers didn't change between CI and your local copy

out of scope for this pr

}

public void testMergeMultipleOperations(int writers, String partitioning)
public void testMergeMultipleOperations(int writers, String partitioning, boolean determinePartitionCountForWrite)
{
Session session = Session.builder(getSession())
.setSystemProperty(TASK_MIN_WRITER_COUNT, String.valueOf(writers))
.setSystemProperty(TASK_MAX_WRITER_COUNT, String.valueOf(writers))
.setSystemProperty(DETERMINE_PARTITION_COUNT_FOR_WRITE_ENABLED, Boolean.toString(determinePartitionCountForWrite))
.build();

int targetCustomerCount = 32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ public static Map<String, String> getExtraProperties()
.put("retry-policy", "TASK")
.put("retry-initial-delay", "50ms")
.put("retry-max-delay", "100ms")
.put("fault-tolerant-execution-min-partition-count", "4")
.put("fault-tolerant-execution-max-partition-count", "5")
.put("fault-tolerant-execution-min-partition-count-for-write", "4")
.put("fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-min", "5MB")
.put("fault-tolerant-execution-arbitrary-distribution-compute-task-target-size-max", "10MB")
.put("fault-tolerant-execution-arbitrary-distribution-write-task-target-size-min", "10MB")
Expand Down