Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions core/trino-main/src/main/java/io/trino/execution/SqlStage.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,24 @@
import io.trino.metadata.InternalNode;
import io.trino.metadata.Split;
import io.trino.spi.metrics.Metrics;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanFragment;
import io.trino.sql.planner.plan.DynamicFilterId;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.SimplePlanRewriter;

import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.base.Preconditions.checkArgument;
Expand All @@ -65,6 +71,7 @@ public final class SqlStage
private final boolean summarizeTaskInfo;

private final Set<DynamicFilterId> outboundDynamicFilterIds;
private final LocalExchangeBucketCountProvider bucketCountProvider;

private final Map<TaskId, RemoteTask> tasks = new ConcurrentHashMap<>();
@GuardedBy("this")
Expand All @@ -85,7 +92,8 @@ public static SqlStage createSqlStage(
Executor stateMachineExecutor,
Tracer tracer,
Span schedulerSpan,
SplitSchedulerStats schedulerStats)
SplitSchedulerStats schedulerStats,
LocalExchangeBucketCountProvider bucketCountProvider)
{
requireNonNull(stageId, "stageId is null");
requireNonNull(fragment, "fragment is null");
Expand All @@ -112,7 +120,8 @@ public static SqlStage createSqlStage(
stateMachine,
remoteTaskFactory,
nodeTaskMap,
summarizeTaskInfo);
summarizeTaskInfo,
bucketCountProvider);
sqlStage.initialize();
return sqlStage;
}
Expand All @@ -122,13 +131,15 @@ private SqlStage(
StageStateMachine stateMachine,
RemoteTaskFactory remoteTaskFactory,
NodeTaskMap nodeTaskMap,
boolean summarizeTaskInfo)
boolean summarizeTaskInfo,
LocalExchangeBucketCountProvider bucketCountProvider)
{
this.session = requireNonNull(session, "session is null");
this.stateMachine = stateMachine;
this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null");
this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null");
this.summarizeTaskInfo = summarizeTaskInfo;
this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null");

this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment());
}
Expand Down Expand Up @@ -243,6 +254,7 @@ public synchronized Optional<RemoteTask> createTask(
int partition,
int attempt,
Optional<int[]> bucketToPartition,
OptionalInt skewedBucketCount,
OutputBuffers outputBuffers,
Multimap<PlanNodeId, Split> splits,
Set<PlanNodeId> noMoreSplits,
Expand All @@ -257,13 +269,21 @@ public synchronized Optional<RemoteTask> createTask(

stateMachine.transitionToScheduling();

// set partitioning information on coordinator side
PlanFragment fragment = stateMachine.getFragment();
fragment = fragment.withOutputPartitioning(bucketToPartition, skewedBucketCount);
PlanNode newRoot = fragment.getRoot();
LocalExchangePartitionRewriter rewriter = new LocalExchangePartitionRewriter(handle -> bucketCountProvider.getBucketCount(session, handle));
newRoot = SimplePlanRewriter.rewriteWith(rewriter, newRoot);
fragment = fragment.withRoot(newRoot);

RemoteTask task = remoteTaskFactory.createRemoteTask(
session,
stateMachine.getStageSpan(),
taskId,
node,
speculative,
stateMachine.getFragment().withBucketToPartition(bucketToPartition),
fragment,
splits,
outputBuffers,
nodeTaskMap.createPartitionedSplitCountTracker(node, taskId),
Expand Down Expand Up @@ -375,4 +395,33 @@ public synchronized void stateChanged(TaskStatus taskStatus)
}
}
}

public interface LocalExchangeBucketCountProvider
{
Optional<Integer> getBucketCount(Session session, PartitioningHandle partitioning);
}

