diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ArbitraryDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ArbitraryDistributionSplitAssigner.java index 6a767e7ff354..b917d2448040 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ArbitraryDistributionSplitAssigner.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ArbitraryDistributionSplitAssigner.java @@ -60,6 +60,7 @@ class ArbitraryDistributionSplitAssigner private int nextPartitionId; private int adaptiveCounter; private long targetPartitionSizeInBytes; + private long roundedTargetPartitionSizeInBytes; private final List allAssignments = new ArrayList<>(); private final Map, PartitionAssignment> openAssignments = new HashMap<>(); @@ -94,6 +95,7 @@ class ArbitraryDistributionSplitAssigner this.maxTaskSplitCount = maxTaskSplitCount; this.targetPartitionSizeInBytes = minTargetPartitionSizeInBytes; + this.roundedTargetPartitionSizeInBytes = minTargetPartitionSizeInBytes; } @Override @@ -200,7 +202,7 @@ private AssignmentResult assignPartitionedSplits(PlanNodeId planNodeId, List hostRequirement = getHostRequirement(split); PartitionAssignment partitionAssignment = openAssignments.get(hostRequirement); long splitSizeInBytes = getSplitSizeInBytes(split); - if (partitionAssignment != null && ((partitionAssignment.getAssignedDataSizeInBytes() + splitSizeInBytes > targetPartitionSizeInBytes) + if (partitionAssignment != null && ((partitionAssignment.getAssignedDataSizeInBytes() + splitSizeInBytes > roundedTargetPartitionSizeInBytes) || (partitionAssignment.getAssignedSplitCount() + 1 > maxTaskSplitCount))) { partitionAssignment.setFull(true); for (PlanNodeId partitionedSourceNodeId : partitionedSources) { @@ -221,7 +223,8 @@ private AssignmentResult assignPartitionedSplits(PlanNodeId planNodeId, List= adaptiveGrowthPeriod) { targetPartitionSizeInBytes = (long) min(maxTargetPartitionSizeInBytes, ceil(targetPartitionSizeInBytes * adaptiveGrowthFactor)); // round to a multiple of minTargetPartitionSizeInBytes so work will be evenly distributed among drivers of a task - targetPartitionSizeInBytes = (targetPartitionSizeInBytes + minTargetPartitionSizeInBytes - 1) / minTargetPartitionSizeInBytes * minTargetPartitionSizeInBytes; + roundedTargetPartitionSizeInBytes = round(targetPartitionSizeInBytes * 1.0 / minTargetPartitionSizeInBytes) * minTargetPartitionSizeInBytes; + verify(roundedTargetPartitionSizeInBytes > 0, "roundedTargetPartitionSizeInBytes %s not positive", roundedTargetPartitionSizeInBytes); adaptiveCounter = 0; } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestArbitraryDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestArbitraryDistributionSplitAssigner.java index 2aa6744174b9..cff502fedf13 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestArbitraryDistributionSplitAssigner.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestArbitraryDistributionSplitAssigner.java @@ -512,6 +512,86 @@ public void testAdaptiveTaskSizing() .build()); } + @Test + public void testAdaptiveTaskSizingRounding() + { + Set partitionedSources = ImmutableSet.of(PARTITIONED_1); + List batches = ImmutableList.of( + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(1), createSplit(2), createSplit(3)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(4), createSplit(5), createSplit(6)), false), + new SplitBatch(PARTITIONED_1, ImmutableList.of(createSplit(7), createSplit(8), createSplit(9)), true)); + SplitAssigner splitAssigner = new ArbitraryDistributionSplitAssigner( + Optional.of(TEST_CATALOG_HANDLE), + partitionedSources, + ImmutableSet.of(), + 1, + 1.3, + 100, + 400, + 100, + 5); + SplitAssignerTester tester = new SplitAssignerTester(); + for (SplitBatch batch : batches) { + PlanNodeId planNodeId = batch.getPlanNodeId(); + List splits = batch.getSplits(); + boolean noMoreSplits = batch.isNoMoreSplits(); + tester.update(splitAssigner.assign(planNodeId, createSplitsMultimap(splits), noMoreSplits)); + tester.checkContainsSplits(planNodeId, splits, false); + } + tester.update(splitAssigner.finish()); + List taskDescriptors = tester.getTaskDescriptors().orElseThrow(); + assertThat(taskDescriptors).hasSize(5); + + // target size 100, round to 100 + TaskDescriptor taskDescriptor0 = taskDescriptors.get(0); + assertTaskDescriptor( + taskDescriptor0, + taskDescriptor0.getPartitionId(), + ImmutableListMultimap.builder() + .put(PARTITIONED_1, createSplit(1)) + .build()); + + // target size 130, round to 100 + TaskDescriptor taskDescriptor1 = taskDescriptors.get(1); + assertTaskDescriptor( + taskDescriptor1, + taskDescriptor1.getPartitionId(), + ImmutableListMultimap.builder() + .put(PARTITIONED_1, createSplit(2)) + .build()); + + // target size 169, round to 200 + TaskDescriptor taskDescriptor2 = taskDescriptors.get(2); + assertTaskDescriptor( + taskDescriptor2, + taskDescriptor2.getPartitionId(), + ImmutableListMultimap.builder() + .put(PARTITIONED_1, createSplit(3)) + .put(PARTITIONED_1, createSplit(4)) + .build()); + + // target size 220, round to 200 + TaskDescriptor taskDescriptor3 = taskDescriptors.get(3); + assertTaskDescriptor( + taskDescriptor3, + taskDescriptor3.getPartitionId(), + ImmutableListMultimap.builder() + .put(PARTITIONED_1, createSplit(5)) + .put(PARTITIONED_1, createSplit(6)) + .build()); + + // target size 286, round to 300 + TaskDescriptor taskDescriptor4 = taskDescriptors.get(4); + assertTaskDescriptor( + taskDescriptor4, + taskDescriptor4.getPartitionId(), + ImmutableListMultimap.builder() + .put(PARTITIONED_1, createSplit(7)) + .put(PARTITIONED_1, createSplit(8)) + .put(PARTITIONED_1, createSplit(9)) + .build()); + } + private void fuzzTesting(boolean withHostRequirements) { Set partitionedSources = new HashSet<>();