diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ArbitraryDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ArbitraryDistributionSplitAssigner.java index 4452ead39c02..89d70cf94cbe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ArbitraryDistributionSplitAssigner.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/ArbitraryDistributionSplitAssigner.java @@ -15,6 +15,7 @@ import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import io.trino.exchange.SpoolingExchangeInput; @@ -131,7 +132,7 @@ private AssignmentResult assignReplicatedSplits(PlanNodeId planNodeId, List singleSourcePartition(int sourcePartitionId, List splits) + { + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + builder.putAll(0, splits); + return builder.build(); + } + private AssignmentResult assignPartitionedSplits(PlanNodeId planNodeId, List splits, boolean noMoreSplits) { AssignmentResult.Builder assignment = AssignmentResult.builder(); @@ -210,7 +218,7 @@ private AssignmentResult assignPartitionedSplits(PlanNodeId planNodeId, List updatePartition( - int partitionId, + int taskPartitionId, PlanNodeId planNodeId, boolean readyForScheduling, - List splits, + ListMultimap splits, // sourcePartitionId -> splits boolean noMoreSplits) { if (getState().isDone()) { return Optional.empty(); } - StagePartition partition = getStagePartition(partitionId); + StagePartition partition = getStagePartition(taskPartitionId); partition.addSplits(planNodeId, splits, noMoreSplits); if (readyForScheduling && !partition.isTaskScheduled()) { partition.setTaskScheduled(true); - return Optional.of(PrioritizedScheduledTask.createSpeculative(stage.getStageId(), partitionId, schedulingPriority, eager)); + return Optional.of(PrioritizedScheduledTask.createSpeculative(stage.getStageId(), taskPartitionId, schedulingPriority, eager)); } return Optional.empty(); } @@ -1823,7 +1824,7 @@ public Optional schedule(int partitionId, ExchangeSinkInstanceHandle Map outputSelectors = getSourceOutputSelectors(); ListMultimap splits = ArrayListMultimap.create(); - splits.putAll(partition.getSplits()); + splits.putAll(partition.getSplits().getSplitsFlat()); outputSelectors.forEach((planNodeId, outputSelector) -> splits.put(planNodeId, createOutputSelectorSplit(outputSelector))); Set noMoreSplits = new HashSet<>(); @@ -2014,6 +2015,11 @@ public List taskFailed(TaskId taskId, ExecutionFailure runningPartitions.remove(partitionId); } + if (!remainingPartitions.contains(partitionId)) { + // another task for this partition finished successfully + return ImmutableList.of(); + } + RuntimeException failure = failureInfo.toException(); ErrorCode errorCode = failureInfo.getErrorCode(); partitionMemoryEstimator.registerPartitionFinished( @@ -2214,7 +2220,7 @@ public StagePartition( this.exchangeSinkHandle = requireNonNull(exchangeSinkHandle, "exchangeSinkHandle is null"); this.remoteSourceIds = ImmutableSet.copyOf(requireNonNull(remoteSourceIds, "remoteSourceIds is null")); requireNonNull(nodeRequirements, "nodeRequirements is null"); - this.openTaskDescriptor = Optional.of(new OpenTaskDescriptor(ImmutableListMultimap.of(), ImmutableSet.of(), nodeRequirements)); + this.openTaskDescriptor = Optional.of(new OpenTaskDescriptor(SplitsMapping.EMPTY, ImmutableSet.of(), nodeRequirements)); this.memoryRequirements = requireNonNull(memoryRequirements, "memoryRequirements is null"); this.remainingAttempts = maxTaskExecutionAttempts; } @@ -2224,7 +2230,7 @@ public ExchangeSinkHandle getExchangeSinkHandle() return exchangeSinkHandle; } - public void addSplits(PlanNodeId planNodeId, List splits, boolean noMoreSplits) + public void addSplits(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits) { checkState(openTaskDescriptor.isPresent(), "openTaskDescriptor is empty"); openTaskDescriptor = Optional.of(openTaskDescriptor.get().update(planNodeId, splits, noMoreSplits)); @@ -2233,7 +2239,7 @@ public void addSplits(PlanNodeId planNodeId, List splits, boolean noMoreS } for (RemoteTask task : tasks.values()) { task.addSplits(ImmutableListMultimap.builder() - .putAll(planNodeId, splits) + .putAll(planNodeId, splits.values()) .build()); if (noMoreSplits && isFinalOutputSelectorDelivered(planNodeId)) { task.noMoreSplits(planNodeId); @@ -2270,15 +2276,15 @@ public void seal() } } - public ListMultimap getSplits() + public SplitsMapping getSplits() { if (finished) { - return ImmutableListMultimap.of(); + return SplitsMapping.EMPTY; } return openTaskDescriptor.map(OpenTaskDescriptor::getSplits) .or(() -> taskDescriptorStorage.get(stageId, partitionId).map(TaskDescriptor::getSplits)) // execution is finished - .orElse(ImmutableListMultimap.of()); + .orElse(SplitsMapping.EMPTY); } public boolean isNoMoreSplits(PlanNodeId planNodeId) @@ -2436,18 +2442,25 @@ private static Split createOutputSelectorSplit(ExchangeSourceOutputSelector sele private static class OpenTaskDescriptor { - private final ListMultimap splits; + private final SplitsMapping splits; private final Set noMoreSplits; private final NodeRequirements nodeRequirements; - private OpenTaskDescriptor(ListMultimap splits, Set noMoreSplits, NodeRequirements nodeRequirements) + private OpenTaskDescriptor(SplitsMapping splits, Set noMoreSplits, NodeRequirements nodeRequirements) { - this.splits = ImmutableListMultimap.copyOf(requireNonNull(splits, "splits is null")); + this.splits = requireNonNull(splits, "splits is null"); this.noMoreSplits = ImmutableSet.copyOf(requireNonNull(noMoreSplits, "noMoreSplits is null")); this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); } - public ListMultimap getSplits() + private static Map> copySplits(Map> splits) + { + ImmutableMap.Builder> splitsBuilder = ImmutableMap.builder(); + splits.forEach((planNodeId, planNodeSplits) -> splitsBuilder.put(planNodeId, ImmutableListMultimap.copyOf(planNodeSplits))); + return splitsBuilder.buildOrThrow(); + } + + public SplitsMapping getSplits() { return splits; } @@ -2462,12 +2475,15 @@ public NodeRequirements getNodeRequirements() return nodeRequirements; } - public OpenTaskDescriptor update(PlanNodeId planNodeId, List splits, boolean noMoreSplits) + public OpenTaskDescriptor update(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits) { - ListMultimap updatedSplits = ImmutableListMultimap.builder() - .putAll(this.splits) - .putAll(planNodeId, splits) - .build(); + SplitsMapping.Builder updatedSplitsMapping = SplitsMapping.builder(this.splits); + + for (Map.Entry> entry : Multimaps.asMap(splits).entrySet()) { + Integer sourcePartition = entry.getKey(); + List partitionSplits = entry.getValue(); + updatedSplitsMapping.addSplits(planNodeId, sourcePartition, partitionSplits); + } Set updatedNoMoreSplits = this.noMoreSplits; if (noMoreSplits && !updatedNoMoreSplits.contains(planNodeId)) { @@ -2477,14 +2493,14 @@ public OpenTaskDescriptor update(PlanNodeId planNodeId, List splits, bool .build(); } return new OpenTaskDescriptor( - updatedSplits, + updatedSplitsMapping.build(), updatedNoMoreSplits, nodeRequirements); } public TaskDescriptor createTaskDescriptor(int partitionId) { - Set missingNoMoreSplits = Sets.difference(splits.keySet(), noMoreSplits); + Set missingNoMoreSplits = Sets.difference(splits.getPlanNodeIds(), noMoreSplits); checkState(missingNoMoreSplits.isEmpty(), "missing no more splits for plan nodes: %s", missingNoMoreSplits); return new TaskDescriptor( partitionId, diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/HashDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/HashDistributionSplitAssigner.java index df7762680fc4..db0e693e0fd6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/HashDistributionSplitAssigner.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/HashDistributionSplitAssigner.java @@ -16,6 +16,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; @@ -61,6 +62,7 @@ class HashDistributionSplitAssigner private final Set createdTaskPartitions = new HashSet<>(); private final Set completedSources = new HashSet<>(); + private final ListMultimap replicatedSplits = ArrayListMultimap.create(); private boolean allTaskPartitionsCreated; @@ -150,7 +152,7 @@ public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap replicatedSourcePartition(List splits) + { + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + builder.putAll(SINGLE_SOURCE_PARTITION_ID, splits); + return builder.build(); + } + @Override public AssignmentResult finish() { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/SingleDistributionSplitAssigner.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/SingleDistributionSplitAssigner.java index 5bbf9a1771a3..3f1801c382b6 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/SingleDistributionSplitAssigner.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/SingleDistributionSplitAssigner.java @@ -13,7 +13,7 @@ */ package io.trino.execution.scheduler.faulttolerant; -import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import io.trino.metadata.Split; @@ -57,7 +57,7 @@ public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap splits, boolean noMoreSplits); AssignmentResult finish(); @@ -48,14 +52,14 @@ record PartitionUpdate( int partitionId, PlanNodeId planNodeId, boolean readyForScheduling, - List splits, + ListMultimap splits, // sourcePartition -> splits boolean noMoreSplits) { public PartitionUpdate { requireNonNull(planNodeId, "planNodeId is null"); checkArgument(!(readyForScheduling && splits.isEmpty()), "partition update with empty splits marked as ready for scheduling"); - splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")); + splits = ImmutableListMultimap.copyOf(requireNonNull(splits, "splits is null")); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/SplitsMapping.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/SplitsMapping.java new file mode 100644 index 000000000000..d1f662dff9d4 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/SplitsMapping.java @@ -0,0 +1,291 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler.faulttolerant; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ListMultimap; +import com.google.common.collect.Multimaps; +import com.google.common.collect.Sets; +import io.trino.metadata.Split; +import io.trino.sql.planner.plan.PlanNodeId; +import jakarta.annotation.Nullable; + +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.airlift.slice.SizeOf.INTEGER_INSTANCE_SIZE; +import static io.airlift.slice.SizeOf.estimatedSizeOf; +import static io.airlift.slice.SizeOf.instanceSize; +import static java.util.Objects.requireNonNull; + +public final class SplitsMapping +{ + private static final int INSTANCE_SIZE = instanceSize(SplitsMapping.class); + + public static final SplitsMapping EMPTY = SplitsMapping.builder().build(); + + // not using Multimap to avoid extensive data structure copying when building updated SplitsMapping + private final Map>> splits; // plan-node -> hash-partition -> Split + + private SplitsMapping(ImmutableMap>> splits) + { + // Builder implementations ensure that external map as well as Maps/Lists used in values + // are immutable. + this.splits = splits; + } + + public Set getPlanNodeIds() + { + return splits.keySet(); + } + + public ListMultimap getSplitsFlat() + { + ImmutableListMultimap.Builder splitsFlat = ImmutableListMultimap.builder(); + for (Map.Entry>> entry : splits.entrySet()) { + // TODO can we do less copying? + splitsFlat.putAll(entry.getKey(), entry.getValue().values().stream().flatMap(Collection::stream).collect(toImmutableList())); + } + return splitsFlat.build(); + } + + public List getSplitsFlat(PlanNodeId planNodeId) + { + Map> splits = this.splits.get(planNodeId); + if (splits == null) { + return ImmutableList.of(); + } + verify(!splits.isEmpty(), "expected not empty splits list %s", splits); + + if (splits.size() == 1) { + return getOnlyElement(splits.values()); + } + + // TODO improve to not copy here; return view instead + ImmutableList.Builder result = ImmutableList.builder(); + for (List partitionSplits : splits.values()) { + result.addAll(partitionSplits); + } + return result.build(); + } + + @VisibleForTesting + ListMultimap getSplits(PlanNodeId planNodeId) + { + Map> splits = this.splits.get(planNodeId); + if (splits == null) { + return ImmutableListMultimap.of(); + } + verify(!splits.isEmpty(), "expected not empty splits list %s", splits); + + ImmutableListMultimap.Builder result = ImmutableListMultimap.builder(); + for (Map.Entry> entry : splits.entrySet()) { + result.putAll(entry.getKey(), entry.getValue()); + } + return result.build(); + } + + public long getRetainedSizeInBytes() + { + return INSTANCE_SIZE + + estimatedSizeOf( + splits, + PlanNodeId::getRetainedSizeInBytes, + planNodeSplits -> estimatedSizeOf( + planNodeSplits, + partitionId -> INTEGER_INSTANCE_SIZE, + splitList -> estimatedSizeOf(splitList, Split::getRetainedSizeInBytes))); + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SplitsMapping that = (SplitsMapping) o; + return Objects.equals(splits, that.splits); + } + + @Override + public int hashCode() + { + return Objects.hash(splits); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("splits", splits) + .toString(); + } + + public static Builder builder() + { + return new NewBuilder(); + } + + public static Builder builder(SplitsMapping mapping) + { + return new UpdatingBuilder(mapping); + } + + public long size() + { + return splits.values().stream() + .flatMap(sourcePartitionToSplits -> sourcePartitionToSplits.values().stream()) + .mapToLong(List::size) + .sum(); + } + + public abstract static class Builder + { + private Builder() {} // close for extension + + public Builder addSplit(PlanNodeId planNodeId, int partitionId, Split split) + { + return addSplits(planNodeId, partitionId, ImmutableList.of(split)); + } + + public Builder addSplits(PlanNodeId planNodeId, ListMultimap splits) + { + Multimaps.asMap(splits).forEach((partitionId, partitionSplits) -> addSplits(planNodeId, partitionId, partitionSplits)); + return this; + } + + public Builder addMapping(SplitsMapping updatingMapping) + { + for (Map.Entry>> entry : updatingMapping.splits.entrySet()) { + PlanNodeId planNodeId = entry.getKey(); + entry.getValue().forEach((partitionId, partitionSplits) -> addSplits(planNodeId, partitionId, partitionSplits)); + } + return this; + } + + public abstract Builder addSplits(PlanNodeId planNodeId, int partitionId, List splits); + + public abstract SplitsMapping build(); + } + + private static class UpdatingBuilder + extends Builder + { + private final SplitsMapping originalMapping; + private final Map>> updates = new HashMap<>(); + + public UpdatingBuilder(SplitsMapping originalMapping) + { + this.originalMapping = requireNonNull(originalMapping, "sourceMapping is null"); + } + + @Override + public Builder addSplits(PlanNodeId planNodeId, int partitionId, List splits) + { + if (splits.isEmpty()) { + // ensure we do not have empty lists in result splits map. + return this; + } + updates.computeIfAbsent(planNodeId, ignored -> new HashMap<>()) + .computeIfAbsent(partitionId, key -> ImmutableList.builder()) + .addAll(splits); + return this; + } + + @Override + public SplitsMapping build() + { + ImmutableMap.Builder>> result = ImmutableMap.builder(); + for (PlanNodeId planNodeId : Sets.union(originalMapping.splits.keySet(), updates.keySet())) { + Map> planNodeOriginalMapping = originalMapping.splits.getOrDefault(planNodeId, ImmutableMap.of()); + Map> planNodeUpdates = updates.getOrDefault(planNodeId, ImmutableMap.of()); + if (planNodeUpdates.isEmpty()) { + // just use original splits for planNodeId + result.put(planNodeId, planNodeOriginalMapping); + continue; + } + // create new mapping for planNodeId reusing as much of source as possible + ImmutableMap.Builder> targetSplitsMapBuilder = ImmutableMap.builder(); + for (Integer sourcePartitionId : Sets.union(planNodeOriginalMapping.keySet(), planNodeUpdates.keySet())) { + @Nullable List originalSplits = planNodeOriginalMapping.get(sourcePartitionId); + @Nullable ImmutableList.Builder splitUpdates = planNodeUpdates.get(sourcePartitionId); + targetSplitsMapBuilder.put(sourcePartitionId, mergeIfPresent(originalSplits, splitUpdates)); + } + result.put(planNodeId, targetSplitsMapBuilder.buildOrThrow()); + } + return new SplitsMapping(result.buildOrThrow()); + } + + private static List mergeIfPresent(@Nullable List list, @Nullable ImmutableList.Builder additionalElements) + { + if (additionalElements == null) { + // reuse source immutable split list + return requireNonNull(list, "list is null"); + } + if (list == null) { + return additionalElements.build(); + } + return ImmutableList.builder() + .addAll(list) + .addAll(additionalElements.build()) + .build(); + } + } + + private static class NewBuilder + extends Builder + { + private final Map>> splitsBuilder = new HashMap<>(); + + @Override + public Builder addSplits(PlanNodeId planNodeId, int partitionId, List splits) + { + if (splits.isEmpty()) { + // ensure we do not have empty lists in result splits map. + return this; + } + splitsBuilder.computeIfAbsent(planNodeId, ignored -> new HashMap<>()) + .computeIfAbsent(partitionId, ignored -> ImmutableList.builder()) + .addAll(splits); + return this; + } + + @Override + public SplitsMapping build() + { + return new SplitsMapping(splitsBuilder.entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + planNodeMapping -> planNodeMapping.getValue().entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + sourcePartitionMapping -> sourcePartitionMapping.getValue().build()))))); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptor.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptor.java index 6b0710768264..7521617f5a0b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptor.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptor.java @@ -13,16 +13,9 @@ */ package io.trino.execution.scheduler.faulttolerant; -import com.google.common.collect.ImmutableListMultimap; -import com.google.common.collect.ListMultimap; -import io.trino.metadata.Split; -import io.trino.sql.planner.plan.PlanNodeId; - import java.util.Objects; import static com.google.common.base.MoreObjects.toStringHelper; -import static com.google.common.collect.Multimaps.asMap; -import static io.airlift.slice.SizeOf.estimatedSizeOf; import static io.airlift.slice.SizeOf.instanceSize; import static java.util.Objects.requireNonNull; @@ -31,18 +24,18 @@ public class TaskDescriptor private static final int INSTANCE_SIZE = instanceSize(TaskDescriptor.class); private final int partitionId; - private final ListMultimap splits; + private final SplitsMapping splits; private final NodeRequirements nodeRequirements; private transient volatile long retainedSizeInBytes; public TaskDescriptor( int partitionId, - ListMultimap splits, + SplitsMapping splitsMapping, NodeRequirements nodeRequirements) { this.partitionId = partitionId; - this.splits = ImmutableListMultimap.copyOf(requireNonNull(splits, "splits is null")); + this.splits = requireNonNull(splitsMapping, "splitsMapping is null"); this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null"); } @@ -51,7 +44,7 @@ public int getPartitionId() return partitionId; } - public ListMultimap getSplits() + public SplitsMapping getSplits() { return splits; } @@ -95,7 +88,7 @@ public long getRetainedSizeInBytes() long result = retainedSizeInBytes; if (result == 0) { result = INSTANCE_SIZE - + estimatedSizeOf(asMap(splits), PlanNodeId::getRetainedSizeInBytes, splits -> estimatedSizeOf(splits, Split::getRetainedSizeInBytes)) + + splits.getRetainedSizeInBytes() + nodeRequirements.getRetainedSizeInBytes(); retainedSizeInBytes = result; } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage.java index 676e036352a7..599a593b4a7f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/TaskDescriptorStorage.java @@ -330,7 +330,7 @@ private String getDebugInfo() entry -> getDebugInfo(entry.getValue()))); List biggestSplits = descriptorsByStageId.entries().stream() - .flatMap(entry -> entry.getValue().getSplits().entries().stream().map(splitEntry -> Map.entry("%s/%s".formatted(entry.getKey(), splitEntry.getKey()), splitEntry.getValue()))) + .flatMap(entry -> entry.getValue().getSplits().getSplitsFlat().entries().stream().map(splitEntry -> Map.entry("%s/%s".formatted(entry.getKey(), splitEntry.getKey()), splitEntry.getValue()))) .sorted(Comparator.>comparingLong(entry -> entry.getValue().getRetainedSizeInBytes()).reversed()) .limit(3) .map(entry -> "{nodeId=%s, size=%s, split=%s}".formatted(entry.getKey(), entry.getValue().getRetainedSizeInBytes(), splitJsonCodec.toJson(entry.getValue()))) @@ -344,11 +344,11 @@ private String getDebugInfo(Collection taskDescriptors) int taskDescriptorsCount = taskDescriptors.size(); Stats taskDescriptorsRetainedSizeStats = Stats.of(taskDescriptors.stream().mapToLong(TaskDescriptor::getRetainedSizeInBytes)); - Set planNodeIds = taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().keySet().stream()).collect(toImmutableSet()); + Set planNodeIds = taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().keySet().stream()).collect(toImmutableSet()); Map splitsDebugInfo = new HashMap<>(); for (PlanNodeId planNodeId : planNodeIds) { - Stats splitCountStats = Stats.of(taskDescriptors.stream().mapToLong(taskDescriptor -> taskDescriptor.getSplits().asMap().get(planNodeId).size())); - Stats splitSizeStats = Stats.of(taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().get(planNodeId).stream()).mapToLong(Split::getRetainedSizeInBytes)); + Stats splitCountStats = Stats.of(taskDescriptors.stream().mapToLong(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().asMap().get(planNodeId).size())); + Stats splitSizeStats = Stats.of(taskDescriptors.stream().flatMap(taskDescriptor -> taskDescriptor.getSplits().getSplitsFlat().get(planNodeId).stream()).mapToLong(Split::getRetainedSizeInBytes)); splitsDebugInfo.put( planNodeId, "{splitCountMean=%s, splitCountStdDev=%s, splitSizeMean=%s, splitSizeStdDev=%s}".formatted( diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/SplitAssignerTester.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/SplitAssignerTester.java index fe9cd438fda4..92bef3ff0890 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/SplitAssignerTester.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/SplitAssignerTester.java @@ -38,15 +38,17 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.execution.scheduler.faulttolerant.SplitAssigner.SINGLE_SOURCE_PARTITION_ID; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.guava.api.Assertions.assertThat; class SplitAssignerTester { private final Map nodeRequirements = new HashMap<>(); - private final Map> splits = new HashMap<>(); + private final Map splits = new HashMap<>(); private final SetMultimap noMoreSplits = HashMultimap.create(); - private final Set sealedPartitions = new HashSet<>(); - private boolean noMorePartitions; + private final Set sealedTaskPartitions = new HashSet<>(); + private boolean noMoreTaskPartitions; private Optional> taskDescriptors = Optional.empty(); public Optional> getTaskDescriptors() @@ -54,40 +56,49 @@ public Optional> getTaskDescriptors() return taskDescriptors; } - public synchronized int getPartitionCount() + public synchronized int getTaskPartitionCount() { return nodeRequirements.size(); } - public synchronized NodeRequirements getNodeRequirements(int partition) + public synchronized NodeRequirements getNodeRequirements(int taskPartition) { - NodeRequirements result = nodeRequirements.get(partition); - checkArgument(result != null, "partition not found: %s", partition); + NodeRequirements result = nodeRequirements.get(taskPartition); + checkArgument(result != null, "task partition not found: %s", taskPartition); return result; } - public synchronized Set getSplitIds(int partition, PlanNodeId planNodeId) + public synchronized Set getSplitIds(int taskPartition, PlanNodeId planNodeId) { - ListMultimap partitionSplits = splits.getOrDefault(partition, ImmutableListMultimap.of()); - return partitionSplits.get(planNodeId).stream() + SplitsMapping taskPartitionSplits = splits.getOrDefault(taskPartition, SplitsMapping.EMPTY); + List splitsFlat = taskPartitionSplits.getSplitsFlat(planNodeId); + return splitsFlat.stream() .map(split -> (TestingConnectorSplit) split.getConnectorSplit()) .map(TestingConnectorSplit::getId) .collect(toImmutableSet()); } - public synchronized boolean isNoMoreSplits(int partition, PlanNodeId planNodeId) + public synchronized ListMultimap getSplitIdsBySourcePartition(int taskPartition, PlanNodeId planNodeId) { - return noMoreSplits.get(partition).contains(planNodeId); + SplitsMapping taskPartitionSplits = splits.getOrDefault(taskPartition, SplitsMapping.EMPTY); + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + taskPartitionSplits.getSplits(planNodeId).forEach((sourcePartition, split) -> builder.put(sourcePartition, TestingConnectorSplit.getSplitId(split))); + return builder.build(); } - public synchronized boolean isSealed(int partition) + public synchronized boolean isNoMoreSplits(int taskPartition, PlanNodeId planNodeId) { - return sealedPartitions.contains(partition); + return noMoreSplits.get(taskPartition).contains(planNodeId); } - public synchronized boolean isNoMorePartitions() + public synchronized boolean isSealed(int taskPartition) { - return noMorePartitions; + return sealedTaskPartitions.contains(taskPartition); + } + + public synchronized boolean isNoMoreTaskPartitions() + { + return noMoreTaskPartitions; } public void checkContainsSplits(PlanNodeId planNodeId, Collection splits, boolean replicated) @@ -95,13 +106,38 @@ public void checkContainsSplits(PlanNodeId planNodeId, Collection splits, Set expectedSplitIds = splits.stream() .map(TestingConnectorSplit::getSplitId) .collect(Collectors.toSet()); - for (int partitionId = 0; partitionId < getPartitionCount(); partitionId++) { - Set partitionSplitIds = getSplitIds(partitionId, planNodeId); + for (int taskPartitionId = 0; taskPartitionId < getTaskPartitionCount(); taskPartitionId++) { + Set taskPartitionSplitIds = getSplitIds(taskPartitionId, planNodeId); + if (replicated) { + assertThat(taskPartitionSplitIds).containsAll(expectedSplitIds); + } + else { + expectedSplitIds.removeAll(taskPartitionSplitIds); + } + } + if (!replicated) { + assertThat(expectedSplitIds).isEmpty(); + } + } + + public void checkContainsSplits(PlanNodeId planNodeId, ListMultimap splitsBySourcePartition, boolean replicated) + { + ListMultimap expectedSplitIds; + if (replicated) { + expectedSplitIds = ArrayListMultimap.create(); + expectedSplitIds.putAll(SINGLE_SOURCE_PARTITION_ID, buildSplitIds(splitsBySourcePartition).values()); + } + else { + expectedSplitIds = ArrayListMultimap.create(buildSplitIds(splitsBySourcePartition)); + } + + for (int taskPartitionId = 0; taskPartitionId < getTaskPartitionCount(); taskPartitionId++) { + ListMultimap taskPartitionSplitIds = getSplitIdsBySourcePartition(taskPartitionId, planNodeId); if (replicated) { - assertThat(partitionSplitIds).containsAll(expectedSplitIds); + assertThat(taskPartitionSplitIds).containsAllEntriesOf(expectedSplitIds); } else { - expectedSplitIds.removeAll(partitionSplitIds); + taskPartitionSplitIds.forEach(expectedSplitIds::remove); } } if (!replicated) { @@ -109,48 +145,61 @@ public void checkContainsSplits(PlanNodeId planNodeId, Collection splits, } } + private ListMultimap buildSplitIds(ListMultimap splitsBySourcePartition) + { + ImmutableListMultimap.Builder builder = ImmutableListMultimap.builder(); + splitsBySourcePartition.forEach((sourcePartition, split) -> builder.put(sourcePartition, TestingConnectorSplit.getSplitId(split))); + return builder.build(); + } + public void update(AssignmentResult assignment) { - for (Partition partition : assignment.partitionsAdded()) { - verify(!noMorePartitions, "noMorePartitions is set"); - verify(nodeRequirements.put(partition.partitionId(), partition.nodeRequirements()) == null, "partition already exist: %s", partition.partitionId()); + for (Partition taskPartition : assignment.partitionsAdded()) { + verify(!noMoreTaskPartitions, "noMoreTaskPartitions is set"); + verify(nodeRequirements.put(taskPartition.partitionId(), taskPartition.nodeRequirements()) == null, "task partition already exist: %s", taskPartition.partitionId()); } - for (PartitionUpdate partitionUpdate : assignment.partitionUpdates()) { - int partitionId = partitionUpdate.partitionId(); - verify(nodeRequirements.get(partitionId) != null, "partition does not exist: %s", partitionId); - verify(!sealedPartitions.contains(partitionId), "partition is sealed: %s", partitionId); - PlanNodeId planNodeId = partitionUpdate.planNodeId(); - if (!partitionUpdate.splits().isEmpty()) { - verify(!noMoreSplits.get(partitionId).contains(planNodeId), "noMoreSplits is set for partition %s and plan node %s", partitionId, planNodeId); - splits.computeIfAbsent(partitionId, (key) -> ArrayListMultimap.create()).putAll(planNodeId, partitionUpdate.splits()); + for (PartitionUpdate taskPartitionUpdate : assignment.partitionUpdates()) { + int taskPartitionId = taskPartitionUpdate.partitionId(); + verify(nodeRequirements.get(taskPartitionId) != null, "task partition does not exist: %s", taskPartitionId); + verify(!sealedTaskPartitions.contains(taskPartitionId), "task partition is sealed: %s", taskPartitionId); + PlanNodeId planNodeId = taskPartitionUpdate.planNodeId(); + if (!taskPartitionUpdate.splits().isEmpty()) { + verify(!noMoreSplits.get(taskPartitionId).contains(planNodeId), "noMoreSplits is set for task partition %s and plan node %s", taskPartitionId, planNodeId); + splits.merge( + taskPartitionId, + SplitsMapping.builder().addSplits(planNodeId, taskPartitionUpdate.splits()).build(), + (originalMapping, updatedMapping) -> + SplitsMapping.builder(originalMapping) + .addMapping(updatedMapping) + .build()); } - if (partitionUpdate.noMoreSplits()) { - noMoreSplits.put(partitionId, planNodeId); + if (taskPartitionUpdate.noMoreSplits()) { + noMoreSplits.put(taskPartitionId, planNodeId); } } - assignment.sealedPartitions().forEach(sealedPartitions::add); + assignment.sealedPartitions().forEach(sealedTaskPartitions::add); if (assignment.noMorePartitions()) { - noMorePartitions = true; + noMoreTaskPartitions = true; } checkFinished(); } private synchronized void checkFinished() { - if (noMorePartitions && sealedPartitions.containsAll(nodeRequirements.keySet())) { - verify(sealedPartitions.equals(nodeRequirements.keySet()), "unknown sealed partitions: %s", Sets.difference(sealedPartitions, nodeRequirements.keySet())); + if (noMoreTaskPartitions && sealedTaskPartitions.containsAll(nodeRequirements.keySet())) { + verify(sealedTaskPartitions.equals(nodeRequirements.keySet()), "unknown sealed partitions: %s", Sets.difference(sealedTaskPartitions, nodeRequirements.keySet())); ImmutableList.Builder result = ImmutableList.builder(); - for (Integer partitionId : sealedPartitions) { - ListMultimap taskSplits = splits.getOrDefault(partitionId, ImmutableListMultimap.of()); + for (Integer taskPartitionId : sealedTaskPartitions) { + SplitsMapping taskSplits = splits.getOrDefault(taskPartitionId, SplitsMapping.EMPTY); verify( - noMoreSplits.get(partitionId).containsAll(taskSplits.keySet()), - "no more split is missing for partition %s: %s", - partitionId, - Sets.difference(taskSplits.keySet(), noMoreSplits.get(partitionId))); + noMoreSplits.get(taskPartitionId).containsAll(taskSplits.getPlanNodeIds()), + "no more split is missing for task partition %s: %s", + taskPartitionId, + Sets.difference(taskSplits.getPlanNodeIds(), noMoreSplits.get(taskPartitionId))); result.add(new TaskDescriptor( - partitionId, + taskPartitionId, taskSplits, - nodeRequirements.get(partitionId))); + nodeRequirements.get(taskPartitionId))); } taskDescriptors = Optional.of(result.build()); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestArbitraryDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestArbitraryDistributionSplitAssigner.java index eb7864c578b7..51d605dc674b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestArbitraryDistributionSplitAssigner.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestArbitraryDistributionSplitAssigner.java @@ -43,6 +43,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.execution.scheduler.faulttolerant.SplitAssigner.SINGLE_SOURCE_PARTITION_ID; import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; import static java.util.Collections.shuffle; import static java.util.Objects.requireNonNull; @@ -770,9 +771,13 @@ private static void assertTaskDescriptor( ListMultimap expectedSplits) { assertEquals(taskDescriptor.getPartitionId(), expectedPartitionId); - assertSplitsEqual(taskDescriptor.getSplits(), expectedSplits); + taskDescriptor.getSplits().getPlanNodeIds().forEach(planNodeId -> { + // we expect single source partition for arbitrary distributed tasks + assertThat(taskDescriptor.getSplits().getSplits(planNodeId).keySet()).isEqualTo(ImmutableSet.of(SINGLE_SOURCE_PARTITION_ID)); + }); + assertSplitsEqual(taskDescriptor.getSplits().getSplitsFlat(), expectedSplits); Set hostRequirement = null; - for (Split split : taskDescriptor.getSplits().values()) { + for (Split split : taskDescriptor.getSplits().getSplitsFlat().values()) { if (!split.isRemotelyAccessible()) { if (hostRequirement == null) { hostRequirement = ImmutableSet.copyOf(split.getAddresses()); diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestEventDrivenTaskSource.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestEventDrivenTaskSource.java index 05a70b539e6b..ab41210d9f16 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestEventDrivenTaskSource.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestEventDrivenTaskSource.java @@ -361,7 +361,7 @@ private void testStageTaskSourceSuccess( Map> actualSplits = new HashMap<>(); for (TaskDescriptor taskDescriptor : taskDescriptors) { int partitionId = taskDescriptor.getPartitionId(); - for (Map.Entry entry : taskDescriptor.getSplits().entries()) { + for (Map.Entry entry : taskDescriptor.getSplits().getSplitsFlat().entries()) { if (entry.getValue().getCatalogHandle().equals(REMOTE_CATALOG_HANDLE)) { RemoteSplit remoteSplit = (RemoteSplit) entry.getValue().getConnectorSplit(); SpoolingExchangeInput input = (SpoolingExchangeInput) remoteSplit.getExchangeInput(); @@ -671,15 +671,18 @@ public AssignmentResult assign(PlanNodeId planNodeId, ListMultimap partitionSplits = ImmutableListMultimap.builder().putAll(partition, splits).build(); + result.updatePartition(new PartitionUpdate(partition, planNodeId, true, partitionSplits, noMoreSplits)); }); if (noMoreSplits) { finishedSources.add(planNodeId); for (Integer partition : partitions) { - result.updatePartition(new PartitionUpdate(partition, planNodeId, false, ImmutableList.of(), true)); + result.updatePartition(new PartitionUpdate(partition, planNodeId, false, ImmutableListMultimap.of(), true)); } } if (finishedSources.containsAll(allSources)) { diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestHashDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestHashDistributionSplitAssigner.java index 6538216db1d8..1a762aa87cd3 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestHashDistributionSplitAssigner.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestHashDistributionSplitAssigner.java @@ -14,11 +14,13 @@ package io.trino.execution.scheduler.faulttolerant; import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.HashMultimap; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; +import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import com.google.common.collect.SetMultimap; import com.google.common.primitives.ImmutableLongArray; @@ -48,6 +50,8 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.ImmutableSetMultimap.toImmutableSetMultimap; import static io.trino.execution.scheduler.faulttolerant.HashDistributionSplitAssigner.createSourcePartitionToTaskPartition; +import static io.trino.execution.scheduler.faulttolerant.SplitAssigner.SINGLE_SOURCE_PARTITION_ID; +import static io.trino.execution.scheduler.faulttolerant.TestingConnectorSplit.getSplitId; import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; import static java.util.Objects.requireNonNull; import static org.assertj.core.api.Assertions.assertThat; @@ -724,16 +728,16 @@ public void run() sourcePartitionToTaskPartition); SplitAssignerTester tester = new SplitAssignerTester(); Map> partitionedSplitIds = new HashMap<>(); - Set replicatedSplitIds = new HashSet<>(); + Multimap replicatedSplitIds = HashMultimap.create(); for (SplitBatch batch : splits) { tester.update(assigner.assign(batch.getPlanNodeId(), batch.getSplits(), batch.isNoMoreSplits())); boolean replicated = replicatedSources.contains(batch.getPlanNodeId()); - tester.checkContainsSplits(batch.getPlanNodeId(), batch.getSplits().values(), replicated); + tester.checkContainsSplits(batch.getPlanNodeId(), batch.getSplits(), replicated); for (Map.Entry entry : batch.getSplits().entries()) { - int splitId = TestingConnectorSplit.getSplitId(entry.getValue()); + int splitId = getSplitId(entry.getValue()); if (replicated) { - assertThat(replicatedSplitIds).doesNotContain(splitId); - replicatedSplitIds.add(splitId); + assertThat(replicatedSplitIds.containsValue(splitId)).isFalse(); + replicatedSplitIds.put(batch.getPlanNodeId(), splitId); } else { partitionedSplitIds.computeIfAbsent(entry.getKey(), key -> ArrayListMultimap.create()).put(batch.getPlanNodeId(), splitId); @@ -751,15 +755,19 @@ public void run() NodeRequirements nodeRequirements = taskDescriptor.getNodeRequirements(); assertEquals(nodeRequirements.getCatalogHandle(), Optional.of(TEST_CATALOG_HANDLE)); partitionToNodeMap.ifPresent(partitionToNode -> { - if (!taskDescriptor.getSplits().isEmpty()) { + if (!taskDescriptor.getSplits().getSplitsFlat().isEmpty()) { InternalNode node = partitionToNode.get(partitionId); assertThat(nodeRequirements.getAddresses()).containsExactly(node.getHostAndPort()); } }); - Set taskDescriptorSplitIds = taskDescriptor.getSplits().values().stream() - .map(TestingConnectorSplit::getSplitId) - .collect(toImmutableSet()); - assertThat(taskDescriptorSplitIds).containsAll(replicatedSplitIds); + Set taskDescriptorSplitIds = new HashSet<>(); + replicatedSplitIds.keySet().forEach(planNodeId -> { + // all replicated splits should be assigned to single source partition in task descriptor + taskDescriptor.getSplits().getSplits(planNodeId).get(SINGLE_SOURCE_PARTITION_ID).stream() + .map(TestingConnectorSplit::getSplitId) + .forEach(taskDescriptorSplitIds::add); + }); + assertThat(taskDescriptorSplitIds).containsAll(replicatedSplitIds.values()); } // validate partitioned splits @@ -771,13 +779,15 @@ public void run() .map(taskDescriptors::get) .collect(toImmutableList()); for (TaskDescriptor descriptor : descriptors) { - Set taskDescriptorSplitIds = descriptor.getSplits().values().stream() - .map(TestingConnectorSplit::getSplitId) - .collect(toImmutableSet()); - if (taskDescriptorSplitIds.contains(splitId) && splittableSources.contains(source)) { + Multimap taskDescriptorSplitIds = descriptor.getSplits().getSplits(source).entries().stream() + .collect(toImmutableListMultimap( + Map.Entry::getKey, + entry -> getSplitId(entry.getValue()))); + + if (taskDescriptorSplitIds.get(partitionId).contains(splitId) && splittableSources.contains(source)) { return; } - if (!taskDescriptorSplitIds.contains(splitId) && !splittableSources.contains(source)) { + if (!taskDescriptorSplitIds.get(partitionId).contains(splitId) && !splittableSources.contains(source)) { fail("expected split not found: ." + splitId); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestSingleDistributionSplitAssigner.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestSingleDistributionSplitAssigner.java index c95e3d17479e..4da174434ecb 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestSingleDistributionSplitAssigner.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestSingleDistributionSplitAssigner.java @@ -43,10 +43,10 @@ public void testNoSources() tester.update(splitAssigner.finish()); - assertEquals(tester.getPartitionCount(), 1); + assertEquals(tester.getTaskPartitionCount(), 1); assertEquals(tester.getNodeRequirements(0), new NodeRequirements(Optional.empty(), hostRequirement)); assertTrue(tester.isSealed(0)); - assertTrue(tester.isNoMorePartitions()); + assertTrue(tester.isNoMoreTaskPartitions()); } @Test @@ -61,12 +61,12 @@ public void testEmptySource() tester.update(splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(), true)); tester.update(splitAssigner.finish()); - assertEquals(tester.getPartitionCount(), 1); + assertEquals(tester.getTaskPartitionCount(), 1); assertEquals(tester.getNodeRequirements(0), new NodeRequirements(Optional.empty(), hostRequirement)); assertThat(tester.getSplitIds(0, PLAN_NODE_1)).isEmpty(); assertTrue(tester.isNoMoreSplits(0, PLAN_NODE_1)); assertTrue(tester.isSealed(0)); - assertTrue(tester.isNoMorePartitions()); + assertTrue(tester.isNoMoreTaskPartitions()); } @Test @@ -77,18 +77,18 @@ public void testSingleSource() ImmutableSet.of(PLAN_NODE_1)); SplitAssignerTester tester = new SplitAssignerTester(); - assertEquals(tester.getPartitionCount(), 0); - assertFalse(tester.isNoMorePartitions()); + assertEquals(tester.getTaskPartitionCount(), 0); + assertFalse(tester.isNoMoreTaskPartitions()); tester.update(splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(0, createSplit(1)), false)); tester.update(splitAssigner.finish()); - assertEquals(tester.getPartitionCount(), 1); + assertEquals(tester.getTaskPartitionCount(), 1); assertThat(tester.getSplitIds(0, PLAN_NODE_1)).containsExactly(1); - assertTrue(tester.isNoMorePartitions()); + assertTrue(tester.isNoMoreTaskPartitions()); tester.update(splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(0, createSplit(2), 1, createSplit(3)), false)); tester.update(splitAssigner.finish()); - assertEquals(tester.getPartitionCount(), 1); + assertEquals(tester.getTaskPartitionCount(), 1); assertThat(tester.getSplitIds(0, PLAN_NODE_1)).containsExactly(1, 2, 3); assertFalse(tester.isNoMoreSplits(0, PLAN_NODE_1)); @@ -107,31 +107,31 @@ public void testMultipleSources() ImmutableSet.of(PLAN_NODE_1, PLAN_NODE_2)); SplitAssignerTester tester = new SplitAssignerTester(); - assertEquals(tester.getPartitionCount(), 0); - assertFalse(tester.isNoMorePartitions()); + assertEquals(tester.getTaskPartitionCount(), 0); + assertFalse(tester.isNoMoreTaskPartitions()); tester.update(splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(0, createSplit(1)), false)); tester.update(splitAssigner.finish()); - assertEquals(tester.getPartitionCount(), 1); - assertThat(tester.getSplitIds(0, PLAN_NODE_1)).containsExactly(1); - assertTrue(tester.isNoMorePartitions()); + assertEquals(tester.getTaskPartitionCount(), 1); + assertThat(tester.getSplitIds(0, PLAN_NODE_1)).containsExactlyInAnyOrder(1); + assertTrue(tester.isNoMoreTaskPartitions()); tester.update(splitAssigner.assign(PLAN_NODE_2, ImmutableListMultimap.of(0, createSplit(2), 1, createSplit(3)), false)); tester.update(splitAssigner.finish()); - assertEquals(tester.getPartitionCount(), 1); - assertThat(tester.getSplitIds(0, PLAN_NODE_2)).containsExactly(2, 3); + assertEquals(tester.getTaskPartitionCount(), 1); + assertThat(tester.getSplitIds(0, PLAN_NODE_2)).containsExactlyInAnyOrder(2, 3); assertFalse(tester.isNoMoreSplits(0, PLAN_NODE_1)); tester.update(splitAssigner.assign(PLAN_NODE_1, ImmutableListMultimap.of(2, createSplit(4)), true)); tester.update(splitAssigner.finish()); - assertThat(tester.getSplitIds(0, PLAN_NODE_1)).containsExactly(1, 4); + assertThat(tester.getSplitIds(0, PLAN_NODE_1)).containsExactlyInAnyOrder(1, 4); assertTrue(tester.isNoMoreSplits(0, PLAN_NODE_1)); assertFalse(tester.isNoMoreSplits(0, PLAN_NODE_2)); assertFalse(tester.isSealed(0)); tester.update(splitAssigner.assign(PLAN_NODE_2, ImmutableListMultimap.of(3, createSplit(5)), true)); tester.update(splitAssigner.finish()); - assertThat(tester.getSplitIds(0, PLAN_NODE_2)).containsExactly(2, 3, 5); + assertThat(tester.getSplitIds(0, PLAN_NODE_2)).containsExactlyInAnyOrder(2, 3, 5); assertTrue(tester.isNoMoreSplits(0, PLAN_NODE_2)); assertTrue(tester.isSealed(0)); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestSplitsMapping.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestSplitsMapping.java new file mode 100644 index 000000000000..0c761703140d --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestSplitsMapping.java @@ -0,0 +1,129 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler.faulttolerant; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import io.trino.metadata.Split; +import io.trino.sql.planner.plan.PlanNodeId; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.Optional; +import java.util.OptionalInt; + +import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.guava.api.Assertions.assertThat; + +public class TestSplitsMapping +{ + @Test + public void testNewSplitMappingBuilder() + { + SplitsMapping.Builder newBuilder = SplitsMapping.builder(); + newBuilder.addSplit(new PlanNodeId("N1"), 0, createSplit(1)); + newBuilder.addSplit(new PlanNodeId("N1"), 1, createSplit(2)); + newBuilder.addSplits(new PlanNodeId("N1"), 1, ImmutableList.of(createSplit(3), createSplit(4))); + newBuilder.addSplits(new PlanNodeId("N1"), 2, ImmutableList.of(createSplit(5), createSplit(6))); // addSplits(list) creating new source partition + newBuilder.addSplits(new PlanNodeId("N1"), ImmutableListMultimap.of( + 0, createSplit(7), + 1, createSplit(8), + 3, createSplit(9))); // create new source partition + newBuilder.addSplit(new PlanNodeId("N2"), 0, createSplit(10)); // another plan node + newBuilder.addSplit(new PlanNodeId("N2"), 3, createSplit(11)); + newBuilder.addMapping(SplitsMapping.builder() + .addSplit(new PlanNodeId("N1"), 0, createSplit(20)) + .addSplit(new PlanNodeId("N1"), 4, createSplit(21)) + .addSplit(new PlanNodeId("N3"), 0, createSplit(22)) + .build()); + + SplitsMapping splitsMapping1 = newBuilder.build(); + + assertThat(splitsMapping1.getPlanNodeIds()).containsExactlyInAnyOrder(new PlanNodeId("N1"), new PlanNodeId("N2"), new PlanNodeId("N3")); + assertThat(splitIds(splitsMapping1, "N1")).isEqualTo( + ImmutableListMultimap.builder() + .putAll(0, 1, 7, 20) + .putAll(1, 2, 3, 4, 8) + .putAll(2, 5, 6) + .putAll(3, 9) + .put(4, 21) + .build()); + assertThat(splitIds(splitsMapping1, "N2")).isEqualTo( + ImmutableListMultimap.builder() + .put(0, 10) + .put(3, 11) + .build()); + assertThat(splitIds(splitsMapping1, "N3")).isEqualTo( + ImmutableListMultimap.builder() + .put(0, 22) + .build()); + } + + @Test + public void testUpdatingSplitMappingBuilder() + { + SplitsMapping.Builder newBuilder = SplitsMapping.builder(SplitsMapping.builder() + .addSplit(new PlanNodeId("N1"), 0, createSplit(20)) + .addSplit(new PlanNodeId("N1"), 4, createSplit(21)) + .addSplit(new PlanNodeId("N3"), 0, createSplit(22)) + .build()); + + newBuilder.addSplit(new PlanNodeId("N1"), 0, createSplit(1)); + newBuilder.addSplit(new PlanNodeId("N1"), 1, createSplit(2)); + newBuilder.addSplits(new PlanNodeId("N1"), 1, ImmutableList.of(createSplit(3), createSplit(4))); + newBuilder.addSplits(new PlanNodeId("N1"), 2, ImmutableList.of(createSplit(5), createSplit(6))); // addSplits(list) creating new source partition + newBuilder.addSplits(new PlanNodeId("N1"), ImmutableListMultimap.of( + 0, createSplit(7), + 1, createSplit(8), + 3, createSplit(9))); // create new source partition + newBuilder.addSplit(new PlanNodeId("N2"), 0, createSplit(10)); // another plan node + newBuilder.addSplit(new PlanNodeId("N2"), 3, createSplit(11)); + + SplitsMapping splitsMapping1 = newBuilder.build(); + + assertThat(splitsMapping1.getPlanNodeIds()).containsExactlyInAnyOrder(new PlanNodeId("N1"), new PlanNodeId("N2"), new PlanNodeId("N3")); + assertThat(splitIds(splitsMapping1, "N1")).isEqualTo( + ImmutableListMultimap.builder() + .putAll(0, 20, 1, 7) + .putAll(1, 2, 3, 4, 8) + .putAll(2, 5, 6) + .putAll(3, 9) + .put(4, 21) + .build()); + assertThat(splitIds(splitsMapping1, "N2")).isEqualTo( + ImmutableListMultimap.builder() + .put(0, 10) + .put(3, 11) + .build()); + assertThat(splitIds(splitsMapping1, "N3")).isEqualTo( + ImmutableListMultimap.builder() + .put(0, 22) + .build()); + } + + private ListMultimap splitIds(SplitsMapping splitsMapping, String planNodeId) + { + return splitsMapping.getSplits(new PlanNodeId(planNodeId)).entries().stream() + .collect(ImmutableListMultimap.toImmutableListMultimap( + Map.Entry::getKey, + entry -> ((TestingConnectorSplit) entry.getValue().getConnectorSplit()).getId())); + } + + private static Split createSplit(int id) + { + return new Split(TEST_CATALOG_HANDLE, new TestingConnectorSplit(id, OptionalInt.empty(), Optional.empty())); + } +} diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestTaskDescriptorStorage.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestTaskDescriptorStorage.java index 6a763eaf9a88..deede00937fa 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestTaskDescriptorStorage.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestTaskDescriptorStorage.java @@ -14,7 +14,6 @@ package io.trino.execution.scheduler.faulttolerant; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableSet; import io.airlift.units.DataSize; import io.trino.exchange.SpoolingExchangeInput; @@ -199,9 +198,9 @@ private static TaskDescriptor createTaskDescriptor(int partitionId, DataSize ret { return new TaskDescriptor( partitionId, - ImmutableListMultimap.of( - new PlanNodeId("1"), - new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.of(new TestingExchangeSourceHandle(retainedSize.toBytes())), Optional.empty())))), + SplitsMapping.builder() + .addSplit(new PlanNodeId("1"), 1, new Split(REMOTE_CATALOG_HANDLE, new RemoteSplit(new SpoolingExchangeInput(ImmutableList.of(new TestingExchangeSourceHandle(retainedSize.toBytes())), Optional.empty())))) + .build(), new NodeRequirements(catalog, ImmutableSet.of())); }