private static final class LocalExchangePartitionRewriter
extends SimplePlanRewriter<Void>
{
private final Function<PartitioningHandle, Optional<Integer>> bucketCountProvider;

public LocalExchangePartitionRewriter(Function<PartitioningHandle, Optional<Integer>> bucketCountProvider)
{
this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null");
}

@Override
public PlanNode visitExchange(ExchangeNode node, RewriteContext<Void> context)
{
return new ExchangeNode(
node.getId(),
node.getType(),
node.getScope(),
node.getPartitioningScheme().withBucketCount(bucketCountProvider.apply(node.getPartitioningScheme().getPartitioning().getHandle())),
node.getSources(),
node.getInputs(),
node.getOrderingScheme());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public SqlTaskExecution create(
taskContext,
fragment.getRoot(),
fragment.getOutputPartitioningScheme(),
fragment.getOutputSkewedBucketCount(),
fragment.getPartitionedSources(),
outputBuffer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import io.trino.spi.connector.CatalogHandle;
import io.trino.split.SplitSource;
import io.trino.sql.planner.NodePartitionMap;
import io.trino.sql.planner.NodePartitionMap.BucketToPartition;
import io.trino.sql.planner.NodePartitioningManager;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PlanFragment;
Expand All @@ -86,6 +87,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.Set;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
Expand Down Expand Up @@ -134,6 +136,7 @@
import static io.trino.execution.scheduler.StageExecution.State.SCHEDULED;
import static io.trino.operator.RetryPolicy.NONE;
import static io.trino.operator.RetryPolicy.QUERY;
import static io.trino.operator.output.SkewedPartitionRebalancer.getSkewedBucketCount;
import static io.trino.spi.ErrorType.EXTERNAL;
import static io.trino.spi.ErrorType.INTERNAL_ERROR;
import static io.trino.spi.StandardErrorCode.CLUSTER_OUT_OF_MEMORY;
Expand Down Expand Up @@ -242,7 +245,8 @@ public PipelinedQueryScheduler(
schedulerSpan,
schedulerStats,
plan,
summarizeTaskInfo);
summarizeTaskInfo,
nodePartitioningManager::getBucketCount);

coordinatorStagesScheduler = CoordinatorStagesScheduler.create(
queryStateMachine,
Expand Down Expand Up @@ -530,12 +534,12 @@ public void noMoreTasks(PlanFragmentId fragmentId)
*/
private static class CoordinatorStagesScheduler
{
private static final int[] SINGLE_PARTITION = new int[] {0};
private static final BucketToPartition SINGLE_PARTITION = new BucketToPartition(new int[1], false);

private final QueryStateMachine queryStateMachine;
private final NodeScheduler nodeScheduler;
private final Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator;
private final Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator;
private final Map<PlanFragmentId, Optional<BucketToPartition>> bucketToPartitionForStagesConsumedByCoordinator;
private final TaskLifecycleListener taskLifecycleListener;
private final StageManager stageManager;
private final List<StageExecution> stageExecutions;
Expand All @@ -554,7 +558,7 @@ public static CoordinatorStagesScheduler create(
SqlTaskManager coordinatorTaskManager)
{
Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator = createOutputBuffersForStagesConsumedByCoordinator(stageManager);
Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator = createBucketToPartitionForStagesConsumedByCoordinator(stageManager);
Map<PlanFragmentId, Optional<BucketToPartition>> bucketToPartitionForStagesConsumedByCoordinator = createBucketToPartitionForStagesConsumedByCoordinator(stageManager);

TaskLifecycleListener taskLifecycleListener = new QueryOutputTaskLifecycleListener(queryStateMachine);
// create executions
Expand All @@ -566,7 +570,8 @@ public static CoordinatorStagesScheduler create(
taskLifecycleListener,
failureDetector,
executor,
bucketToPartitionForStagesConsumedByCoordinator.get(stage.getFragment().getId()),
bucketToPartitionForStagesConsumedByCoordinator.get(stage.getFragment().getId()).map(BucketToPartition::bucketToPartition),
OptionalInt.empty(),
0);
stageExecutions.add(stageExecution);
taskLifecycleListener = stageExecution.getTaskLifecycleListener();
Expand Down Expand Up @@ -612,9 +617,9 @@ private static PipelinedOutputBufferManager createSingleStreamOutputBuffer(SqlSt
return new PartitionedPipelinedOutputBufferManager(partitioningHandle, 1);
}

private static Map<PlanFragmentId, Optional<int[]>> createBucketToPartitionForStagesConsumedByCoordinator(StageManager stageManager)
private static Map<PlanFragmentId, Optional<BucketToPartition>> createBucketToPartitionForStagesConsumedByCoordinator(StageManager stageManager)
{
ImmutableMap.Builder<PlanFragmentId, Optional<int[]>> result = ImmutableMap.builder();
ImmutableMap.Builder<PlanFragmentId, Optional<BucketToPartition>> result = ImmutableMap.builder();

SqlStage outputStage = stageManager.getOutputStage();
result.put(outputStage.getFragment().getId(), Optional.of(SINGLE_PARTITION));
Expand All @@ -632,7 +637,7 @@ private CoordinatorStagesScheduler(
QueryStateMachine queryStateMachine,
NodeScheduler nodeScheduler,
Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator,
Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator,
Map<PlanFragmentId, Optional<BucketToPartition>> bucketToPartitionForStagesConsumedByCoordinator,
TaskLifecycleListener taskLifecycleListener,
StageManager stageManager,
List<StageExecution> stageExecutions,
Expand Down Expand Up @@ -767,7 +772,7 @@ public Map<PlanFragmentId, PipelinedOutputBufferManager> getOutputBuffersForStag
return outputBuffersForStagesConsumedByCoordinator;
}

public Map<PlanFragmentId, Optional<int[]>> getBucketToPartitionForStagesConsumedByCoordinator()
public Map<PlanFragmentId, Optional<BucketToPartition>> getBucketToPartitionForStagesConsumedByCoordinator()
{
return bucketToPartitionForStagesConsumedByCoordinator;
}
Expand Down Expand Up @@ -878,7 +883,7 @@ public static DistributedStagesScheduler create(
partitioning.handle.equals(SCALED_WRITER_HASH_DISTRIBUTION) ? FIXED_HASH_DISTRIBUTION : partitioning.handle,
partitioning.partitionCount));

Map<PlanFragmentId, Optional<int[]>> bucketToPartitionMap = createBucketToPartitionMap(
Map<PlanFragmentId, Optional<BucketToPartition>> bucketToPartitionMap = createBucketToPartitionMap(
queryStateMachine.getSession(),
coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(),
stageManager,
Expand Down Expand Up @@ -916,13 +921,22 @@ public static DistributedStagesScheduler create(
}

PlanFragment fragment = stage.getFragment();
// TODO partitioning should be locked down during the planning phase
// This is a compromise to compute output partitioning and skew handling in the
// coordinator without having to change the planner code.
Optional<BucketToPartition> bucketToPartition = bucketToPartitionMap.get(fragment.getId());
OptionalInt skewedBucketCount = OptionalInt.empty();
if (bucketToPartition.isPresent()) {
skewedBucketCount = getSkewedBucketCount(queryStateMachine.getSession(), fragment.getOutputPartitioningScheme(), bucketToPartition.get(), nodePartitioningManager);
}
StageExecution stageExecution = createPipelinedStageExecution(
stageManager.get(fragment.getId()),
outputBufferManagers,
taskLifecycleListener,
failureDetector,
executor,
bucketToPartitionMap.get(fragment.getId()),
bucketToPartition.map(BucketToPartition::bucketToPartition),
skewedBucketCount,
attempt);
stageExecutions.put(stage.getStageId(), stageExecution);
}
Expand Down Expand Up @@ -960,18 +974,18 @@ public static DistributedStagesScheduler create(
return distributedStagesScheduler;
}

private static Map<PlanFragmentId, Optional<int[]>> createBucketToPartitionMap(
private static Map<PlanFragmentId, Optional<BucketToPartition>> createBucketToPartitionMap(
Session session,
Map<PlanFragmentId, Optional<int[]>> bucketToPartitionForStagesConsumedByCoordinator,
Map<PlanFragmentId, Optional<BucketToPartition>> bucketToPartitionForStagesConsumedByCoordinator,
StageManager stageManager,
Function<PartitioningKey, NodePartitionMap> partitioningCache)
{
ImmutableMap.Builder<PlanFragmentId, Optional<int[]>> result = ImmutableMap.builder();
ImmutableMap.Builder<PlanFragmentId, Optional<BucketToPartition>> result = ImmutableMap.builder();
result.putAll(bucketToPartitionForStagesConsumedByCoordinator);
for (SqlStage stage : stageManager.getDistributedStagesInTopologicalOrder()) {
PlanFragment fragment = stage.getFragment();
BucketToPartitionKey bucketToPartitionKey = getKeyForFragment(fragment, session);
Optional<int[]> bucketToPartition = getBucketToPartition(bucketToPartitionKey, partitioningCache);
Optional<BucketToPartition> bucketToPartition = getBucketToPartition(bucketToPartitionKey, partitioningCache);
for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) {
result.put(childStage.getFragment().getId(), bucketToPartition);
}
Expand All @@ -994,10 +1008,10 @@ private static BucketToPartitionKey getKeyForFragment(PlanFragment fragment, Ses
return new PartitioningKey(partitioningHandle, partitionCount);
}

private static Optional<int[]> getBucketToPartition(BucketToPartitionKey bucketToPartitionKey, Function<PartitioningKey, NodePartitionMap> partitioningCache)
private static Optional<BucketToPartition> getBucketToPartition(BucketToPartitionKey bucketToPartitionKey, Function<PartitioningKey, NodePartitionMap> partitioningCache)
{
return switch (bucketToPartitionKey) {
case ONE -> Optional.of(new int[1]);
case ONE -> Optional.of(new BucketToPartition(new int[1], false));
case EMPTY -> Optional.empty();
case PartitioningKey key -> {
NodePartitionMap nodePartitionMap = partitioningCache.apply(key);
Expand All @@ -1012,7 +1026,7 @@ private static Optional<int[]> getBucketToPartition(BucketToPartitionKey bucketT
private static Map<PlanFragmentId, PipelinedOutputBufferManager> createOutputBufferManagers(
Map<PlanFragmentId, PipelinedOutputBufferManager> outputBuffersForStagesConsumedByCoordinator,
StageManager stageManager,
Map<PlanFragmentId, Optional<int[]>> bucketToPartitionMap)
Map<PlanFragmentId, Optional<BucketToPartition>> bucketToPartitionMap)
{
ImmutableMap.Builder<PlanFragmentId, PipelinedOutputBufferManager> result = ImmutableMap.builder();
result.putAll(outputBuffersForStagesConsumedByCoordinator);
Expand All @@ -1029,9 +1043,9 @@ else if (partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) {
outputBufferManager = new ScaledPipelinedOutputBufferManager();
}
else {
Optional<int[]> bucketToPartition = bucketToPartitionMap.get(fragmentId);
Optional<BucketToPartition> bucketToPartition = bucketToPartitionMap.get(fragmentId);
checkArgument(bucketToPartition.isPresent(), "bucketToPartition is expected to be present for fragment: %s", fragmentId);
int partitionCount = Ints.max(bucketToPartition.get()) + 1;
int partitionCount = Ints.max(bucketToPartition.get().bucketToPartition()) + 1;
outputBufferManager = new PartitionedPipelinedOutputBufferManager(partitioningHandle, partitionCount);
}
result.put(fragmentId, outputBufferManager);
Expand Down
Loading