diff --git a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java index af1c6bb6823b..507776030ff8 100644 --- a/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java +++ b/core/trino-main/src/main/java/io/trino/SystemSessionProperties.java @@ -187,6 +187,9 @@ public final class SystemSessionProperties public static final String FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT = "fault_tolerant_execution_max_partition_count"; public static final String FAULT_TOLERANT_EXECUTION_MIN_PARTITION_COUNT = "fault_tolerant_execution_min_partition_count"; public static final String FAULT_TOLERANT_EXECUTION_MIN_PARTITION_COUNT_FOR_WRITE = "fault_tolerant_execution_min_partition_count_for_write"; + public static final String FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_ENABLED = "fault_tolerant_execution_runtime_adaptive_partitioning_enabled"; + public static final String FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_PARTITION_COUNT = "fault_tolerant_execution_runtime_adaptive_partitioning_partition_count"; + public static final String FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_MAX_TASK_SIZE = "fault_tolerant_execution_runtime_adaptive_partitioning_max_task_size"; public static final String FAULT_TOLERANT_EXECUTION_MIN_SOURCE_STAGE_PROGRESS = "fault_tolerant_execution_min_source_stage_progress"; private static final String FAULT_TOLERANT_EXECUTION_SMALL_STAGE_ESTIMATION_ENABLED = "fault_tolerant_execution_small_stage_estimation_enabled"; private static final String FAULT_TOLERANT_EXECUTION_SMALL_STAGE_ESTIMATION_THRESHOLD = "fault_tolerant_execution_small_stage_estimation_threshold"; @@ -949,6 +952,22 @@ public SystemSessionProperties( queryManagerConfig.getFaultTolerantExecutionMinPartitionCountForWrite(), value -> validateIntegerValue(value, FAULT_TOLERANT_EXECUTION_MIN_PARTITION_COUNT_FOR_WRITE, 1, FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT_LIMIT, false), false), + booleanProperty( + FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_ENABLED, + "Enables change of number of partitions at runtime when intermediate data size is large", + queryManagerConfig.isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(), + true), + integerProperty( + FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_PARTITION_COUNT, + "The partition count to use for runtime adaptive partitioning when enabled", + queryManagerConfig.getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(), + value -> validateIntegerValue(value, FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_PARTITION_COUNT, 1, FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT_LIMIT, false), + true), + dataSizeProperty( + FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_MAX_TASK_SIZE, + "Max average task input size when deciding runtime adaptive partitioning", + queryManagerConfig.getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(), + true), doubleProperty( FAULT_TOLERANT_EXECUTION_MIN_SOURCE_STAGE_PROGRESS, "Minimal progress of source stage to consider scheduling of parent stage", @@ -1790,6 +1809,21 @@ public static int getFaultTolerantExecutionMinPartitionCountForWrite(Session ses return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_MIN_PARTITION_COUNT_FOR_WRITE, Integer.class); } + public static boolean isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_ENABLED, Boolean.class); + } + + public static int getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_PARTITION_COUNT, Integer.class); + } + + public static DataSize getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(Session session) + { + return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_RUNTIME_ADAPTIVE_PARTITIONING_MAX_TASK_SIZE, DataSize.class); + } + public static double getFaultTolerantExecutionMinSourceStageProgress(Session session) { return session.getSystemProperty(FAULT_TOLERANT_EXECUTION_MIN_SOURCE_STAGE_PROGRESS, Double.class); diff --git a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java index 8e8354ccd9c9..f9092c221638 100644 --- a/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/QueryManagerConfig.java @@ -131,6 +131,11 @@ public class QueryManagerConfig private int faultTolerantExecutionMaxPartitionCount = 50; private int faultTolerantExecutionMinPartitionCount = 4; private int faultTolerantExecutionMinPartitionCountForWrite = 50; + private boolean faultTolerantExecutionRuntimeAdaptivePartitioningEnabled; + private int faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount = FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT_LIMIT; + // Currently, initial setup is 5GB of task memory processing 4GB data. Given that we triple the memory in case of + // task OOM, max task size is set to 12GB such that tasks of stages below threshold will succeed within one retry. + private DataSize faultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize = DataSize.of(12, GIGABYTE); private boolean faultTolerantExecutionForcePreferredWritePartitioningEnabled = true; private double faultTolerantExecutionMinSourceStageProgress = 0.2; @@ -965,6 +970,46 @@ public QueryManagerConfig setFaultTolerantExecutionMinPartitionCountForWrite(int return this; } + public boolean isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled() + { + return faultTolerantExecutionRuntimeAdaptivePartitioningEnabled; + } + + @Config("fault-tolerant-execution-runtime-adaptive-partitioning-enabled") + public QueryManagerConfig setFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(boolean faultTolerantExecutionRuntimeAdaptivePartitioningEnabled) + { + this.faultTolerantExecutionRuntimeAdaptivePartitioningEnabled = faultTolerantExecutionRuntimeAdaptivePartitioningEnabled; + return this; + } + + @Min(1) + @Max(FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT_LIMIT) + public int getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount() + { + return faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount; + } + + @Config("fault-tolerant-execution-runtime-adaptive-partitioning-partition-count") + @ConfigDescription("The partition count to use for runtime adaptive partitioning when enabled") + public QueryManagerConfig setFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(int faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount) + { + this.faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount = faultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount; + return this; + } + + public DataSize getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize() + { + return faultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize; + } + + @Config("fault-tolerant-execution-runtime-adaptive-partitioning-max-task-size") + @ConfigDescription("Max average task input size when deciding runtime adaptive partitioning") + public QueryManagerConfig setFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(DataSize faultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize) + { + this.faultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize = faultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize; + return this; + } + public boolean isFaultTolerantExecutionForcePreferredWritePartitioningEnabled() { return faultTolerantExecutionForcePreferredWritePartitioningEnabled; diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java index 58bcad458e44..7266c67340c1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java @@ -82,6 +82,8 @@ import io.trino.split.RemoteSplit; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.PlanFragmentIdAllocator; +import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SubPlan; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.PlanFragmentId; @@ -98,6 +100,7 @@ import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -122,12 +125,16 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.util.concurrent.Futures.getDone; +import static io.airlift.units.DataSize.succinctBytes; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionArbitraryDistributionComputeTaskTargetSizeMin; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultCoordinatorTaskMemory; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionDefaultTaskMemory; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMinSourceStageProgress; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionSmallStageEstimationThreshold; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionSmallStageSourceSizeMultiplier; import static io.trino.SystemSessionProperties.getMaxTasksWaitingForExecutionPerQuery; @@ -137,6 +144,7 @@ import static io.trino.SystemSessionProperties.getRetryMaxDelay; import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.SystemSessionProperties.getTaskRetryAttemptsPerTask; +import static io.trino.SystemSessionProperties.isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled; import static io.trino.SystemSessionProperties.isFaultTolerantExecutionSmallStageEstimationEnabled; import static io.trino.SystemSessionProperties.isFaultTolerantExecutionSmallStageRequireNoMorePartitions; import static io.trino.execution.BasicStageStats.aggregateBasicStageStats; @@ -154,9 +162,14 @@ import static io.trino.spi.ErrorType.USER_ERROR; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.consumesHashPartitionedInput; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanFragmentId; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanId; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.overridePartitionCountRecursively; import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.TopologicalOrderSubPlanVisitor.sortPlanInTopologicalOrder; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static io.trino.util.Failures.toFailure; import static java.lang.Math.max; import static java.lang.Math.min; @@ -291,10 +304,11 @@ public synchronized void start() }); Session session = queryStateMachine.getSession(); + int maxPartitionCount = getFaultTolerantExecutionMaxPartitionCount(session); FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory = new FaultTolerantPartitioningSchemeFactory( nodePartitioningManager, session, - getFaultTolerantExecutionMaxPartitionCount(session)); + maxPartitionCount); Closer closer = Closer.create(); NodeAllocator nodeAllocator = closer.register(nodeAllocatorService.getNodeAllocator(session)); try { @@ -327,6 +341,10 @@ public synchronized void start() getRetryDelayScaleFactor(session), Stopwatch.createUnstarted()), originalPlan, + maxPartitionCount, + isFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(session), + getFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(session), + getFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(session), minSourceStageProgress, smallStageEstimationEnabled, smallStageEstimationThreshold, @@ -415,9 +433,10 @@ public void updatePlan(SubPlan plan) public StageInfo getStageInfo() { - SubPlan plan = requireNonNull(this.plan.get(), "plan is null"); Map stageInfos = stages.values().stream() .collect(toImmutableMap(stage -> stage.getFragment().getId(), SqlStage::getStageInfo)); + // make sure that plan is not staler than stageInfos since `getStageInfo` is called asynchronously + SubPlan plan = requireNonNull(this.plan.get(), "plan is null"); Set reportedFragments = new HashSet<>(); StageInfo stageInfo = getStageInfo(plan, stageInfos, reportedFragments); // TODO Some stages may no longer be present in the plan when adaptive re-planning is implemented @@ -509,6 +528,10 @@ private static class Scheduler private final StageRegistry stageRegistry; private final TaskExecutionStats taskExecutionStats; private final DynamicFilterService dynamicFilterService; + private final int maxPartitionCount; + private final boolean runtimeAdaptivePartitioningEnabled; + private final int runtimeAdaptivePartitioningPartitionCount; + private final long runtimeAdaptivePartitioningMaxTaskSizeInBytes; private final double minSourceStageProgress; private final boolean smallStageEstimationEnabled; private final DataSize smallStageEstimationThreshold; @@ -518,10 +541,12 @@ private static class Scheduler private final List eventBuffer = new ArrayList<>(EVENT_BUFFER_CAPACITY); private boolean started; + private boolean runtimeAdaptivePartitioningApplied; private SubPlan plan; private List planInTopologicalOrder; private final Map stageExecutions = new HashMap<>(); + private final Map isReadyForExecutionCache = new HashMap<>(); private final SetMultimap stageConsumers = HashMultimap.create(); private final SchedulingQueue schedulingQueue = new SchedulingQueue(); @@ -558,6 +583,10 @@ public Scheduler( DynamicFilterService dynamicFilterService, SchedulingDelayer schedulingDelayer, SubPlan plan, + int maxPartitionCount, + boolean runtimeAdaptivePartitioningEnabled, + int runtimeAdaptivePartitioningPartitionCount, + DataSize runtimeAdaptivePartitioningMaxTaskSize, double minSourceStageProgress, boolean smallStageEstimationEnabled, DataSize smallStageEstimationThreshold, @@ -590,6 +619,10 @@ public Scheduler( this.dynamicFilterService = requireNonNull(dynamicFilterService, "dynamicFilterService is null"); this.schedulingDelayer = requireNonNull(schedulingDelayer, "schedulingDelayer is null"); this.plan = requireNonNull(plan, "plan is null"); + this.maxPartitionCount = maxPartitionCount; + this.runtimeAdaptivePartitioningEnabled = runtimeAdaptivePartitioningEnabled; + this.runtimeAdaptivePartitioningPartitionCount = runtimeAdaptivePartitioningPartitionCount; + this.runtimeAdaptivePartitioningMaxTaskSizeInBytes = requireNonNull(runtimeAdaptivePartitioningMaxTaskSize, "runtimeAdaptivePartitioningMaxTaskSize is null").toBytes(); this.minSourceStageProgress = minSourceStageProgress; this.smallStageEstimationEnabled = smallStageEstimationEnabled; this.smallStageEstimationThreshold = requireNonNull(smallStageEstimationThreshold, "smallStageEstimationThreshold is null"); @@ -769,6 +802,81 @@ private SubPlan optimizePlan(SubPlan plan) { // Re-optimize plan here based on available runtime statistics. // Fragments changed due to re-optimization as well as their downstream stages are expected to be assigned new fragment ids. + plan = updateStagesPartitioning(plan); + return plan; + } + + private SubPlan updateStagesPartitioning(SubPlan plan) + { + if (!runtimeAdaptivePartitioningEnabled || runtimeAdaptivePartitioningApplied) { + return plan; + } + + for (SubPlan subPlan : planInTopologicalOrder) { + PlanFragment fragment = subPlan.getFragment(); + if (!consumesHashPartitionedInput(fragment)) { + // no input hash partitioning present + continue; + } + + StageId stageId = getStageId(fragment.getId()); + if (stageExecutions.containsKey(stageId)) { + // already started + continue; + } + + IsReadyForExecutionResult isReadyForExecutionResult = isReadyForExecution(subPlan); + // Caching is not only needed to avoid duplicate calls, but also to avoid the case that a stage that + // is not ready now but becomes ready when updateStageExecutions. + // We want to avoid starting an execution without considering changing the number of partitions. + // TODO: think about how to eliminate the cache + isReadyForExecutionCache.put(subPlan, isReadyForExecutionResult); + if (!isReadyForExecutionResult.isReadyForExecution()) { + // not ready for execution + continue; + } + + // calculate (estimated) input data size to determine if we want to change number of partitions at runtime + List partitionedInputBytes = fragment.getRemoteSourceNodes().stream() + .filter(remoteSourceNode -> remoteSourceNode.getExchangeType() != REPLICATE) + .map(remoteSourceNode -> remoteSourceNode.getSourceFragmentIds().stream() + .mapToLong(sourceFragmentId -> { + StageId sourceStageId = getStageId(sourceFragmentId); + OutputDataSizeEstimate outputDataSizeEstimate = isReadyForExecutionResult.getSourceOutputSizeEstimates().get(sourceStageId); + verify(outputDataSizeEstimate != null, "outputDataSizeEstimate not found for source stage %s", sourceStageId); + return outputDataSizeEstimate.getTotalSizeInBytes(); + }) + .sum()) + .collect(toImmutableList()); + // Currently the memory estimation is simplified: + // if it's an aggregation, then we use the total input bytes as the memory consumption + // if it involves multiple joins, conservatively we assume the smallest remote source will be streamed through + // and use the sum of input bytes of other remote sources as the memory consumption + // TODO: more accurate memory estimation based on context (https://github.com/trinodb/trino/issues/18698) + long estimatedMemoryConsumptionInBytes = (partitionedInputBytes.size() == 1) ? partitionedInputBytes.get(0) : + partitionedInputBytes.stream().mapToLong(Long::longValue).sum() - Collections.min(partitionedInputBytes); + + int partitionCount = fragment.getPartitionCount().orElse(maxPartitionCount); + if (estimatedMemoryConsumptionInBytes > runtimeAdaptivePartitioningMaxTaskSizeInBytes * partitionCount) { + log.info("Stage %s has an estimated memory consumption of %s, changing partition count from %s to %s", + stageId, succinctBytes(estimatedMemoryConsumptionInBytes), partitionCount, runtimeAdaptivePartitioningPartitionCount); + runtimeAdaptivePartitioningApplied = true; + PlanFragmentIdAllocator planFragmentIdAllocator = new PlanFragmentIdAllocator(getMaxPlanFragmentId(planInTopologicalOrder) + 1); + PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(getMaxPlanId(planInTopologicalOrder) + 1); + return overridePartitionCountRecursively( + plan, + partitionCount, + runtimeAdaptivePartitioningPartitionCount, + planFragmentIdAllocator, + planNodeIdAllocator, + planInTopologicalOrder.stream() + .map(SubPlan::getFragment) + .map(PlanFragment::getId) + .filter(planFragmentId -> stageExecutions.containsKey(getStageId(planFragmentId))) + .collect(toImmutableSet())); + } + } + return plan; } @@ -782,7 +890,7 @@ private void updateStageExecutions() currentPlanStages.add(stageId); StageExecution stageExecution = stageExecutions.get(stageId); if (stageExecution == null) { - IsReadyForExecutionResult result = isReadyForExecution(subPlan); + IsReadyForExecutionResult result = isReadyForExecutionCache.computeIfAbsent(subPlan, ignored -> isReadyForExecution(subPlan)); if (result.isReadyForExecution()) { createStageExecution(subPlan, fragmentId.equals(rootFragmentId), result.getSourceOutputSizeEstimates(), nextSchedulingPriority++); } @@ -798,6 +906,7 @@ private void updateStageExecutions() stageExecution.abort(); } }); + isReadyForExecutionCache.clear(); } private static class IsReadyForExecutionResult @@ -847,8 +956,12 @@ private IsReadyForExecutionResult isReadyForExecution(SubPlan subPlan) boolean nonSpeculativeTasksWaitingForNode = preSchedulingTaskContexts.values().stream() .anyMatch(task -> !task.isSpeculative() && !task.getNodeLease().getNode().isDone()); - // do not start a speculative stage if there is non-speculative work still to be done. - boolean canScheduleSpeculative = !nonSpeculativeTasksInQueue && !nonSpeculativeTasksWaitingForNode; + // Do not start a speculative stage if there is non-speculative work still to be done. + // Do not start a speculative stage after partition count has been changed at runtime, as when we estimate + // by progress, repartition tasks will produce very uneven output for different output partitions, which + // will result in very bad task bin-packing results; also the fact that runtime adaptive partitioning + // happened already suggests that there is plenty work ahead. + boolean canScheduleSpeculative = !nonSpeculativeTasksInQueue && !nonSpeculativeTasksWaitingForNode && !runtimeAdaptivePartitioningApplied; boolean speculative = false; int finishedSourcesCount = 0; int estimatedByProgressSourcesCount = 0; diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java index 5344b2207cc3..05b01a3305a3 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningScheme.java @@ -94,4 +94,13 @@ public Optional> getPartitionToNodeMap() { return partitionToNodeMap; } + + public FaultTolerantPartitioningScheme withPartitionCount(int partitionCount) + { + return new FaultTolerantPartitioningScheme( + partitionCount, + this.bucketToPartitionMap, + this.splitToBucketFunction, + this.partitionToNodeMap); + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java index e3b2c398d364..3a7221967641 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FaultTolerantPartitioningSchemeFactory.java @@ -60,6 +60,12 @@ public FaultTolerantPartitioningScheme get(PartitioningHandle handle, Optional createSourcePartitionToTaskPartition( // adjust targetPartitionSizeInBytes based on total input bytes if (targetMaxTaskCount != Integer.MAX_VALUE || targetMinTaskCount != 0) { - long totalBytes = 0; - for (int partitionId = 0; partitionId < partitionCount; partitionId++) { - totalBytes += mergedEstimate.getPartitionSizeInBytes(partitionId); - } + long totalBytes = mergedEstimate.getTotalSizeInBytes(); if (totalBytes / targetPartitionSizeInBytes > targetMaxTaskCount) { // targetMaxTaskCount is only used to adjust targetPartitionSizeInBytes to avoid excessive number diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java index 6de375d802e9..2ee37ba886c9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java @@ -284,4 +284,49 @@ public String toString() .add("outputPartitioningScheme", outputPartitioningScheme) .toString(); } + + public PlanFragment withPartitionCount(Optional partitionCount) + { + return new PlanFragment( + this.id, + this.root, + this.symbols, + this.partitioning, + partitionCount, + this.partitionedSources, + this.outputPartitioningScheme, + this.statsAndCosts, + this.activeCatalogs, + this.jsonRepresentation); + } + + public PlanFragment withOutputPartitioningScheme(PartitioningScheme outputPartitioningScheme) + { + return new PlanFragment( + this.id, + this.root, + this.symbols, + this.partitioning, + this.partitionCount, + this.partitionedSources, + outputPartitioningScheme, + this.statsAndCosts, + this.activeCatalogs, + this.jsonRepresentation); + } + + public PlanFragment withRoot(PlanNode root) + { + return new PlanFragment( + this.id, + root, + this.symbols, + this.partitioning, + this.partitionCount, + this.partitionedSources, + this.outputPartitioningScheme, + this.statsAndCosts, + this.activeCatalogs, + this.jsonRepresentation); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmentIdAllocator.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmentIdAllocator.java new file mode 100644 index 000000000000..66baf1b7f702 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmentIdAllocator.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import io.trino.sql.planner.plan.PlanFragmentId; + +public class PlanFragmentIdAllocator +{ + private int nextId; + + public PlanFragmentIdAllocator(int startId) + { + this.nextId = startId; + } + + public PlanFragmentId getNextId() + { + return new PlanFragmentId(Integer.toString(nextId++)); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index bdcd5ff9204f..6df01b83a1dd 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -214,7 +214,7 @@ private static class Fragmenter private final TypeProvider types; private final StatsAndCosts statsAndCosts; private final List activeCatalogs; - private int nextFragmentId = ROOT_FRAGMENT_ID + 1; + private final PlanFragmentIdAllocator idAllocator = new PlanFragmentIdAllocator(ROOT_FRAGMENT_ID + 1); public Fragmenter(Session session, Metadata metadata, FunctionManager functionManager, TypeProvider types, StatsAndCosts statsAndCosts, List activeCatalogs) { @@ -231,11 +231,6 @@ public SubPlan buildRootFragment(PlanNode root, FragmentProperties properties) return buildFragment(root, properties, new PlanFragmentId(String.valueOf(ROOT_FRAGMENT_ID))); } - private PlanFragmentId nextFragmentId() - { - return new PlanFragmentId(String.valueOf(nextFragmentId++)); - } - private SubPlan buildFragment(PlanNode root, FragmentProperties properties, PlanFragmentId fragmentId) { Set dependencies = SymbolsExtractor.extractOutputSymbols(root); @@ -433,7 +428,7 @@ else if (exchange.getType() == ExchangeNode.Type.REPARTITION) { private SubPlan buildSubPlan(PlanNode node, FragmentProperties properties, RewriteContext context) { - PlanFragmentId planFragmentId = nextFragmentId(); + PlanFragmentId planFragmentId = idAllocator.getNextId(); PlanNode child = context.rewrite(node, properties); return buildFragment(child, properties, planFragmentId); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanNodeIdAllocator.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanNodeIdAllocator.java index 8dc20992d65d..330e0e60a44f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanNodeIdAllocator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanNodeIdAllocator.java @@ -19,6 +19,16 @@ public class PlanNodeIdAllocator { private int nextId; + public PlanNodeIdAllocator() + { + this(0); + } + + public PlanNodeIdAllocator(int startId) + { + this.nextId = startId; + } + public PlanNodeId getNextId() { return new PlanNodeId(Integer.toString(nextId++)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java b/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java new file mode 100644 index 000000000000..099c825fb9a3 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RuntimeAdaptivePartitioningRewriter.java @@ -0,0 +1,217 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.graph.Traverser; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.sql.planner.plan.SimplePlanRewriter; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterators.getOnlyElement; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_HASH_DISTRIBUTION; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; +import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith; +import static java.util.Objects.requireNonNull; + +public final class RuntimeAdaptivePartitioningRewriter +{ + private RuntimeAdaptivePartitioningRewriter() {} + + public static SubPlan overridePartitionCountRecursively( + SubPlan subPlan, + int oldPartitionCount, + int newPartitionCount, + PlanFragmentIdAllocator planFragmentIdAllocator, + PlanNodeIdAllocator planNodeIdAllocator, + Set startedFragments) + { + PlanFragment fragment = subPlan.getFragment(); + if (startedFragments.contains(fragment.getId())) { + // already started, nothing to change for subPlan and its descendants + return subPlan; + } + + PartitioningScheme outputPartitioningScheme = fragment.getOutputPartitioningScheme(); + if (outputPartitioningScheme.getPartitioning().getHandle().equals(FIXED_BROADCAST_DISTRIBUTION)) { + // the result of the subtree will be broadcast, then no need to change partition count for the subtree + // as the planner will only broadcast fragment output if it sees input data is small or filter ratio is high + return subPlan; + } + if (producesHashPartitionedOutput(fragment)) { + fragment = fragment.withOutputPartitioningScheme(outputPartitioningScheme.withPartitionCount(Optional.of(newPartitionCount))); + } + + if (consumesHashPartitionedInput(fragment)) { + fragment = fragment.withPartitionCount(Optional.of(newPartitionCount)); + } + else { + // no input partitioning, then no need to insert extra exchanges to sources + return new SubPlan( + fragment, + subPlan.getChildren().stream() + .map(child -> overridePartitionCountRecursively( + child, + oldPartitionCount, + newPartitionCount, + planFragmentIdAllocator, + planNodeIdAllocator, + startedFragments)) + .collect(toImmutableList())); + } + + // insert extra exchanges to sources + ImmutableList.Builder newSources = ImmutableList.builder(); + ImmutableMap.Builder runtimeAdaptivePlanFragmentIdMapping = ImmutableMap.builder(); + for (SubPlan source : subPlan.getChildren()) { + PlanFragment sourceFragment = source.getFragment(); + RemoteSourceNode sourceRemoteSourceNode = getOnlyElement(fragment.getRemoteSourceNodes().stream() + .filter(remoteSourceNode -> remoteSourceNode.getSourceFragmentIds().contains(sourceFragment.getId())) + .iterator()); + requireNonNull(sourceRemoteSourceNode, "sourceRemoteSourceNode is null"); + if (sourceRemoteSourceNode.getExchangeType() == REPLICATE) { + // since exchange type is REPLICATE, also no need to change partition count for the subtree as the + // planner will only broadcast fragment output if it sees input data is small or filter ratio is high + newSources.add(source); + continue; + } + if (!startedFragments.contains(sourceFragment.getId())) { + // source not started yet, then no need to insert extra exchanges to sources + newSources.add(overridePartitionCountRecursively( + source, + oldPartitionCount, + newPartitionCount, + planFragmentIdAllocator, + planNodeIdAllocator, + startedFragments)); + runtimeAdaptivePlanFragmentIdMapping.put(sourceFragment.getId(), sourceFragment.getId()); + continue; + } + RemoteSourceNode runtimeAdaptiveRemoteSourceNode = new RemoteSourceNode( + planNodeIdAllocator.getNextId(), + sourceFragment.getId(), + sourceFragment.getOutputPartitioningScheme().getOutputLayout(), + sourceRemoteSourceNode.getOrderingScheme(), + sourceRemoteSourceNode.getExchangeType(), + sourceRemoteSourceNode.getRetryPolicy()); + PlanFragment runtimeAdaptivePlanFragment = new PlanFragment( + planFragmentIdAllocator.getNextId(), + runtimeAdaptiveRemoteSourceNode, + sourceFragment.getSymbols(), + FIXED_HASH_DISTRIBUTION, + Optional.of(oldPartitionCount), + ImmutableList.of(), // partitioned sources will be empty as the fragment will only read from `runtimeAdaptiveRemoteSourceNode` + sourceFragment.getOutputPartitioningScheme().withPartitionCount(Optional.of(newPartitionCount)), + sourceFragment.getStatsAndCosts(), + sourceFragment.getActiveCatalogs(), + sourceFragment.getJsonRepresentation()); + SubPlan newSource = new SubPlan( + runtimeAdaptivePlanFragment, + ImmutableList.of(overridePartitionCountRecursively( + source, + oldPartitionCount, + newPartitionCount, + planFragmentIdAllocator, + planNodeIdAllocator, + startedFragments))); + newSources.add(newSource); + runtimeAdaptivePlanFragmentIdMapping.put(sourceFragment.getId(), runtimeAdaptivePlanFragment.getId()); + } + + return new SubPlan( + fragment.withRoot(rewriteWith( + new UpdateRemoteSourceFragmentIdsRewriter(runtimeAdaptivePlanFragmentIdMapping.buildOrThrow()), + fragment.getRoot())), + newSources.build()); + } + + public static boolean consumesHashPartitionedInput(PlanFragment fragment) + { + return isPartitioned(fragment.getPartitioning()); + } + + public static boolean producesHashPartitionedOutput(PlanFragment fragment) + { + return isPartitioned(fragment.getOutputPartitioningScheme().getPartitioning().getHandle()); + } + + public static int getMaxPlanFragmentId(List subPlans) + { + return subPlans.stream() + .map(SubPlan::getFragment) + .map(PlanFragment::getId) + .mapToInt(fragmentId -> Integer.parseInt(fragmentId.toString())) + .max() + .orElseThrow(); + } + + public static int getMaxPlanId(List subPlans) + { + return subPlans.stream() + .map(SubPlan::getFragment) + .map(PlanFragment::getRoot) + .mapToInt(root -> traverse(root) + .map(PlanNode::getId) + .mapToInt(planNodeId -> Integer.parseInt(planNodeId.toString())) + .max() + .orElseThrow()) + .max() + .orElseThrow(); + } + + private static boolean isPartitioned(PartitioningHandle partitioningHandle) + { + return partitioningHandle.equals(FIXED_HASH_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_HASH_DISTRIBUTION); + } + + private static Stream traverse(PlanNode node) + { + Iterable iterable = Traverser.forTree(PlanNode::getSources).depthFirstPreOrder(node); + return StreamSupport.stream(iterable.spliterator(), false); + } + + private static class UpdateRemoteSourceFragmentIdsRewriter + extends SimplePlanRewriter + { + private final Map runtimeAdaptivePlanFragmentIdMapping; + + public UpdateRemoteSourceFragmentIdsRewriter(Map runtimeAdaptivePlanFragmentIdMapping) + { + this.runtimeAdaptivePlanFragmentIdMapping = requireNonNull(runtimeAdaptivePlanFragmentIdMapping, "runtimeAdaptivePlanFragmentIdMapping is null"); + } + + @Override + public PlanNode visitRemoteSource(RemoteSourceNode node, RewriteContext context) + { + if (node.getExchangeType() == REPLICATE) { + return node; + } + return node.withSourceFragmentIds(node.getSourceFragmentIds().stream() + .map(runtimeAdaptivePlanFragmentIdMapping::get) + .collect(toImmutableList())); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/TopologicalOrderSubPlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/TopologicalOrderSubPlanVisitor.java index 03713a0068ef..ab5dc8d1af03 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/TopologicalOrderSubPlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/TopologicalOrderSubPlanVisitor.java @@ -26,7 +26,7 @@ import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toMap; -public class TopologicalOrderSubPlanVisitor +public final class TopologicalOrderSubPlanVisitor { private TopologicalOrderSubPlanVisitor() {} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java index 07f89cb8ef5e..785b78274239 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/RemoteSourceNode.java @@ -117,4 +117,15 @@ public PlanNode replaceChildren(List newChildren) checkArgument(newChildren.isEmpty(), "newChildren is not empty"); return this; } + + public RemoteSourceNode withSourceFragmentIds(List sourceFragmentIds) + { + return new RemoteSourceNode( + this.getId(), + sourceFragmentIds, + this.getOutputSymbols(), + this.getOrderingScheme(), + this.getExchangeType(), + this.getRetryPolicy()); + } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java index 73bd6afd9b4b..34dbcd3615dc 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestQueryManagerConfig.java @@ -28,6 +28,7 @@ import static io.airlift.units.DataSize.Unit.KILOBYTE; import static io.airlift.units.DataSize.Unit.MEGABYTE; import static io.trino.execution.QueryManagerConfig.AVAILABLE_HEAP_MEMORY; +import static io.trino.execution.QueryManagerConfig.FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT_LIMIT; import static java.util.concurrent.TimeUnit.DAYS; import static java.util.concurrent.TimeUnit.HOURS; import static java.util.concurrent.TimeUnit.MINUTES; @@ -98,6 +99,9 @@ public void testDefaults() .setFaultTolerantExecutionMaxPartitionCount(50) .setFaultTolerantExecutionMinPartitionCount(4) .setFaultTolerantExecutionMinPartitionCountForWrite(50) + .setFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(false) + .setFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(DataSize.of(12, GIGABYTE)) + .setFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(FAULT_TOLERANT_EXECUTION_MAX_PARTITION_COUNT_LIMIT) .setFaultTolerantExecutionForcePreferredWritePartitioningEnabled(true) .setFaultTolerantExecutionMinSourceStageProgress(0.2) .setFaultTolerantExecutionSmallStageEstimationEnabled(true) @@ -170,6 +174,9 @@ public void testExplicitPropertyMappings() .put("fault-tolerant-execution-max-partition-count", "123") .put("fault-tolerant-execution-min-partition-count", "12") .put("fault-tolerant-execution-min-partition-count-for-write", "99") + .put("fault-tolerant-execution-runtime-adaptive-partitioning-enabled", "true") + .put("fault-tolerant-execution-runtime-adaptive-partitioning-partition-count", "888") + .put("fault-tolerant-execution-runtime-adaptive-partitioning-max-task-size", "18GB") .put("experimental.fault-tolerant-execution-force-preferred-write-partitioning-enabled", "false") .put("fault-tolerant-execution-min-source-stage-progress", "0.3") .put("query.max-writer-task-count", "101") @@ -239,6 +246,9 @@ public void testExplicitPropertyMappings() .setFaultTolerantExecutionMaxPartitionCount(123) .setFaultTolerantExecutionMinPartitionCount(12) .setFaultTolerantExecutionMinPartitionCountForWrite(99) + .setFaultTolerantExecutionRuntimeAdaptivePartitioningEnabled(true) + .setFaultTolerantExecutionRuntimeAdaptivePartitioningPartitionCount(888) + .setFaultTolerantExecutionRuntimeAdaptivePartitioningMaxTaskSize(DataSize.of(18, GIGABYTE)) .setFaultTolerantExecutionForcePreferredWritePartitioningEnabled(false) .setFaultTolerantExecutionMinSourceStageProgress(0.3) .setFaultTolerantExecutionSmallStageEstimationEnabled(false) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java index af67cb4e21e7..195cae74788b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java @@ -135,7 +135,7 @@ public void testPartitionCountInPlanFragment() new PlanFragmentId("4"), Optional.empty(), new PlanFragmentId("5"), Optional.empty()); - assertThat(expectedPartitionCount).isEqualTo(actualPartitionCount.buildOrThrow()); + assertThat(actualPartitionCount.buildOrThrow()).isEqualTo(expectedPartitionCount); } private SubPlan fragment(Plan plan) diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java new file mode 100644 index 000000000000..3c70d12c6858 --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestOverridePartitionCountRecursively.java @@ -0,0 +1,307 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant; + +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.connector.CoordinatorDynamicCatalogManager; +import io.trino.connector.InMemoryCatalogStore; +import io.trino.connector.LazyCatalogFactory; +import io.trino.execution.QueryManagerConfig; +import io.trino.execution.warnings.WarningCollector; +import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; +import io.trino.plugin.hive.HiveQueryRunner; +import io.trino.security.AllowAllAccessControl; +import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.planner.Plan; +import io.trino.sql.planner.PlanFragment; +import io.trino.sql.planner.PlanFragmentIdAllocator; +import io.trino.sql.planner.PlanFragmenter; +import io.trino.sql.planner.PlanNodeIdAllocator; +import io.trino.sql.planner.SubPlan; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; +import io.trino.testing.QueryRunner; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount; +import static io.trino.execution.querystats.PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.consumesHashPartitionedInput; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanFragmentId; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.getMaxPlanId; +import static io.trino.sql.planner.RuntimeAdaptivePartitioningRewriter.overridePartitionCountRecursively; +import static io.trino.sql.planner.SystemPartitioningHandle.COORDINATOR_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_BROADCAST_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; +import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; +import static io.trino.sql.planner.TopologicalOrderSubPlanVisitor.sortPlanInTopologicalOrder; +import static io.trino.tpch.TpchTable.getTables; +import static io.trino.transaction.TransactionBuilder.transaction; +import static java.util.Objects.requireNonNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestOverridePartitionCountRecursively + extends AbstractTestQueryFramework +{ + private static final int PARTITION_COUNT_OVERRIDE = 40; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + ImmutableMap.Builder extraPropertiesWithRuntimeAdaptivePartitioning = ImmutableMap.builder(); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(FaultTolerantExecutionConnectorTestHelper.getExtraProperties()); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(FaultTolerantExecutionConnectorTestHelper.enforceRuntimeAdaptivePartitioningProperties()); + + return HiveQueryRunner.builder() + .setExtraProperties(extraPropertiesWithRuntimeAdaptivePartitioning.buildOrThrow()) + .setAdditionalSetup(runner -> { + runner.installPlugin(new FileSystemExchangePlugin()); + runner.loadExchangeManager("filesystem", ImmutableMap.of("exchange.base-directories", + System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager")); + }) + .setInitialTables(getTables()) + .build(); + } + + @Test + public void testCreateTableAs() + { + // already started: 3, 5, 6 + // added fragments: 7, 8, 9 + // 0 0 + // | | + // 1 1 + // | | + // 2 2 + // / \ / \ + // 3* 4 => [7] 4 + // / \ | / \ + // 5* 6* 3* [8] [9] + // | | + // 5* 6* + assertOverridePartitionCountRecursively( + noJoinReordering(), + "CREATE TABLE tmp AS " + + "SELECT n1.* FROM nation n1 " + + "RIGHT JOIN " + + "(SELECT n.nationkey FROM (SELECT * FROM lineitem WHERE suppkey BETWEEN 20 and 30) l LEFT JOIN nation n on l.suppkey = n.nationkey) n2" + + " ON n1.nationkey = n2.nationkey + 1", + ImmutableMap.builder() + .put(0, new FragmentPartitioningInfo(COORDINATOR_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) + .put(1, new FragmentPartitioningInfo(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) + .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.empty(), SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty())) + .put(3, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) + .put(4, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) + .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) + .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) + .buildOrThrow(), + ImmutableMap.builder() + .put(0, new FragmentPartitioningInfo(COORDINATOR_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) + .put(1, new FragmentPartitioningInfo(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) + .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, Optional.empty())) + .put(3, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) + .put(4, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) + .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.empty())) + .put(7, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(5), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .put(8, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(5), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .put(9, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(5), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .buildOrThrow(), + ImmutableSet.of(3, 5, 6)); + } + + @Test + public void testSkipBroadcastSubtree() + { + // result of fragment 7 will be broadcast, + // so no runtime adaptive partitioning will be applied to its subtree + // already started: 4, 10, 11, 12 + // added fragments: 13 + // 0 0 + // | | + // 1 1 + // / \ / \ + // 2 7 => 2 7 + // / \ | / \ | + // 3 6 8 3 6 8 + // / \ / \ / \ / \ + // 4* 5 9 12* [13] 5 9 12* + // / \ | / \ + // 10* 11* 4* 10* 11* + assertOverridePartitionCountRecursively( + noJoinReordering(), + "SELECT\n" + + " ps.partkey,\n" + + " sum(ps.supplycost * ps.availqty) AS value\n" + + "FROM\n" + + " partsupp ps,\n" + + " supplier s,\n" + + " nation n\n" + + "WHERE\n" + + " ps.suppkey = s.suppkey\n" + + " AND s.nationkey = n.nationkey\n" + + " AND n.name = 'GERMANY'\n" + + "GROUP BY\n" + + " ps.partkey\n" + + "HAVING\n" + + " sum(ps.supplycost * ps.availqty) > (\n" + + " SELECT sum(ps.supplycost * ps.availqty) * 0.0001\n" + + " FROM\n" + + " partsupp ps,\n" + + " supplier s,\n" + + " nation n\n" + + " WHERE\n" + + " ps.suppkey = s.suppkey\n" + + " AND s.nationkey = n.nationkey\n" + + " AND n.name = 'GERMANY'\n" + + " )\n" + + "ORDER BY\n" + + " value DESC", + ImmutableMap.builder() + .put(0, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) + .put(1, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), SINGLE_DISTRIBUTION, Optional.empty())) + .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(3, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(4, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(7, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), FIXED_BROADCAST_DISTRIBUTION, Optional.empty())) + .put(8, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), SINGLE_DISTRIBUTION, Optional.empty())) + .put(9, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(10, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(11, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(12, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .buildOrThrow(), + ImmutableMap.builder() + .put(0, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), SINGLE_DISTRIBUTION, Optional.empty())) + .put(1, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), SINGLE_DISTRIBUTION, Optional.empty())) + .put(2, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .put(3, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(40), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .put(4, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(5, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .put(6, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .put(7, new FragmentPartitioningInfo(SINGLE_DISTRIBUTION, Optional.empty(), FIXED_BROADCAST_DISTRIBUTION, Optional.empty())) + .put(8, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), SINGLE_DISTRIBUTION, Optional.empty())) + .put(9, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(10, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(11, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(12, new FragmentPartitioningInfo(SOURCE_DISTRIBUTION, Optional.empty(), FIXED_HASH_DISTRIBUTION, Optional.of(4))) + .put(13, new FragmentPartitioningInfo(FIXED_HASH_DISTRIBUTION, Optional.of(4), FIXED_HASH_DISTRIBUTION, Optional.of(40))) + .buildOrThrow(), + ImmutableSet.of(4, 10, 11, 12)); + } + + private void assertOverridePartitionCountRecursively( + Session session, + @Language("SQL") String sql, + Map fragmentPartitioningInfoBefore, + Map fragmentPartitioningInfoAfter, + Set startedFragments) + { + SubPlan plan = getSubPlan(session, sql); + List planInTopologicalOrder = sortPlanInTopologicalOrder(plan); + assertThat(planInTopologicalOrder).hasSize(fragmentPartitioningInfoBefore.size()); + for (SubPlan subPlan : planInTopologicalOrder) { + PlanFragment fragment = subPlan.getFragment(); + int fragmentIdAsInt = Integer.parseInt(fragment.getId().toString()); + FragmentPartitioningInfo fragmentPartitioningInfo = fragmentPartitioningInfoBefore.get(fragmentIdAsInt); + assertEquals(fragment.getPartitionCount(), fragmentPartitioningInfo.inputPartitionCount()); + assertEquals(fragment.getPartitioning(), fragmentPartitioningInfo.inputPartitioning()); + assertEquals(fragment.getOutputPartitioningScheme().getPartitionCount(), fragmentPartitioningInfo.outputPartitionCount()); + assertEquals(fragment.getOutputPartitioningScheme().getPartitioning().getHandle(), fragmentPartitioningInfo.outputPartitioning()); + } + + PlanFragmentIdAllocator planFragmentIdAllocator = new PlanFragmentIdAllocator(getMaxPlanFragmentId(planInTopologicalOrder) + 1); + PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(getMaxPlanId(planInTopologicalOrder) + 1); + int oldPartitionCount = planInTopologicalOrder.stream() + .mapToInt(subPlan -> { + PlanFragment fragment = subPlan.getFragment(); + if (consumesHashPartitionedInput(fragment)) { + return fragment.getPartitionCount().orElse(getFaultTolerantExecutionMaxPartitionCount(session)); + } + else { + return 0; + } + }) + .max() + .orElseThrow(); + assertTrue(oldPartitionCount > 0); + + SubPlan newPlan = overridePartitionCountRecursively( + plan, + oldPartitionCount, + PARTITION_COUNT_OVERRIDE, + planFragmentIdAllocator, + planNodeIdAllocator, + startedFragments.stream().map(fragmentIdAsInt -> new PlanFragmentId(String.valueOf(fragmentIdAsInt))).collect(toImmutableSet())); + planInTopologicalOrder = sortPlanInTopologicalOrder(newPlan); + assertThat(planInTopologicalOrder).hasSize(fragmentPartitioningInfoAfter.size()); + for (SubPlan subPlan : planInTopologicalOrder) { + PlanFragment fragment = subPlan.getFragment(); + int fragmentIdAsInt = Integer.parseInt(fragment.getId().toString()); + FragmentPartitioningInfo fragmentPartitioningInfo = fragmentPartitioningInfoAfter.get(fragmentIdAsInt); + assertEquals(fragment.getPartitionCount(), fragmentPartitioningInfo.inputPartitionCount()); + assertEquals(fragment.getPartitioning(), fragmentPartitioningInfo.inputPartitioning()); + assertEquals(fragment.getOutputPartitioningScheme().getPartitionCount(), fragmentPartitioningInfo.outputPartitionCount()); + assertEquals(fragment.getOutputPartitioningScheme().getPartitioning().getHandle(), fragmentPartitioningInfo.outputPartitioning()); + } + } + + private SubPlan getSubPlan(Session session, @Language("SQL") String sql) + { + QueryRunner queryRunner = getDistributedQueryRunner(); + Plan plan = queryRunner.createPlan(session, sql, WarningCollector.NOOP, createPlanOptimizersStatsCollector()); + return transaction(queryRunner.getTransactionManager(), new AllowAllAccessControl()) + .singleStatement() + .execute(session, transactionSession -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + transactionSession.getCatalog().ifPresent(catalog -> queryRunner.getMetadata().getCatalogHandle(transactionSession, catalog)); + return new PlanFragmenter( + queryRunner.getMetadata(), + queryRunner.getFunctionManager(), + queryRunner.getTransactionManager(), + new CoordinatorDynamicCatalogManager(new InMemoryCatalogStore(), new LazyCatalogFactory(), directExecutor()), + new QueryManagerConfig()).createSubPlans(transactionSession, plan, false, WarningCollector.NOOP); + }); + } + + private record FragmentPartitioningInfo( + PartitioningHandle inputPartitioning, + Optional inputPartitionCount, + PartitioningHandle outputPartitioning, + Optional outputPartitionCount) + { + FragmentPartitioningInfo { + requireNonNull(inputPartitioning, "inputPartitioning is null"); + requireNonNull(inputPartitionCount, "inputPartitionCount is null"); + requireNonNull(outputPartitioning, "outputPartitioning is null"); + requireNonNull(outputPartitionCount, "outputPartitionCount is null"); + } + } +} diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionAggregations.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionAggregations.java new file mode 100644 index 000000000000..e09a4ddff0c3 --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionAggregations.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; +import io.trino.plugin.hive.HiveQueryRunner; +import io.trino.testing.AbstractTestFaultTolerantExecutionAggregations; +import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; +import io.trino.testing.QueryRunner; + +import java.util.Map; + +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionAggregations + extends AbstractTestFaultTolerantExecutionAggregations +{ + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + ImmutableMap.Builder extraPropertiesWithRuntimeAdaptivePartitioning = ImmutableMap.builder(); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(extraProperties); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(FaultTolerantExecutionConnectorTestHelper.enforceRuntimeAdaptivePartitioningProperties()); + + return HiveQueryRunner.builder() + .setExtraProperties(extraPropertiesWithRuntimeAdaptivePartitioning.buildOrThrow()) + .setAdditionalSetup(runner -> { + runner.installPlugin(new FileSystemExchangePlugin()); + runner.loadExchangeManager("filesystem", ImmutableMap.of("exchange.base-directories", + System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager")); + }) + .setInitialTables(getTables()) + .build(); + } +} diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionJoinQueries.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionJoinQueries.java new file mode 100644 index 000000000000..b87a708dbced --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionJoinQueries.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.execution.DynamicFilterConfig; +import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; +import io.trino.plugin.hive.HiveQueryRunner; +import io.trino.testing.AbstractTestFaultTolerantExecutionJoinQueries; +import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; +import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.google.common.base.Verify.verify; +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionJoinQueries + extends AbstractTestFaultTolerantExecutionJoinQueries +{ + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + ImmutableMap.Builder extraPropertiesWithRuntimeAdaptivePartitioning = ImmutableMap.builder(); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(extraProperties); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(FaultTolerantExecutionConnectorTestHelper.enforceRuntimeAdaptivePartitioningProperties()); + + verify(new DynamicFilterConfig().isEnableDynamicFiltering(), "this class assumes dynamic filtering is enabled by default"); + return HiveQueryRunner.builder() + .setExtraProperties(extraPropertiesWithRuntimeAdaptivePartitioning.buildOrThrow()) + .setAdditionalSetup(runner -> { + runner.installPlugin(new FileSystemExchangePlugin()); + runner.loadExchangeManager("filesystem", ImmutableMap.of("exchange.base-directories", + System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager")); + }) + .setInitialTables(getTables()) + .addHiveProperty("hive.dynamic-filtering.wait-timeout", "1h") + .build(); + } + + @Test + public void verifyDynamicFilteringEnabled() + { + assertQuery( + "SHOW SESSION LIKE 'enable_dynamic_filtering'", + "VALUES ('enable_dynamic_filtering', 'true', 'true', 'boolean', 'Enable dynamic filtering')"); + } +} diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionWindowQueries.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionWindowQueries.java new file mode 100644 index 000000000000..c3ef5a24f4b1 --- /dev/null +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/hive/TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionWindowQueries.java @@ -0,0 +1,48 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.faulttolerant.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; +import io.trino.plugin.hive.HiveQueryRunner; +import io.trino.testing.AbstractTestFaultTolerantExecutionWindowQueries; +import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; +import io.trino.testing.QueryRunner; + +import java.util.Map; + +import static io.trino.tpch.TpchTable.getTables; + +public class TestHiveRuntimeAdaptivePartitioningFaultTolerantExecutionWindowQueries + extends AbstractTestFaultTolerantExecutionWindowQueries +{ + @Override + protected QueryRunner createQueryRunner(Map extraProperties) + throws Exception + { + ImmutableMap.Builder extraPropertiesWithRuntimeAdaptivePartitioning = ImmutableMap.builder(); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(extraProperties); + extraPropertiesWithRuntimeAdaptivePartitioning.putAll(FaultTolerantExecutionConnectorTestHelper.enforceRuntimeAdaptivePartitioningProperties()); + + return HiveQueryRunner.builder() + .setExtraProperties(extraPropertiesWithRuntimeAdaptivePartitioning.buildOrThrow()) + .setAdditionalSetup(runner -> { + runner.installPlugin(new FileSystemExchangePlugin()); + runner.loadExchangeManager("filesystem", ImmutableMap.of("exchange.base-directories", + System.getProperty("java.io.tmpdir") + "/trino-local-file-system-exchange-manager")); + }) + .setInitialTables(getTables()) + .build(); + } +} diff --git a/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java b/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java index 75754a98ce71..96f2adffd60b 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/FaultTolerantExecutionConnectorTestHelper.java @@ -49,4 +49,14 @@ public static Map getExtraProperties() .put("query.schedule-split-batch-size", "2") .buildOrThrow(); } + + public static Map enforceRuntimeAdaptivePartitioningProperties() + { + return ImmutableMap.builder() + .put("fault-tolerant-execution-runtime-adaptive-partitioning-enabled", "true") + .put("fault-tolerant-execution-runtime-adaptive-partitioning-partition-count", "40") + // to ensure runtime adaptive partitioning is triggered + .put("fault-tolerant-execution-runtime-adaptive-partitioning-max-task-size", "1B") + .buildOrThrow(); + } }