diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java index af0ac72b1d8e..fda974055a00 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java @@ -27,18 +27,24 @@ import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.spi.metrics.Metrics; +import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.DynamicFilterId; +import io.trino.sql.planner.plan.ExchangeNode; +import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.sql.planner.plan.SimplePlanRewriter; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.function.Function; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; @@ -65,6 +71,7 @@ public final class SqlStage private final boolean summarizeTaskInfo; private final Set outboundDynamicFilterIds; + private final LocalExchangeBucketCountProvider bucketCountProvider; private final Map tasks = new ConcurrentHashMap<>(); @GuardedBy("this") @@ -85,7 +92,8 @@ public static SqlStage createSqlStage( Executor stateMachineExecutor, Tracer tracer, Span schedulerSpan, - SplitSchedulerStats schedulerStats) + SplitSchedulerStats schedulerStats, + LocalExchangeBucketCountProvider bucketCountProvider) { requireNonNull(stageId, "stageId is null"); requireNonNull(fragment, "fragment is null"); @@ -112,7 +120,8 @@ public static SqlStage createSqlStage( stateMachine, remoteTaskFactory, nodeTaskMap, - summarizeTaskInfo); + summarizeTaskInfo, + bucketCountProvider); sqlStage.initialize(); return sqlStage; } @@ -122,13 +131,15 @@ private SqlStage( StageStateMachine stateMachine, RemoteTaskFactory remoteTaskFactory, NodeTaskMap nodeTaskMap, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + LocalExchangeBucketCountProvider bucketCountProvider) { this.session = requireNonNull(session, "session is null"); this.stateMachine = stateMachine; this.remoteTaskFactory = requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); this.nodeTaskMap = requireNonNull(nodeTaskMap, "nodeTaskMap is null"); this.summarizeTaskInfo = summarizeTaskInfo; + this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null"); this.outboundDynamicFilterIds = getOutboundDynamicFilters(stateMachine.getFragment()); } @@ -243,6 +254,7 @@ public synchronized Optional createTask( int partition, int attempt, Optional bucketToPartition, + OptionalInt skewedBucketCount, OutputBuffers outputBuffers, Multimap splits, Set noMoreSplits, @@ -257,13 +269,21 @@ public synchronized Optional createTask( stateMachine.transitionToScheduling(); + // set partitioning information on coordinator side + PlanFragment fragment = stateMachine.getFragment(); + fragment = fragment.withOutputPartitioning(bucketToPartition, skewedBucketCount); + PlanNode newRoot = fragment.getRoot(); + LocalExchangePartitionRewriter rewriter = new LocalExchangePartitionRewriter(handle -> bucketCountProvider.getBucketCount(session, handle)); + newRoot = SimplePlanRewriter.rewriteWith(rewriter, newRoot); + fragment = fragment.withRoot(newRoot); + RemoteTask task = remoteTaskFactory.createRemoteTask( session, stateMachine.getStageSpan(), taskId, node, speculative, - stateMachine.getFragment().withBucketToPartition(bucketToPartition), + fragment, splits, outputBuffers, nodeTaskMap.createPartitionedSplitCountTracker(node, taskId), @@ -375,4 +395,33 @@ public synchronized void stateChanged(TaskStatus taskStatus) } } } + + public interface LocalExchangeBucketCountProvider + { + Optional getBucketCount(Session session, PartitioningHandle partitioning); + } + + private static final class LocalExchangePartitionRewriter + extends SimplePlanRewriter + { + private final Function> bucketCountProvider; + + public LocalExchangePartitionRewriter(Function> bucketCountProvider) + { + this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null"); + } + + @Override + public PlanNode visitExchange(ExchangeNode node, RewriteContext context) + { + return new ExchangeNode( + node.getId(), + node.getType(), + node.getScope(), + node.getPartitioningScheme().withBucketCount(bucketCountProvider.apply(node.getPartitioningScheme().getPartitioning().getHandle())), + node.getSources(), + node.getInputs(), + node.getOrderingScheme()); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java index b223720047c6..0238d450b3ed 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java @@ -84,6 +84,7 @@ public SqlTaskExecution create( taskContext, fragment.getRoot(), fragment.getOutputPartitioningScheme(), + fragment.getOutputSkewedBucketCount(), fragment.getPartitionedSources(), outputBuffer); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java index a60a1097f9ce..f1564f285a44 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java @@ -63,6 +63,7 @@ import io.trino.spi.connector.CatalogHandle; import io.trino.split.SplitSource; import io.trino.sql.planner.NodePartitionMap; +import io.trino.sql.planner.NodePartitionMap.BucketToPartition; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.PlanFragment; @@ -86,6 +87,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; @@ -134,6 +136,7 @@ import static io.trino.execution.scheduler.StageExecution.State.SCHEDULED; import static io.trino.operator.RetryPolicy.NONE; import static io.trino.operator.RetryPolicy.QUERY; +import static io.trino.operator.output.SkewedPartitionRebalancer.getSkewedBucketCount; import static io.trino.spi.ErrorType.EXTERNAL; import static io.trino.spi.ErrorType.INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.CLUSTER_OUT_OF_MEMORY; @@ -242,7 +245,8 @@ public PipelinedQueryScheduler( schedulerSpan, schedulerStats, plan, - summarizeTaskInfo); + summarizeTaskInfo, + nodePartitioningManager::getBucketCount); coordinatorStagesScheduler = CoordinatorStagesScheduler.create( queryStateMachine, @@ -530,12 +534,12 @@ public void noMoreTasks(PlanFragmentId fragmentId) */ private static class CoordinatorStagesScheduler { - private static final int[] SINGLE_PARTITION = new int[] {0}; + private static final BucketToPartition SINGLE_PARTITION = new BucketToPartition(new int[1], false); private final QueryStateMachine queryStateMachine; private final NodeScheduler nodeScheduler; private final Map outputBuffersForStagesConsumedByCoordinator; - private final Map> bucketToPartitionForStagesConsumedByCoordinator; + private final Map> bucketToPartitionForStagesConsumedByCoordinator; private final TaskLifecycleListener taskLifecycleListener; private final StageManager stageManager; private final List stageExecutions; @@ -554,7 +558,7 @@ public static CoordinatorStagesScheduler create( SqlTaskManager coordinatorTaskManager) { Map outputBuffersForStagesConsumedByCoordinator = createOutputBuffersForStagesConsumedByCoordinator(stageManager); - Map> bucketToPartitionForStagesConsumedByCoordinator = createBucketToPartitionForStagesConsumedByCoordinator(stageManager); + Map> bucketToPartitionForStagesConsumedByCoordinator = createBucketToPartitionForStagesConsumedByCoordinator(stageManager); TaskLifecycleListener taskLifecycleListener = new QueryOutputTaskLifecycleListener(queryStateMachine); // create executions @@ -566,7 +570,8 @@ public static CoordinatorStagesScheduler create( taskLifecycleListener, failureDetector, executor, - bucketToPartitionForStagesConsumedByCoordinator.get(stage.getFragment().getId()), + bucketToPartitionForStagesConsumedByCoordinator.get(stage.getFragment().getId()).map(BucketToPartition::bucketToPartition), + OptionalInt.empty(), 0); stageExecutions.add(stageExecution); taskLifecycleListener = stageExecution.getTaskLifecycleListener(); @@ -612,9 +617,9 @@ private static PipelinedOutputBufferManager createSingleStreamOutputBuffer(SqlSt return new PartitionedPipelinedOutputBufferManager(partitioningHandle, 1); } - private static Map> createBucketToPartitionForStagesConsumedByCoordinator(StageManager stageManager) + private static Map> createBucketToPartitionForStagesConsumedByCoordinator(StageManager stageManager) { - ImmutableMap.Builder> result = ImmutableMap.builder(); + ImmutableMap.Builder> result = ImmutableMap.builder(); SqlStage outputStage = stageManager.getOutputStage(); result.put(outputStage.getFragment().getId(), Optional.of(SINGLE_PARTITION)); @@ -632,7 +637,7 @@ private CoordinatorStagesScheduler( QueryStateMachine queryStateMachine, NodeScheduler nodeScheduler, Map outputBuffersForStagesConsumedByCoordinator, - Map> bucketToPartitionForStagesConsumedByCoordinator, + Map> bucketToPartitionForStagesConsumedByCoordinator, TaskLifecycleListener taskLifecycleListener, StageManager stageManager, List stageExecutions, @@ -767,7 +772,7 @@ public Map getOutputBuffersForStag return outputBuffersForStagesConsumedByCoordinator; } - public Map> getBucketToPartitionForStagesConsumedByCoordinator() + public Map> getBucketToPartitionForStagesConsumedByCoordinator() { return bucketToPartitionForStagesConsumedByCoordinator; } @@ -878,7 +883,7 @@ public static DistributedStagesScheduler create( partitioning.handle.equals(SCALED_WRITER_HASH_DISTRIBUTION) ? FIXED_HASH_DISTRIBUTION : partitioning.handle, partitioning.partitionCount)); - Map> bucketToPartitionMap = createBucketToPartitionMap( + Map> bucketToPartitionMap = createBucketToPartitionMap( queryStateMachine.getSession(), coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(), stageManager, @@ -916,13 +921,22 @@ public static DistributedStagesScheduler create( } PlanFragment fragment = stage.getFragment(); + // TODO partitioning should be locked down during the planning phase + // This is a compromise to compute output partitioning and skew handling in the + // coordinator without having to change the planner code. + Optional bucketToPartition = bucketToPartitionMap.get(fragment.getId()); + OptionalInt skewedBucketCount = OptionalInt.empty(); + if (bucketToPartition.isPresent()) { + skewedBucketCount = getSkewedBucketCount(queryStateMachine.getSession(), fragment.getOutputPartitioningScheme(), bucketToPartition.get(), nodePartitioningManager); + } StageExecution stageExecution = createPipelinedStageExecution( stageManager.get(fragment.getId()), outputBufferManagers, taskLifecycleListener, failureDetector, executor, - bucketToPartitionMap.get(fragment.getId()), + bucketToPartition.map(BucketToPartition::bucketToPartition), + skewedBucketCount, attempt); stageExecutions.put(stage.getStageId(), stageExecution); } @@ -960,18 +974,18 @@ public static DistributedStagesScheduler create( return distributedStagesScheduler; } - private static Map> createBucketToPartitionMap( + private static Map> createBucketToPartitionMap( Session session, - Map> bucketToPartitionForStagesConsumedByCoordinator, + Map> bucketToPartitionForStagesConsumedByCoordinator, StageManager stageManager, Function partitioningCache) { - ImmutableMap.Builder> result = ImmutableMap.builder(); + ImmutableMap.Builder> result = ImmutableMap.builder(); result.putAll(bucketToPartitionForStagesConsumedByCoordinator); for (SqlStage stage : stageManager.getDistributedStagesInTopologicalOrder()) { PlanFragment fragment = stage.getFragment(); BucketToPartitionKey bucketToPartitionKey = getKeyForFragment(fragment, session); - Optional bucketToPartition = getBucketToPartition(bucketToPartitionKey, partitioningCache); + Optional bucketToPartition = getBucketToPartition(bucketToPartitionKey, partitioningCache); for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) { result.put(childStage.getFragment().getId(), bucketToPartition); } @@ -994,10 +1008,10 @@ private static BucketToPartitionKey getKeyForFragment(PlanFragment fragment, Ses return new PartitioningKey(partitioningHandle, partitionCount); } - private static Optional getBucketToPartition(BucketToPartitionKey bucketToPartitionKey, Function partitioningCache) + private static Optional getBucketToPartition(BucketToPartitionKey bucketToPartitionKey, Function partitioningCache) { return switch (bucketToPartitionKey) { - case ONE -> Optional.of(new int[1]); + case ONE -> Optional.of(new BucketToPartition(new int[1], false)); case EMPTY -> Optional.empty(); case PartitioningKey key -> { NodePartitionMap nodePartitionMap = partitioningCache.apply(key); @@ -1012,7 +1026,7 @@ private static Optional getBucketToPartition(BucketToPartitionKey bucketT private static Map createOutputBufferManagers( Map outputBuffersForStagesConsumedByCoordinator, StageManager stageManager, - Map> bucketToPartitionMap) + Map> bucketToPartitionMap) { ImmutableMap.Builder result = ImmutableMap.builder(); result.putAll(outputBuffersForStagesConsumedByCoordinator); @@ -1029,9 +1043,9 @@ else if (partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) { outputBufferManager = new ScaledPipelinedOutputBufferManager(); } else { - Optional bucketToPartition = bucketToPartitionMap.get(fragmentId); + Optional bucketToPartition = bucketToPartitionMap.get(fragmentId); checkArgument(bucketToPartition.isPresent(), "bucketToPartition is expected to be present for fragment: %s", fragmentId); - int partitionCount = Ints.max(bucketToPartition.get()) + 1; + int partitionCount = Ints.max(bucketToPartition.get().bucketToPartition()) + 1; outputBufferManager = new PartitionedPipelinedOutputBufferManager(partitioningHandle, partitionCount); } result.put(fragmentId, outputBufferManager); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java index e88a35fef094..0d44da600b76 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java @@ -52,6 +52,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; @@ -113,6 +114,7 @@ public class PipelinedStageExecution private final TaskLifecycleListener taskLifecycleListener; private final FailureDetector failureDetector; private final Optional bucketToPartition; + private final OptionalInt skewedBucketCount; private final Map exchangeSources; private final int attempt; @@ -139,6 +141,7 @@ public static PipelinedStageExecution createPipelinedStageExecution( FailureDetector failureDetector, Executor executor, Optional bucketToPartition, + OptionalInt skewedBucketCount, int attempt) { PipelinedStageStateMachine stateMachine = new PipelinedStageStateMachine(stage.getStageId(), executor); @@ -155,6 +158,7 @@ public static PipelinedStageExecution createPipelinedStageExecution( taskLifecycleListener, failureDetector, bucketToPartition, + skewedBucketCount, exchangeSources.buildOrThrow(), attempt); execution.initialize(); @@ -168,6 +172,7 @@ private PipelinedStageExecution( TaskLifecycleListener taskLifecycleListener, FailureDetector failureDetector, Optional bucketToPartition, + OptionalInt skewedBucketCount, Map exchangeSources, int attempt) { @@ -177,6 +182,7 @@ private PipelinedStageExecution( this.taskLifecycleListener = requireNonNull(taskLifecycleListener, "taskLifecycleListener is null"); this.failureDetector = requireNonNull(failureDetector, "failureDetector is null"); this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); + this.skewedBucketCount = requireNonNull(skewedBucketCount, "skewedBucketCount is null"); this.exchangeSources = ImmutableMap.copyOf(requireNonNull(exchangeSources, "exchangeSources is null")); this.attempt = attempt; } @@ -296,6 +302,7 @@ public synchronized Optional scheduleTask( partition, attempt, bucketToPartition, + skewedBucketCount, outputBuffers, initialSplits, ImmutableSet.of(), diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java index 0034c2713629..959e7ada8322 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/StageManager.java @@ -26,6 +26,7 @@ import io.trino.execution.QueryStateMachine; import io.trino.execution.RemoteTaskFactory; import io.trino.execution.SqlStage; +import io.trino.execution.SqlStage.LocalExchangeBucketCountProvider; import io.trino.execution.StageId; import io.trino.execution.StageInfo; import io.trino.execution.TableInfo; @@ -72,7 +73,8 @@ static StageManager create( Span schedulerSpan, SplitSchedulerStats schedulerStats, SubPlan planTree, - boolean summarizeTaskInfo) + boolean summarizeTaskInfo, + LocalExchangeBucketCountProvider bucketCountProvider) { Session session = queryStateMachine.getSession(); ImmutableMap.Builder stages = ImmutableMap.builder(); @@ -95,7 +97,8 @@ static StageManager create( queryStateMachine.getStateMachineExecutor(), tracer, schedulerSpan, - schedulerStats); + schedulerStats, + bucketCountProvider); StageId stageId = stage.getStageId(); stages.put(stageId, stage); stagesInTopologicalOrder.add(stage); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java index a059475880e0..9e9f9e9bcd58 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/faulttolerant/EventDrivenFaultTolerantQueryScheduler.java @@ -55,6 +55,7 @@ import io.trino.execution.RemoteTask; import io.trino.execution.RemoteTaskFactory; import io.trino.execution.SqlStage; +import io.trino.execution.SqlStage.LocalExchangeBucketCountProvider; import io.trino.execution.StageId; import io.trino.execution.StageInfo; import io.trino.execution.StageState; @@ -127,6 +128,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.Queue; import java.util.Set; @@ -370,7 +372,8 @@ public synchronized void start() originalPlan, maxPartitionCount, stageEstimationForEagerParentEnabled, - adaptivePlanner); + adaptivePlanner, + nodePartitioningManager::getBucketCount); queryExecutor.submit(scheduler::run); } catch (Throwable t) { @@ -697,6 +700,7 @@ private static class Scheduler private final OutputStatsEstimator outputStatsEstimator; private final FaultTolerantPartitioningSchemeFactory partitioningSchemeFactory; private final ExchangeManager exchangeManager; + private final LocalExchangeBucketCountProvider bucketCountProvider; private final int maxTaskExecutionAttempts; private final int maxTasksWaitingForNode; private final int maxTasksWaitingForExecution; @@ -765,7 +769,8 @@ public Scheduler( SubPlan plan, int maxPartitionCount, boolean stageEstimationForEagerParentEnabled, - Optional adaptivePlanner) + Optional adaptivePlanner, + LocalExchangeBucketCountProvider bucketCountProvider) { this.queryStateMachine = requireNonNull(queryStateMachine, "queryStateMachine is null"); this.metadata = requireNonNull(metadata, "metadata is null"); @@ -796,6 +801,7 @@ public Scheduler( this.plan = requireNonNull(plan, "plan is null"); this.maxPartitionCount = maxPartitionCount; this.adaptivePlanner = requireNonNull(adaptivePlanner, "adaptivePlanner is null"); + this.bucketCountProvider = requireNonNull(bucketCountProvider, "bucketCountProvider is null"); this.stageEstimationForEagerParentEnabled = stageEstimationForEagerParentEnabled; this.schedulerSpan = tracer.spanBuilder("scheduler") .setParent(Context.current().with(queryStateMachine.getSession().getQuerySpan())) @@ -1400,7 +1406,8 @@ private void createStageExecution( queryStateMachine.getStateMachineExecutor(), tracer, schedulerSpan, - schedulerStats); + schedulerStats, + bucketCountProvider); closer.register(stage::abort); stageRegistry.add(stage); stage.addFinalStageInfoListener(_ -> queryStateMachine.updateQueryInfo(Optional.ofNullable(stageRegistry.getStageInfo()))); @@ -2301,13 +2308,14 @@ public Optional schedule(int partitionId, ExchangeSinkInstanceHandle noMoreSplits.add(partitionedSource); } } - SpoolingOutputBuffers outputBuffers = SpoolingOutputBuffers.createInitial(exchangeSinkInstanceHandle, sinkPartitioningScheme.getPartitionCount()); Optional task = stage.createTask( node, partitionId, attempt, sinkPartitioningScheme.getBucketToPartitionMap(), + // FTE does not support writer scaling + OptionalInt.empty(), outputBuffers, splits, noMoreSplits, diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java index 76be2ced34d3..6c81b471d427 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java @@ -21,7 +21,6 @@ import io.airlift.slice.XxHash64; import io.airlift.units.DataSize; import io.trino.Session; -import io.trino.operator.BucketPartitionFunction; import io.trino.operator.HashGenerator; import io.trino.operator.PartitionFunction; import io.trino.operator.output.SkewedPartitionRebalancer; @@ -29,13 +28,14 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.planner.MergePartitioningHandle; -import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.SystemPartitioningHandle; import java.io.Closeable; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -86,10 +86,11 @@ public class LocalExchange private int nextSourceIndex; public LocalExchange( - NodePartitioningManager nodePartitioningManager, + PartitionFunctionProvider partitionFunctionProvider, Session session, int defaultConcurrency, PartitioningHandle partitioning, + Optional bucketCount, List partitionChannels, List partitionChannelTypes, DataSize maxBufferedBytes, @@ -150,9 +151,10 @@ else if (isScaledWriterHashDistribution(partitioning)) { exchangerSupplier = () -> { PartitionFunction partitionFunction = createPartitionFunction( - nodePartitioningManager, + partitionFunctionProvider, session, typeOperators, + bucketCount, partitioning, partitionCount, partitionChannels, @@ -177,9 +179,10 @@ else if (partitioning.equals(FIXED_HASH_DISTRIBUTION) || partitioning.getCatalog .collect(toImmutableList()); exchangerSupplier = () -> { PartitionFunction partitionFunction = createPartitionFunction( - nodePartitioningManager, + partitionFunctionProvider, session, typeOperators, + bucketCount, partitioning, bufferCount, partitionChannels, @@ -231,17 +234,18 @@ private static Function createPartitionPagePreparer(PartitioningHand } private static PartitionFunction createPartitionFunction( - NodePartitioningManager nodePartitioningManager, + PartitionFunctionProvider partitionFunctionProvider, Session session, TypeOperators typeOperators, - PartitioningHandle partitioning, + Optional optionalBucketCount, + PartitioningHandle partitioningHandle, int partitionCount, List partitionChannels, List partitionChannelTypes) { checkArgument(Integer.bitCount(partitionCount) == 1, "partitionCount must be a power of 2"); - if (isSystemPartitioning(partitioning)) { + if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) { HashGenerator hashGenerator = createChannelsHashGenerator(partitionChannelTypes, Ints.toArray(partitionChannels), typeOperators); return new LocalPartitionGenerator(hashGenerator, partitionCount); } @@ -250,7 +254,7 @@ private static PartitionFunction createPartitionFunction( // The same bucket function (with the same bucket count) as for node // partitioning must be used. This way rows within a single bucket // will be being processed by single thread. - int bucketCount = getBucketCount(session, nodePartitioningManager, partitioning); + int bucketCount = optionalBucketCount.orElseThrow(() -> new IllegalArgumentException("Bucket count must be set before non-system partition function can be created")); int[] bucketToPartition = new int[bucketCount]; for (int bucket = 0; bucket < bucketCount; bucket++) { @@ -259,30 +263,7 @@ private static PartitionFunction createPartitionFunction( bucketToPartition[bucket] = hashedBucket & (partitionCount - 1); } - if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle handle) { - return handle.getPartitionFunction( - (scheme, types) -> nodePartitioningManager.getPartitionFunction(session, scheme, types, bucketToPartition), - partitionChannelTypes, - bucketToPartition); - } - - return new BucketPartitionFunction( - nodePartitioningManager.getBucketFunction(session, partitioning, partitionChannelTypes, bucketCount), - bucketToPartition); - } - - public static int getBucketCount(Session session, NodePartitioningManager nodePartitioningManager, PartitioningHandle partitioning) - { - if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) { - // TODO: can we always use this code path? - return nodePartitioningManager.getNodePartitioningMap(session, partitioning, 1000).getBucketToPartition().length; - } - return nodePartitioningManager.getBucketCount(session, partitioning); - } - - private static boolean isSystemPartitioning(PartitioningHandle partitioning) - { - return partitioning.getConnectorHandle() instanceof SystemPartitioningHandle; + return partitionFunctionProvider.getPartitionFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition); } private void checkAllSourcesFinished() diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java index efde572b5522..65b0329df811 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java @@ -22,9 +22,10 @@ import io.trino.SystemSessionProperties; import io.trino.execution.resourcegroups.IndexedPriorityQueue; import io.trino.operator.PartitionFunction; -import io.trino.spi.connector.ConnectorBucketNodeMap; import io.trino.spi.type.Type; +import io.trino.sql.planner.NodePartitionMap.BucketToPartition; import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.PartitioningHandle; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.SystemPartitioningHandle; @@ -32,6 +33,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.OptionalInt; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLongArray; @@ -77,7 +79,7 @@ public class SkewedPartitionRebalancer private static final Logger log = Logger.get(SkewedPartitionRebalancer.class); // Keep the scale writers partition count big enough such that we could rebalance skewed partitions // at more granularity, thus leading to less resource utilization at writer stage. - private static final int SCALE_WRITERS_PARTITION_COUNT = 4096; + public static final int SCALE_WRITERS_PARTITION_COUNT = 4096; // If the percentage difference between the two different task buckets with maximum and minimum processed bytes // since last rebalance is above 0.7 (or 70%), then we consider them skewed. private static final double TASK_BUCKET_SKEWNESS_THRESHOLD = 0.7; @@ -106,29 +108,41 @@ public class SkewedPartitionRebalancer private final List> partitionAssignments; - public static boolean checkCanScalePartitionsRemotely(Session session, int taskCount, PartitioningHandle partitioningHandle, NodePartitioningManager nodePartitioningManager) + public static OptionalInt getSkewedBucketCount(Session session, PartitioningScheme partitioningScheme, BucketToPartition bucketToPartition, NodePartitioningManager nodePartitioningManager) { + // FTE does not support skewed partition rebalancing if (SystemSessionProperties.getRetryPolicy(session) == TASK) { - return false; + return OptionalInt.empty(); + } + + // If it is fixed then we can't distribute a bucket across multiple tasks. + if (bucketToPartition.hasFixedMapping()) { + return OptionalInt.empty(); } - // In case of connector partitioning, check if bucketToPartitions has fixed mapping or not. If it is fixed - // then we can't distribute a bucket across multiple tasks. - boolean hasFixedNodeMapping = partitioningHandle.getCatalogHandle() - .map(_ -> nodePartitioningManager.getConnectorBucketNodeMap(session, partitioningHandle) - .map(ConnectorBucketNodeMap::hasFixedMapping) - .orElse(false)) - .orElse(false); - // Use skewed partition rebalancer only when there are more than one tasks - return taskCount > 1 && !hasFixedNodeMapping && isScaledWriterHashDistribution(partitioningHandle); + + PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle(); + if (!isScaledWriterHashDistribution(partitioningHandle)) { + return OptionalInt.empty(); + } + + // Use skewed partition rebalancer only when there are more than one tasks (partition) + if (IntStream.of(bucketToPartition.bucketToPartition()).max().orElseThrow() == 0) { + return OptionalInt.empty(); + } + + int bucketCount = (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) + ? SCALE_WRITERS_PARTITION_COUNT + : nodePartitioningManager.getDefaultBucketCount(session); + return OptionalInt.of(bucketCount); } public static PartitionFunction createPartitionFunction( Session session, - NodePartitioningManager nodePartitioningManager, - PartitioningScheme scheme, + PartitionFunctionProvider partitionFunctionProvider, + PartitioningHandle partitioningHandle, + int bucketCount, List partitionChannelTypes) { - PartitioningHandle handle = scheme.getPartitioning().getHandle(); // In case of SystemPartitioningHandle we can use arbitrary bucket count so that skewness mitigation // is more granular. // Whereas, in the case of connector partitioning we have to use connector provided bucketCount @@ -145,14 +159,9 @@ public static PartitionFunction createPartitionFunction( // five artificial buckets resemble the first hive bucket. Therefore, these artificial buckets // have to write minPartitionDataProcessedRebalanceThreshold before they get scaled to task 1, which is slow // compared to only a single hive bucket reaching the min limit. - int bucketCount = (handle.getConnectorHandle() instanceof SystemPartitioningHandle) - ? SCALE_WRITERS_PARTITION_COUNT - : nodePartitioningManager.getBucketCount(session, handle); - return nodePartitioningManager.getPartitionFunction( - session, - scheme, - partitionChannelTypes, - IntStream.range(0, bucketCount).toArray()); + int[] bucketToPartition = IntStream.range(0, bucketCount).toArray(); + + return partitionFunctionProvider.getPartitionFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition); } public static int getMaxWritersBasedOnMemory(Session session) diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 7e5049889ba8..91997f8fe96d 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -117,6 +117,7 @@ import io.trino.sql.SessionPropertyResolver; import io.trino.sql.analyzer.AnalyzerFactory; import io.trino.sql.analyzer.QueryExplainerFactory; +import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.OptimizerStatsMBeanExporter; import io.trino.sql.planner.PlanFragmenter; import io.trino.sql.planner.PlanOptimizers; @@ -240,6 +241,9 @@ protected void setup(Binder binder) newExporter(binder).export(ClusterMemoryManager.class).withGeneratedName(); + // node partitioning manager + binder.bind(NodePartitioningManager.class).in(Scopes.SINGLETON); + // node allocator binder.bind(BinPackingNodeAllocatorService.class).in(Scopes.SINGLETON); newExporter(binder).export(BinPackingNodeAllocatorService.class).withGeneratedName(); diff --git a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java index 799cf8f179a5..ecbd661fc179 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerMainModule.java @@ -142,8 +142,8 @@ import io.trino.sql.parser.SqlParser; import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; -import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.OptimizerConfig; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.RuleStatsRecorder; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolKeyDeserializer; @@ -431,8 +431,8 @@ protected void setup(Binder binder) // split manager binder.bind(SplitManager.class).in(Scopes.SINGLETON); - // node partitioning manager - binder.bind(NodePartitioningManager.class).in(Scopes.SINGLETON); + // partitioning function provider + binder.bind(PartitionFunctionProvider.class).in(Scopes.SINGLETON); // index manager binder.bind(IndexManager.class).in(Scopes.SINGLETON); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 541dac738693..ed45e3dd6fda 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -347,7 +347,6 @@ import static io.trino.operator.join.JoinUtils.isBuildSideReplicated; import static io.trino.operator.join.NestedLoopBuildOperator.NestedLoopBuildOperatorFactory; import static io.trino.operator.join.NestedLoopJoinOperator.NestedLoopJoinOperatorFactory; -import static io.trino.operator.output.SkewedPartitionRebalancer.checkCanScalePartitionsRemotely; import static io.trino.operator.output.SkewedPartitionRebalancer.createPartitionFunction; import static io.trino.operator.output.SkewedPartitionRebalancer.getMaxWritersBasedOnMemory; import static io.trino.operator.output.SkewedPartitionRebalancer.getTaskCount; @@ -416,7 +415,7 @@ public class LocalExecutionPlanner private final Optional explainAnalyzeContext; private final PageSourceManager pageSourceManager; private final IndexManager indexManager; - private final NodePartitioningManager nodePartitioningManager; + private final PartitionFunctionProvider partitionFunctionProvider; private final PageSinkManager pageSinkManager; private final DirectExchangeClientSupplier directExchangeClientSupplier; private final ExpressionCompiler expressionCompiler; @@ -473,7 +472,7 @@ public LocalExecutionPlanner( Optional explainAnalyzeContext, PageSourceManager pageSourceManager, IndexManager indexManager, - NodePartitioningManager nodePartitioningManager, + PartitionFunctionProvider partitionFunctionProvider, PageSinkManager pageSinkManager, DirectExchangeClientSupplier directExchangeClientSupplier, ExpressionCompiler expressionCompiler, @@ -503,7 +502,7 @@ public LocalExecutionPlanner( this.explainAnalyzeContext = requireNonNull(explainAnalyzeContext, "explainAnalyzeContext is null"); this.pageSourceManager = requireNonNull(pageSourceManager, "pageSourceManager is null"); this.indexManager = requireNonNull(indexManager, "indexManager is null"); - this.nodePartitioningManager = requireNonNull(nodePartitioningManager, "nodePartitioningManager is null"); + this.partitionFunctionProvider = requireNonNull(partitionFunctionProvider, "partitionFunctionProvider is null"); this.directExchangeClientSupplier = requireNonNull(directExchangeClientSupplier, "directExchangeClientSupplier is null"); this.pageSinkManager = requireNonNull(pageSinkManager, "pageSinkManager is null"); this.expressionCompiler = requireNonNull(expressionCompiler, "expressionCompiler is null"); @@ -552,6 +551,7 @@ public LocalExecutionPlan plan( TaskContext taskContext, PlanNode plan, PartitioningScheme partitioningScheme, + OptionalInt outputSkewedBucketCount, List partitionedSourceOrder, OutputBuffer outputBuffer) { @@ -594,8 +594,8 @@ public LocalExecutionPlan plan( PartitionFunction partitionFunction; Optional skewedPartitionRebalancer = Optional.empty(); int taskCount = getTaskCount(partitioningScheme); - if (checkCanScalePartitionsRemotely(taskContext.getSession(), taskCount, partitioningScheme.getPartitioning().getHandle(), nodePartitioningManager)) { - partitionFunction = createPartitionFunction(taskContext.getSession(), nodePartitioningManager, partitioningScheme, partitionChannelTypes); + if (outputSkewedBucketCount.isPresent()) { + partitionFunction = createPartitionFunction(taskContext.getSession(), partitionFunctionProvider, partitioningScheme.getPartitioning().getHandle(), outputSkewedBucketCount.getAsInt(), partitionChannelTypes); int partitionedWriterCount = getPartitionedWriterCountBasedOnMemory(taskContext.getSession()); // Keep the task bucket count to 50% of total local writers int taskBucketCount = (int) ceil(0.5 * partitionedWriterCount); @@ -607,7 +607,12 @@ public LocalExecutionPlan plan( getSkewedPartitionMinDataProcessedRebalanceThreshold(taskContext.getSession()).toBytes())); } else { - partitionFunction = nodePartitioningManager.getPartitionFunction(taskContext.getSession(), partitioningScheme, partitionChannelTypes); + partitionFunction = partitionFunctionProvider.getPartitionFunction( + taskContext.getSession(), + partitioningScheme.getPartitioning().getHandle(), + partitionChannelTypes, + partitioningScheme.getBucketToPartition() + .orElseThrow(() -> new IllegalArgumentException("Bucket to partition must be set before a partition function can be created"))); } OptionalInt nullChannel = OptionalInt.empty(); Set partitioningColumns = partitioningScheme.getPartitioning().getColumns(); @@ -3690,10 +3695,11 @@ private PhysicalOperation createLocalMerge(ExchangeNode node, LocalExecutionPlan int operatorsCount = subContext.getDriverInstanceCount().orElse(1); List types = getSourceOperatorTypes(node); LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + partitionFunctionProvider, session, operatorsCount, node.getPartitioningScheme().getPartitioning().getHandle(), + node.getPartitioningScheme().getBucketCount(), ImmutableList.of(), ImmutableList.of(), maxLocalExchangeBufferSize, @@ -3764,10 +3770,11 @@ else if (context.getDriverInstanceCount().isPresent()) { } LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + partitionFunctionProvider, session, driverInstanceCount, node.getPartitioningScheme().getPartitioning().getHandle(), + node.getPartitioningScheme().getBucketCount(), partitionChannels, partitionChannelTypes, maxLocalExchangeBufferSize, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index c1d69a158f5a..b4efd176508a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -724,6 +724,7 @@ else if (isUsePreferredWritePartitioning(session)) { outputLayout, false, Optional.empty(), + Optional.empty(), maxWritersNodesCount)); } } @@ -1016,6 +1017,7 @@ else if (isUsePreferredWritePartitioning(session)) { outputLayout, false, Optional.empty(), + Optional.empty(), maxWritersNodesCount)); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java index 95f522c66606..32c5892c142d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/MergePartitioningHandle.java @@ -100,6 +100,22 @@ public String toString() return "MERGE " + parts; } + public Optional getBucketCount(Function> getBucketCount) + { + Optional optionalInsertBucketCount = insertPartitioning.map(scheme -> scheme.getPartitioning().getHandle()).flatMap(getBucketCount); + Optional optionalUpdateBucketCount = updatePartitioning.map(scheme -> scheme.getPartitioning().getHandle()).flatMap(getBucketCount); + + if (optionalInsertBucketCount.isPresent() && optionalUpdateBucketCount.isPresent()) { + int insertBucketCount = optionalInsertBucketCount.get(); + int updateBucketCount = optionalUpdateBucketCount.get(); + if (insertBucketCount != updateBucketCount) { + throw new TrinoException(NOT_SUPPORTED, "Insert and update layout have mismatched bucket counts: " + insertBucketCount + " vs " + updateBucketCount); + } + } + + return optionalInsertBucketCount.or(() -> optionalUpdateBucketCount); + } + public NodePartitionMap getNodePartitioningMap(Function getMap) { Optional optionalInsertMap = insertPartitioning.map(scheme -> scheme.getPartitioning().getHandle()).map(getMap); @@ -109,7 +125,7 @@ public NodePartitionMap getNodePartitioningMap(Function partitionToNode; - private final int[] bucketToPartition; + private final BucketToPartition bucketToPartition; private final ToIntFunction splitToBucket; public NodePartitionMap(List partitionToNode, ToIntFunction splitToBucket) { this.partitionToNode = ImmutableList.copyOf(requireNonNull(partitionToNode, "partitionToNode is null")); - this.bucketToPartition = IntStream.range(0, partitionToNode.size()).toArray(); + this.bucketToPartition = new BucketToPartition(IntStream.range(0, partitionToNode.size()).toArray(), false); this.splitToBucket = requireNonNull(splitToBucket, "splitToBucket is null"); } - public NodePartitionMap(List partitionToNode, int[] bucketToPartition, ToIntFunction splitToBucket) + public NodePartitionMap(List partitionToNode, BucketToPartition bucketToPartition, ToIntFunction splitToBucket) { - this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); this.partitionToNode = ImmutableList.copyOf(requireNonNull(partitionToNode, "partitionToNode is null")); + this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); this.splitToBucket = requireNonNull(splitToBucket, "splitToBucket is null"); } @@ -57,7 +66,7 @@ public List getPartitionToNode() return partitionToNode; } - public int[] getBucketToPartition() + public BucketToPartition getBucketToPartition() { return bucketToPartition; } @@ -65,14 +74,14 @@ public int[] getBucketToPartition() public InternalNode getNode(Split split) { int bucket = splitToBucket.applyAsInt(split); - int partition = bucketToPartition[bucket]; + int partition = bucketToPartition.bucketToPartition()[bucket]; return requireNonNull(partitionToNode.get(partition)); } public BucketNodeMap asBucketNodeMap() { ImmutableList.Builder bucketToNode = ImmutableList.builder(); - for (int partition : bucketToPartition) { + for (int partition : bucketToPartition.bucketToPartition()) { bucketToNode.add(partitionToNode.get(partition)); } return new BucketNodeMap(splitToBucket, bucketToNode.build()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java index 074f4fff70d5..54e31c945270 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java @@ -25,17 +25,14 @@ import io.trino.execution.scheduler.NodeSelector; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; -import io.trino.operator.BucketPartitionFunction; -import io.trino.operator.PartitionFunction; import io.trino.operator.RetryPolicy; -import io.trino.spi.connector.BucketFunction; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorBucketNodeMap; import io.trino.spi.connector.ConnectorNodePartitioningProvider; +import io.trino.spi.connector.ConnectorPartitioningHandle; import io.trino.spi.connector.ConnectorSplit; -import io.trino.spi.type.Type; -import io.trino.spi.type.TypeOperators; import io.trino.split.EmptySplit; +import io.trino.sql.planner.NodePartitionMap.BucketToPartition; import io.trino.sql.planner.SystemPartitioningHandle.SystemPartitioning; import java.util.HashMap; @@ -63,76 +60,17 @@ public class NodePartitioningManager { private final NodeScheduler nodeScheduler; - private final TypeOperators typeOperators; private final CatalogServiceProvider partitioningProvider; @Inject public NodePartitioningManager( NodeScheduler nodeScheduler, - TypeOperators typeOperators, CatalogServiceProvider partitioningProvider) { this.nodeScheduler = requireNonNull(nodeScheduler, "nodeScheduler is null"); - this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); this.partitioningProvider = requireNonNull(partitioningProvider, "partitioningProvider is null"); } - public PartitionFunction getPartitionFunction( - Session session, - PartitioningScheme partitioningScheme, - List partitionChannelTypes) - { - int[] bucketToPartition = partitioningScheme.getBucketToPartition() - .orElseThrow(() -> new IllegalArgumentException("Bucket to partition must be set before a partition function can be created")); - - PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle(); - if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) { - return ((SystemPartitioningHandle) partitioningHandle.getConnectorHandle()).getPartitionFunction( - partitionChannelTypes, - bucketToPartition, - typeOperators); - } - - if (partitioningHandle.getConnectorHandle() instanceof MergePartitioningHandle handle) { - return handle.getPartitionFunction( - (scheme, types) -> getPartitionFunction(session, scheme, types, bucketToPartition), - partitionChannelTypes, - bucketToPartition); - } - - return getPartitionFunction(session, partitioningScheme, partitionChannelTypes, bucketToPartition); - } - - public PartitionFunction getPartitionFunction(Session session, PartitioningScheme partitioningScheme, List partitionChannelTypes, int[] bucketToPartition) - { - PartitioningHandle partitioningHandle = partitioningScheme.getPartitioning().getHandle(); - - if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle handle) { - return handle.getPartitionFunction( - partitionChannelTypes, - bucketToPartition, - typeOperators); - } - - BucketFunction bucketFunction = getBucketFunction(session, partitioningHandle, partitionChannelTypes, bucketToPartition.length); - return new BucketPartitionFunction(bucketFunction, bucketToPartition); - } - - public BucketFunction getBucketFunction(Session session, PartitioningHandle partitioningHandle, List partitionChannelTypes, int bucketCount) - { - CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle); - ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(catalogHandle); - - BucketFunction bucketFunction = partitioningProvider.getBucketFunction( - partitioningHandle.getTransactionHandle().orElseThrow(), - session.toConnectorSession(), - partitioningHandle.getConnectorHandle(), - partitionChannelTypes, - bucketCount); - checkArgument(bucketFunction != null, "No bucket function for partitioning: %s", partitioningHandle); - return bucketFunction; - } - public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, int partitionCount) { return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>(), partitionCount); @@ -204,7 +142,11 @@ private NodePartitionMap getNodePartitioningMap( .mapToObj(partitionId -> nodeToPartition.inverse().get(partitionId)) .collect(toImmutableList()); - return new NodePartitionMap(partitionToNode, bucketToPartition, getSplitToBucket(session, partitioningHandle, bucketToNode.size())); + boolean hasFixedMapping = optionalMap.map(ConnectorBucketNodeMap::hasFixedMapping).orElse(false); + return new NodePartitionMap( + partitionToNode, + new BucketToPartition(bucketToPartition, hasFixedMapping), + getSplitToBucket(session, partitioningHandle, bucketToNode.size())); } private List systemBucketToNode(Session session, PartitioningHandle partitioningHandle, AtomicReference> nodesCache, int partitionCount) @@ -230,10 +172,19 @@ private List systemBucketToNode(Session session, PartitioningHandl return nodes; } - public int getBucketCount(Session session, PartitioningHandle partitioningHandle) + public Optional getBucketCount(Session session, PartitioningHandle partitioningHandle) { - // we don't care about partition count at all, just bucket count - return getBucketNodeMap(session, partitioningHandle, 1000).getBucketCount(); + if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle) { + return Optional.empty(); + } + + ConnectorPartitioningHandle connectorHandle = partitioningHandle.getConnectorHandle(); + if (connectorHandle instanceof MergePartitioningHandle mergeHandle) { + return mergeHandle.getBucketCount(handle -> getBucketCount(session, handle)) + .or(() -> Optional.of(getDefaultBucketCount(session))); + } + Optional bucketNodeMap = getConnectorBucketNodeMap(session, partitioningHandle); + return Optional.of(bucketNodeMap.map(ConnectorBucketNodeMap::getBucketCount).orElseGet(() -> getDefaultBucketCount(session))); } public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partitioningHandle, int partitionCount) @@ -258,7 +209,7 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit * * @return The default bucket count to use when the connector doesn't provide a number. */ - private int getDefaultBucketCount(Session session) + public int getDefaultBucketCount(Session session) { // The default bucket count is used by both remote and local exchanges to assign buckets to nodes and drivers. The goal is to have enough // buckets to evenly distribute them across tasks or drivers. If number of buckets is too low, then some tasks or drivers will be idle. diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartitionFunctionProvider.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartitionFunctionProvider.java new file mode 100644 index 000000000000..0a53d28cd40e --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartitionFunctionProvider.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.inject.Inject; +import io.trino.Session; +import io.trino.connector.CatalogServiceProvider; +import io.trino.operator.BucketPartitionFunction; +import io.trino.operator.PartitionFunction; +import io.trino.spi.connector.BucketFunction; +import io.trino.spi.connector.CatalogHandle; +import io.trino.spi.connector.ConnectorNodePartitioningProvider; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeOperators; + +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class PartitionFunctionProvider +{ + private final TypeOperators typeOperators; + private final CatalogServiceProvider partitioningProvider; + + @Inject + public PartitionFunctionProvider(TypeOperators typeOperators, CatalogServiceProvider partitioningProvider) + { + this.typeOperators = requireNonNull(typeOperators, "typeOperators is null"); + this.partitioningProvider = requireNonNull(partitioningProvider, "partitioningProvider is null"); + } + + public PartitionFunction getPartitionFunction(Session session, PartitioningHandle partitioningHandle, List partitionChannelTypes, int[] bucketToPartition) + { + if (partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle handle) { + return handle.getPartitionFunction(partitionChannelTypes, bucketToPartition, typeOperators); + } + + if (partitioningHandle.getConnectorHandle() instanceof MergePartitioningHandle handle) { + return handle.getPartitionFunction( + (scheme, types) -> getPartitionFunction(session, scheme.getPartitioning().getHandle(), types, bucketToPartition), + partitionChannelTypes, + bucketToPartition); + } + + ConnectorNodePartitioningProvider partitioningProvider = getPartitioningProvider(partitioningHandle); + + BucketFunction bucketFunction = partitioningProvider.getBucketFunction( + partitioningHandle.getTransactionHandle().orElseThrow(), + session.toConnectorSession(), + partitioningHandle.getConnectorHandle(), + partitionChannelTypes, + bucketToPartition.length); + checkArgument(bucketFunction != null, "No bucket function for partitioning: %s", partitioningHandle); + return new BucketPartitionFunction(bucketFunction, bucketToPartition); + } + + // NOTE: Do not access any function other than getBucketFunction as the other functions are not usable on workers + private ConnectorNodePartitioningProvider getPartitioningProvider(PartitioningHandle partitioningHandle) + { + CatalogHandle catalogHandle = partitioningHandle.getCatalogHandle().orElseThrow(() -> + new IllegalStateException("No catalog handle for partitioning handle: " + partitioningHandle)); + return partitioningProvider.getService(requireNonNull(catalogHandle, "catalogHandle is null")); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java b/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java index f2ddb4e4f24d..79e089547864 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PartitioningScheme.java @@ -33,6 +33,7 @@ public class PartitioningScheme private final List outputLayout; private final boolean replicateNullsAndAny; private final Optional bucketToPartition; + private final Optional bucketCount; private final Optional partitionCount; public PartitioningScheme(Partitioning partitioning, List outputLayout) @@ -42,6 +43,7 @@ public PartitioningScheme(Partitioning partitioning, List outputLayout) outputLayout, false, Optional.empty(), + Optional.empty(), Optional.empty()); } @@ -51,6 +53,7 @@ public PartitioningScheme( @JsonProperty("outputLayout") List outputLayout, @JsonProperty("replicateNullsAndAny") boolean replicateNullsAndAny, @JsonProperty("bucketToPartition") Optional bucketToPartition, + @JsonProperty("bucketCount") Optional bucketCount, @JsonProperty("partitionCount") Optional partitionCount) { this.partitioning = requireNonNull(partitioning, "partitioning is null"); @@ -63,6 +66,11 @@ public PartitioningScheme( checkArgument(!replicateNullsAndAny || columns.size() <= 1, "Must have at most one partitioning column when nullPartition is REPLICATE."); this.replicateNullsAndAny = replicateNullsAndAny; this.bucketToPartition = requireNonNull(bucketToPartition, "bucketToPartition is null"); + this.bucketCount = bucketCount; + checkArgument(bucketCount.isEmpty() || !(partitioning.getHandle().getConnectorHandle() instanceof SystemPartitioningHandle), + "Bucket count cannot be set on a system partitioning handle"); + checkArgument(bucketToPartition.isEmpty() || bucketCount.isEmpty() || bucketToPartition.get().length == bucketCount.get(), + "bucketToPartition length does not match bucketCount"); this.partitionCount = requireNonNull(partitionCount, "partitionCount is null"); checkArgument( partitionCount.isEmpty() || partitioning.getHandle().getConnectorHandle() instanceof SystemPartitioningHandle, @@ -93,6 +101,12 @@ public Optional getBucketToPartition() return bucketToPartition; } + @JsonProperty + public Optional getBucketCount() + { + return bucketCount; + } + @JsonProperty public Optional getPartitionCount() { @@ -101,18 +115,23 @@ public Optional getPartitionCount() public PartitioningScheme withBucketToPartition(Optional bucketToPartition) { - return new PartitioningScheme(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, partitionCount); + return new PartitioningScheme(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, bucketCount, partitionCount); + } + + public PartitioningScheme withBucketCount(Optional bucketCount) + { + return new PartitioningScheme(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, bucketCount, partitionCount); } public PartitioningScheme withPartitioningHandle(PartitioningHandle partitioningHandle) { Partitioning newPartitioning = partitioning.withAlternativePartitioningHandle(partitioningHandle); - return new PartitioningScheme(newPartitioning, outputLayout, replicateNullsAndAny, bucketToPartition, partitionCount); + return new PartitioningScheme(newPartitioning, outputLayout, replicateNullsAndAny, bucketToPartition, bucketCount, partitionCount); } public PartitioningScheme withPartitionCount(Optional partitionCount) { - return new PartitioningScheme(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, partitionCount); + return new PartitioningScheme(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, bucketCount, partitionCount); } public PartitioningScheme translateOutputLayout(List newOutputLayout) @@ -123,7 +142,7 @@ public PartitioningScheme translateOutputLayout(List newOutputLayout) Partitioning newPartitioning = partitioning.translate(symbol -> newOutputLayout.get(outputLayout.indexOf(symbol))); - return new PartitioningScheme(newPartitioning, newOutputLayout, replicateNullsAndAny, bucketToPartition, partitionCount); + return new PartitioningScheme(newPartitioning, newOutputLayout, replicateNullsAndAny, bucketToPartition, bucketCount, partitionCount); } @Override @@ -140,13 +159,14 @@ public boolean equals(Object o) Objects.equals(outputLayout, that.outputLayout) && replicateNullsAndAny == that.replicateNullsAndAny && Objects.equals(bucketToPartition, that.bucketToPartition) && + Objects.equals(bucketCount, that.bucketCount) && Objects.equals(partitionCount, that.partitionCount); } @Override public int hashCode() { - return Objects.hash(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, partitionCount); + return Objects.hash(partitioning, outputLayout, replicateNullsAndAny, bucketToPartition, bucketCount, partitionCount); } @Override @@ -157,6 +177,7 @@ public String toString() .add("outputLayout", outputLayout) .add("replicateNullsAndAny", replicateNullsAndAny) .add("bucketToPartition", bucketToPartition) + .add("bucketCount", bucketCount) .add("partitionCount", partitionCount) .toString(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java index 1016c05e48d9..3fe6a244ae77 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import static com.google.common.base.MoreObjects.toStringHelper; @@ -54,6 +55,7 @@ public class PlanFragment private final Set partitionedSourceNodes; private final List remoteSourceNodes; private final PartitioningScheme outputPartitioningScheme; + private final OptionalInt outputSkewedBucketCount; private final StatsAndCosts statsAndCosts; private final List activeCatalogs; private final Map languageFunctions; @@ -73,6 +75,7 @@ private PlanFragment( Set partitionedSourceNodes, List remoteSourceNodes, PartitioningScheme outputPartitioningScheme, + OptionalInt outputSkewedBucketCount, StatsAndCosts statsAndCosts, List activeCatalogs, Map languageFunctions) @@ -88,6 +91,7 @@ private PlanFragment( this.partitionedSourceNodes = requireNonNull(partitionedSourceNodes, "partitionedSourceNodes is null"); this.remoteSourceNodes = requireNonNull(remoteSourceNodes, "remoteSourceNodes is null"); this.outputPartitioningScheme = requireNonNull(outputPartitioningScheme, "outputPartitioningScheme is null"); + this.outputSkewedBucketCount = requireNonNull(outputSkewedBucketCount, "outputSkewedPartitionCount is null"); this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); this.languageFunctions = ImmutableMap.copyOf(languageFunctions); @@ -104,6 +108,7 @@ public PlanFragment( @JsonProperty("partitionCount") Optional partitionCount, @JsonProperty("partitionedSources") List partitionedSources, @JsonProperty("outputPartitioningScheme") PartitioningScheme outputPartitioningScheme, + @JsonProperty("outputSkewedBucketCount") OptionalInt outputSkewedBucketCount, @JsonProperty("statsAndCosts") StatsAndCosts statsAndCosts, @JsonProperty("activeCatalogs") List activeCatalogs, @JsonProperty("languageFunctions") Map languageFunctions, @@ -116,6 +121,7 @@ public PlanFragment( this.partitionCount = requireNonNull(partitionCount, "partitionCount is null"); this.partitionedSources = ImmutableList.copyOf(requireNonNull(partitionedSources, "partitionedSources is null")); this.partitionedSourcesSet = ImmutableSet.copyOf(partitionedSources); + this.outputSkewedBucketCount = requireNonNull(outputSkewedBucketCount, "outputSkewedBucketCount is null"); this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); this.languageFunctions = ImmutableMap.copyOf(languageFunctions); @@ -190,6 +196,12 @@ public PartitioningScheme getOutputPartitioningScheme() return outputPartitioningScheme; } + @JsonProperty + public OptionalInt getOutputSkewedBucketCount() + { + return outputSkewedBucketCount; + } + @JsonProperty public StatsAndCosts getStatsAndCosts() { @@ -233,6 +245,7 @@ public PlanFragment withoutEmbeddedJsonRepresentation() this.partitionedSourceNodes, this.remoteSourceNodes, this.outputPartitioningScheme, + this.outputSkewedBucketCount, this.statsAndCosts, this.activeCatalogs, this.languageFunctions); @@ -287,7 +300,27 @@ private static void findRemoteSourceNodes(PlanNode node, ImmutableList.Builder bucketToPartition) + public PlanFragment withRoot(PlanNode root) + { + return new PlanFragment( + id, + root, + symbols, + partitioning, + partitionCount, + partitionedSources, + partitionedSourcesSet, + types, + partitionedSourceNodes, + remoteSourceNodes, + outputPartitioningScheme, + outputSkewedBucketCount, + statsAndCosts, + activeCatalogs, + languageFunctions); + } + + public PlanFragment withOutputPartitioning(Optional bucketToPartition, OptionalInt skewedBucketCount) { return new PlanFragment( id, @@ -297,6 +330,7 @@ public PlanFragment withBucketToPartition(Optional bucketToPartition) partitionCount, partitionedSources, outputPartitioningScheme.withBucketToPartition(bucketToPartition), + skewedBucketCount, statsAndCosts, activeCatalogs, languageFunctions, @@ -325,6 +359,7 @@ public PlanFragment withActiveCatalogs(List activeCatalogs) this.partitionCount, this.partitionedSources, this.outputPartitioningScheme, + this.outputSkewedBucketCount, this.statsAndCosts, activeCatalogs, this.languageFunctions, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index 3d4eac0c0292..e9c5efa78ca8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -65,6 +65,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.stream.Stream; @@ -226,7 +227,9 @@ private SubPlan reassignPartitioningHandleIfNecessaryHelper(Session session, Sub outputPartitioningScheme.getOutputLayout(), outputPartitioningScheme.isReplicateNullsAndAny(), outputPartitioningScheme.getBucketToPartition(), + outputPartitioningScheme.getBucketCount(), outputPartitioningScheme.getPartitionCount()), + OptionalInt.empty(), fragment.getStatsAndCosts(), fragment.getActiveCatalogs(), fragment.getLanguageFunctions(), @@ -294,6 +297,7 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan properties.getPartitionCount(), schedulingOrder, properties.getPartitioningScheme(), + OptionalInt.empty(), statsAndCosts.getForSubplan(root), activeCatalogs, languageFunctions, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java index f4f52fdd410e..295c6f8bf641 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/AddExchangesBelowPartialAggregationOverGroupIdRuleSet.java @@ -301,6 +301,7 @@ protected Optional transform(AggregationNode aggregation, GroupIdNode source.getOutputSymbols(), false, Optional.empty(), + Optional.empty(), // It's fine to reuse partitionCount since that is computed by considering all the expanding nodes and table scans in a query partitionCount)); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java index e0265b6d55c2..49ccacbf390a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PruneExchangeColumns.java @@ -104,6 +104,7 @@ protected Optional pushDownProjectOff(Context context, ExchangeNode ex newOutputs.build(), exchangeNode.getPartitioningScheme().isReplicateNullsAndAny(), exchangeNode.getPartitioningScheme().getBucketToPartition(), + exchangeNode.getPartitioningScheme().getBucketCount(), exchangeNode.getPartitioningScheme().getPartitionCount()); return Optional.of(new ExchangeNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java index 7c7346f8e928..60ac0def9714 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushPartialAggregationThroughExchange.java @@ -176,6 +176,7 @@ private PlanNode pushPartial(AggregationNode aggregation, ExchangeNode exchange, aggregation.getOutputSymbols(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition(), + exchange.getPartitioningScheme().getBucketCount(), exchange.getPartitioningScheme().getPartitionCount()); return new ExchangeNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java index 19b69c33794e..ecc4ce3dfc1d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushProjectionThroughExchange.java @@ -160,6 +160,7 @@ public Result apply(ProjectNode project, Captures captures, Context context) outputBuilder.build(), exchange.getPartitioningScheme().isReplicateNullsAndAny(), exchange.getPartitioningScheme().getBucketToPartition(), + exchange.getPartitioningScheme().getBucketCount(), exchange.getPartitioningScheme().getPartitionCount()); PlanNode result = new ExchangeNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java index 6bbdbfb4122f..f8f9fb88db87 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushRemoteExchangeThroughAssignUniqueId.java @@ -80,6 +80,7 @@ public Result apply(ExchangeNode node, Captures captures, Context context) removeSymbol(partitioningScheme.getOutputLayout(), assignUniqueId.getIdColumn()), partitioningScheme.isReplicateNullsAndAny(), partitioningScheme.getBucketToPartition(), + partitioningScheme.getBucketCount(), partitioningScheme.getPartitionCount()), ImmutableList.of(assignUniqueId.getSource()), ImmutableList.of(removeSymbol(getOnlyElement(node.getInputs()), assignUniqueId.getIdColumn())), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 7d5e8b81488a..3c42ce6c9dbb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -793,10 +793,10 @@ private PlanWithProperties getWriterPlanWithProperties(Optional sourceLayo mapAndDistinct(sourceLayout), scheme.isReplicateNullsAndAny(), scheme.getBucketToPartition(), + scheme.getBucketCount(), scheme.getPartitionCount()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java index f23b5e4086b7..1aaa95842f3c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/ExchangeNode.java @@ -132,6 +132,7 @@ public static ExchangeNode partitionedExchange(PlanNodeId id, Scope scope, PlanN child.getOutputSymbols(), replicateNullsAndAny, Optional.empty(), + Optional.empty(), Optional.empty())); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 4b0bc719aa12..ddcd15ed248b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -144,6 +144,7 @@ import java.util.Map.Entry; import java.util.Objects; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.function.Function; import java.util.stream.Stream; @@ -643,6 +644,7 @@ public static String graphvizLogicalPlan(PlanNode plan) Optional.empty(), ImmutableList.of(plan.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getOutputSymbols()), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java index 6284732a778c..654b40d0cf8c 100644 --- a/core/trino-main/src/main/java/io/trino/testing/PlanTester.java +++ b/core/trino-main/src/main/java/io/trino/testing/PlanTester.java @@ -177,6 +177,7 @@ import io.trino.sql.planner.LogicalPlanner; import io.trino.sql.planner.NodePartitioningManager; import io.trino.sql.planner.OptimizerConfig; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.Plan; import io.trino.sql.planner.PlanFragmenter; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -291,6 +292,7 @@ public class PlanTester private final PageSourceManager pageSourceManager; private final IndexManager indexManager; private final NodePartitioningManager nodePartitioningManager; + private final PartitionFunctionProvider partitionFunctionProvider; private final PageSinkManager pageSinkManager; private final TransactionManager transactionManager; private final SessionPropertyManager sessionPropertyManager; @@ -409,7 +411,8 @@ private PlanTester(Session defaultSession, int nodeCountForStats) this.indexManager = new IndexManager(createIndexProvider(catalogManager)); NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory(nodeManager, nodeSchedulerConfig, new NodeTaskMap(finalizerService))); this.sessionPropertyManager = createSessionPropertyManager(catalogManager, taskManagerConfig, optimizerConfig); - this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler, typeOperators, createNodePartitioningProvider(catalogManager)); + this.nodePartitioningManager = new NodePartitioningManager(nodeScheduler, createNodePartitioningProvider(catalogManager)); + this.partitionFunctionProvider = new PartitionFunctionProvider(typeOperators, createNodePartitioningProvider(catalogManager)); TableProceduresRegistry tableProceduresRegistry = new TableProceduresRegistry(createTableProceduresProvider(catalogManager)); FunctionManager functionManager = new FunctionManager(createFunctionProvider(catalogManager), globalFunctionCatalog, languageFunctionManager); this.schemaPropertyManager = createSchemaPropertyManager(catalogManager); @@ -755,7 +758,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out Optional.empty(), pageSourceManager, indexManager, - nodePartitioningManager, + partitionFunctionProvider, pageSinkManager, (_, _, _, _, _, _) -> { throw new UnsupportedOperationException(); diff --git a/core/trino-main/src/test/java/io/trino/cost/TestRemoteSourceStatsRule.java b/core/trino-main/src/test/java/io/trino/cost/TestRemoteSourceStatsRule.java index d49d7601e276..334b70765b64 100644 --- a/core/trino-main/src/test/java/io/trino/cost/TestRemoteSourceStatsRule.java +++ b/core/trino-main/src/test/java/io/trino/cost/TestRemoteSourceStatsRule.java @@ -28,6 +28,7 @@ import org.junit.jupiter.api.Test; import java.util.Optional; +import java.util.OptionalInt; import static io.trino.execution.scheduler.faulttolerant.OutputStatsEstimator.OutputStatsEstimateResult; import static io.trino.operator.RetryPolicy.TASK; @@ -139,6 +140,7 @@ private PlanFragment createPlanFragment(StatsAndCosts statsAndCosts) Optional.empty(), ImmutableList.of(new PlanNodeId("plan_id")), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(new Symbol(BIGINT, "col_c"))), + OptionalInt.empty(), statsAndCosts, ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index bfcdc05e2432..adc7371241ab 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -67,6 +67,7 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.ScheduledExecutorService; @@ -125,6 +126,7 @@ public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, L Optional.empty(), ImmutableList.of(sourceId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index 3d069b66f2f4..f637f4e0ba7e 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -28,10 +28,6 @@ import io.trino.exchange.ExchangeManagerRegistry; import io.trino.execution.BaseTestSqlTaskManager.MockDirectExchangeClientSupplier; import io.trino.execution.buffer.OutputBuffers; -import io.trino.execution.scheduler.NodeScheduler; -import io.trino.execution.scheduler.NodeSchedulerConfig; -import io.trino.execution.scheduler.UniformNodeSelectorFactory; -import io.trino.metadata.InMemoryNodeManager; import io.trino.metadata.Split; import io.trino.operator.FlatHashStrategyCompiler; import io.trino.operator.PagesIndex; @@ -53,7 +49,7 @@ import io.trino.sql.gen.columnar.ColumnarFilterCompiler; import io.trino.sql.planner.CompilerConfig; import io.trino.sql.planner.LocalExecutionPlanner; -import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanFragment; @@ -66,10 +62,10 @@ import io.trino.testing.TestingMetadata.TestingColumnHandle; import io.trino.testing.TestingSplit; import io.trino.type.BlockTypeOperators; -import io.trino.util.FinalizerService; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import static io.airlift.tracing.Tracing.noopTracer; @@ -112,6 +108,7 @@ private TaskTestUtils() {} ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(SYMBOL)) .withBucketToPartition(Optional.of(new int[1])), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), @@ -139,6 +136,7 @@ private TaskTestUtils() {} ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(SYMBOL)) .withBucketToPartition(Optional.of(new int[1])), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), @@ -148,16 +146,8 @@ public static LocalExecutionPlanner createTestingPlanner() { PageSourceManager pageSourceManager = new PageSourceManager(CatalogServiceProvider.singleton(CATALOG_HANDLE, new TestingPageSourceProvider())); - // we don't start the finalizer so nothing will be collected, which is ok for a test - FinalizerService finalizerService = new FinalizerService(); - BlockTypeOperators blockTypeOperators = new BlockTypeOperators(PLANNER_CONTEXT.getTypeOperators()); - NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory( - new InMemoryNodeManager(), - new NodeSchedulerConfig().setIncludeCoordinator(true), - new NodeTaskMap(finalizerService))); - NodePartitioningManager nodePartitioningManager = new NodePartitioningManager( - nodeScheduler, + PartitionFunctionProvider partitionFunctionProvider = new PartitionFunctionProvider( PLANNER_CONTEXT.getTypeOperators(), CatalogServiceProvider.fail()); @@ -169,7 +159,7 @@ public static LocalExecutionPlanner createTestingPlanner() Optional.empty(), pageSourceManager, new IndexManager(CatalogServiceProvider.fail()), - nodePartitioningManager, + partitionFunctionProvider, new PageSinkManager(CatalogServiceProvider.fail()), new MockDirectExchangeClientSupplier(), new ExpressionCompiler(cursorProcessorCompiler, pageFunctionCompiler, columnarFilterCompiler), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java index ea4715360693..ac0517a2e20e 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java @@ -49,6 +49,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -125,7 +126,8 @@ private void testFinalStageInfoInternal() executor, noopTracer(), Span.getInvalid(), - new SplitSchedulerStats()); + new SplitSchedulerStats(), + (_, _) -> Optional.empty()); // add listener that fetches stage info when the final status is available SettableFuture finalStageInfo = SettableFuture.create(); @@ -153,6 +155,7 @@ private void testFinalStageInfoInternal() i, 0, Optional.empty(), + OptionalInt.empty(), PipelinedOutputBuffers.createInitial(ARBITRARY), initialSplits, ImmutableSet.of(), @@ -244,6 +247,7 @@ private static PlanFragment createExchangePlanFragment() Optional.empty(), ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java index f8012796945b..1ba288f919a0 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java @@ -46,6 +46,7 @@ import java.time.Instant; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.concurrent.ExecutorService; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledThreadPoolExecutor; @@ -398,6 +399,7 @@ private static PlanFragment createValuesPlan() Optional.empty(), ImmutableList.of(valuesNodeId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java index 258e6f47752f..a8771672537a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestMultiSourcePartitionedScheduler.java @@ -79,6 +79,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.LinkedBlockingQueue; @@ -581,6 +582,7 @@ private PlanFragment createFragment(TableHandle firstTableHandle, TableHandle se Optional.empty(), ImmutableList.of(TABLE_SCAN_1_NODE_ID, TABLE_SCAN_2_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), @@ -619,7 +621,8 @@ TABLE_SCAN_2_NODE_ID, new TableInfo(Optional.of("test"), new QualifiedObjectName queryExecutor, noopTracer(), Span.getInvalid(), - new SplitSchedulerStats()); + new SplitSchedulerStats(), + (_, _) -> Optional.empty()); ImmutableMap.Builder outputBuffers = ImmutableMap.builder(); outputBuffers.put(fragment.getId(), new PartitionedPipelinedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 1)); fragment.getRemoteSourceNodes().stream() @@ -632,6 +635,7 @@ TABLE_SCAN_2_NODE_ID, new TableInfo(Optional.of("test"), new QualifiedObjectName new NoOpFailureDetector(), queryExecutor, Optional.of(new int[] {0}), + OptionalInt.empty(), 0); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java index 010989f84784..1d8c2f50419a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java @@ -50,6 +50,7 @@ import java.net.URI; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.concurrent.atomic.AtomicReference; @@ -405,6 +406,7 @@ private static PlanFragment createFragment() Optional.empty(), ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java index 7458a34a84a9..baeeb450f88d 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSchedulingUtils.java @@ -42,6 +42,7 @@ import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -361,6 +362,7 @@ private static SubPlan createSubPlan(String fragmentId, PlanNode plan, List Optional.empty()); ImmutableMap.Builder outputBuffers = ImmutableMap.builder(); outputBuffers.put(fragment.getId(), new PartitionedPipelinedOutputBufferManager(FIXED_HASH_DISTRIBUTION, 1)); fragment.getRemoteSourceNodes().stream() @@ -776,6 +779,7 @@ private StageExecution createStageExecution(PlanFragment fragment, NodeTaskMap n new NoOpFailureDetector(), queryExecutor, Optional.of(new int[] {0}), + OptionalInt.empty(), 0); } diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java index 7eeee548fbbb..476ad1b42ab4 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestExponentialGrowthPartitionMemoryEstimator.java @@ -37,6 +37,7 @@ import java.net.URI; import java.util.Optional; +import java.util.OptionalInt; import java.util.function.Function; import static io.airlift.units.DataSize.Unit.GIGABYTE; @@ -267,6 +268,7 @@ private static PlanFragment getPlanFragment(PartitioningHandle partitioningHandl Optional.empty(), ImmutableList.of(), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java index 57520afd8a46..f6f4fa0be040 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/faulttolerant/TestNoMemoryAwarePartitionMemoryEstimator.java @@ -46,6 +46,7 @@ import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import java.util.OptionalLong; import java.util.function.Function; import java.util.stream.Stream; @@ -200,6 +201,7 @@ private static PlanFragment getParentFragment(PlanFragment... childFragments) Optional.empty(), ImmutableList.of(), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), @@ -246,6 +248,7 @@ private static PlanFragment tableScanPlanFragment(String fragmentId, TableHandle Optional.empty(), ImmutableList.of(), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of()), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java index c4703c992405..3e2812f8876c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java @@ -34,6 +34,7 @@ import io.trino.testing.TestingMetadata; import java.util.Optional; +import java.util.OptionalInt; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; @@ -192,6 +193,7 @@ private static PlanFragment createFragment(PlanNode planNode) Optional.empty(), ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java index a1c20c828588..224c0f97358d 100644 --- a/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java +++ b/core/trino-main/src/test/java/io/trino/operator/exchange/TestLocalExchange.java @@ -19,11 +19,6 @@ import io.trino.SequencePageBuilder; import io.trino.Session; import io.trino.block.BlockAssertions; -import io.trino.execution.NodeTaskMap; -import io.trino.execution.scheduler.NodeScheduler; -import io.trino.execution.scheduler.NodeSchedulerConfig; -import io.trino.execution.scheduler.UniformNodeSelectorFactory; -import io.trino.metadata.InMemoryNodeManager; import io.trino.operator.PageAssertions; import io.trino.operator.exchange.LocalExchange.LocalExchangeSinkFactory; import io.trino.spi.Page; @@ -37,10 +32,9 @@ import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; -import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.PartitioningHandle; import io.trino.testing.TestingTransactionHandle; -import io.trino.util.FinalizerService; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -89,20 +83,16 @@ public class TestLocalExchange private static final Session SESSION = testSessionBuilder().build(); private static final DataSize WRITER_SCALING_MIN_DATA_PROCESSED = DataSize.of(32, MEGABYTE); private static final Supplier TOTAL_MEMORY_USED = () -> 0L; + private static final Optional BUCKET_COUNT = Optional.of(8); private final ConcurrentMap partitionManagers = new ConcurrentHashMap<>(); - private NodePartitioningManager nodePartitioningManager; + private PartitionFunctionProvider functionProvider; private final PartitioningHandle customScalingPartitioningHandle = getCustomScalingPartitioningHandle(); @BeforeEach public void setUp() { - NodeScheduler nodeScheduler = new NodeScheduler(new UniformNodeSelectorFactory( - new InMemoryNodeManager(), - new NodeSchedulerConfig().setIncludeCoordinator(true), - new NodeTaskMap(new FinalizerService()))); - nodePartitioningManager = new NodePartitioningManager( - nodeScheduler, + functionProvider = new PartitionFunctionProvider( new TypeOperators(), catalogHandle -> { ConnectorNodePartitioningProvider result = partitionManagers.get(catalogHandle); @@ -115,10 +105,11 @@ public void setUp() public void testGatherSingleWriter() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 8, SINGLE_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(99)), @@ -188,10 +179,11 @@ public void testGatherSingleWriter() public void testRandom() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 2, FIXED_ARBITRARY_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, @@ -237,10 +229,11 @@ public void testRandom() public void testScaleWriter() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 3, SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(4)), @@ -296,10 +289,11 @@ public void testScaleWriter() public void testNoWriterScalingWhenOnlyBufferSizeLimitIsExceeded() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 3, SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(4)), @@ -343,13 +337,14 @@ public void testScalingWithTwoDifferentPartitions() private void testScalingWithTwoDifferentPartitions(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB") .setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "256MB") .build(), 4, partitioningHandle, + BUCKET_COUNT, ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), @@ -452,12 +447,13 @@ public void testScaledWriterRoundRobinExchangerWhenTotalMemoryUsedIsGreaterThanL { AtomicLong totalMemoryUsed = new AtomicLong(); LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "11MB") .build(), 3, SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(4)), @@ -497,10 +493,11 @@ public void testScaledWriterRoundRobinExchangerWhenTotalMemoryUsedIsGreaterThanL public void testNoWriterScalingWhenOnlyWriterScalingMinDataProcessedLimitIsExceeded() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 3, SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(20)), @@ -547,12 +544,13 @@ public void testScalingForSkewedWriters() private void testScalingForSkewedWriters(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB") .build(), 4, partitioningHandle, + BUCKET_COUNT, ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), @@ -642,12 +640,13 @@ public void testNoScalingWhenDataWrittenIsLessThanMinFileSize() private void testNoScalingWhenDataWrittenIsLessThanMinFileSize(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB") .build(), 4, partitioningHandle, + BUCKET_COUNT, ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), @@ -711,12 +710,13 @@ public void testNoScalingWhenBufferUtilizationIsLessThanLimit() private void testNoScalingWhenBufferUtilizationIsLessThanLimit(PartitioningHandle partitioningHandle) { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB") .build(), 4, partitioningHandle, + BUCKET_COUNT, ImmutableList.of(0), TYPES, DataSize.of(50, MEGABYTE), @@ -781,13 +781,14 @@ private void testNoScalingWhenTotalMemoryUsedIsGreaterThanLimit(PartitioningHand { AtomicLong totalMemoryUsed = new AtomicLong(); LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB") .setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "20MB") .build(), 4, partitioningHandle, + BUCKET_COUNT, ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), @@ -867,13 +868,14 @@ private void testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(PartitioningHandl { AtomicLong totalMemoryUsed = new AtomicLong(); LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB") .setSystemProperty(QUERY_MAX_MEMORY_PER_NODE, "20MB") .build(), 4, partitioningHandle, + BUCKET_COUNT, ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), @@ -960,12 +962,13 @@ private void testDoNotUpdateScalingStateWhenMemoryIsAboveLimit(PartitioningHandl public void testNoScalingWhenNoWriterSkewness() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, testSessionBuilder() .setSystemProperty(SKEWED_PARTITION_MIN_DATA_PROCESSED_REBALANCE_THRESHOLD, "20kB") .build(), 2, SCALED_WRITER_HASH_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(0), TYPES, DataSize.ofBytes(retainedSizeOfPages(2)), @@ -1009,10 +1012,11 @@ public void testNoScalingWhenNoWriterSkewness() public void testPassthrough() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 2, FIXED_PASSTHROUGH_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(retainedSizeOfPages(1)), @@ -1076,10 +1080,11 @@ public void testPassthrough() public void testPartition() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 2, FIXED_HASH_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(0), TYPES, LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, @@ -1172,10 +1177,11 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa Optional.of(TestingTransactionHandle.create()), connectorPartitioningHandle); LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 2, partitioningHandle, + BUCKET_COUNT, ImmutableList.of(1), ImmutableList.of(BIGINT), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, @@ -1223,10 +1229,11 @@ public BucketFunction getBucketFunction(ConnectorTransactionHandle transactionHa public void writeUnblockWhenAllReadersFinish() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 2, FIXED_ARBITRARY_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), LOCAL_EXCHANGE_MAX_BUFFERED_BYTES, @@ -1270,10 +1277,11 @@ public void writeUnblockWhenAllReadersFinish() public void writeUnblockWhenAllReadersFinishAndPagesConsumed() { LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + functionProvider, SESSION, 2, FIXED_PASSTHROUGH_DISTRIBUTION, + BUCKET_COUNT, ImmutableList.of(), ImmutableList.of(), DataSize.ofBytes(2), diff --git a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java index d4c63a3a32e3..c24c2b8a7a59 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/JoinTestUtils.java @@ -39,7 +39,7 @@ import io.trino.spiller.SingleStreamSpiller; import io.trino.spiller.SingleStreamSpillerFactory; import io.trino.sql.gen.JoinFilterFunctionCompiler; -import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.plan.PlanNodeId; import java.util.ArrayList; @@ -120,7 +120,7 @@ public static void instantiateBuildDrivers(BuildSideSetup buildSideSetup, TaskCo } public static BuildSideSetup setupBuildSide( - NodePartitioningManager nodePartitioningManager, + PartitionFunctionProvider partitionFunctionProvider, boolean parallelBuild, TaskContext taskContext, RowPagesBuilder buildPages, @@ -138,10 +138,11 @@ public static BuildSideSetup setupBuildSide( .map(types::get) .collect(toImmutableList()); LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + partitionFunctionProvider, taskContext.getSession(), partitionCount, FIXED_HASH_DISTRIBUTION, + Optional.empty(), hashChannels, hashChannelTypes, DataSize.of(32, DataSize.Unit.MEGABYTE), diff --git a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java index b871b7eea072..6c63fb0644f6 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/TestHashJoinOperator.java @@ -24,15 +24,10 @@ import io.trino.ExceededMemoryLimitException; import io.trino.RowPagesBuilder; import io.trino.connector.CatalogServiceProvider; -import io.trino.execution.NodeTaskMap; import io.trino.execution.StageId; import io.trino.execution.TaskId; import io.trino.execution.TaskStateMachine; -import io.trino.execution.scheduler.NodeScheduler; -import io.trino.execution.scheduler.NodeSchedulerConfig; -import io.trino.execution.scheduler.UniformNodeSelectorFactory; import io.trino.memory.context.LocalMemoryContext; -import io.trino.metadata.InMemoryNodeManager; import io.trino.operator.Driver; import io.trino.operator.DriverContext; import io.trino.operator.Operator; @@ -57,11 +52,10 @@ import io.trino.spiller.GenericPartitioningSpillerFactory; import io.trino.spiller.PartitioningSpillerFactory; import io.trino.spiller.SingleStreamSpillerFactory; -import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; import io.trino.testing.TestingTaskContext; -import io.trino.util.FinalizerService; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -127,11 +121,7 @@ public class TestHashJoinOperator private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - private final NodePartitioningManager nodePartitioningManager = new NodePartitioningManager( - new NodeScheduler(new UniformNodeSelectorFactory( - new InMemoryNodeManager(), - new NodeSchedulerConfig().setIncludeCoordinator(true), - new NodeTaskMap(new FinalizerService()))), + private final PartitionFunctionProvider partitionFunctionProvider = new PartitionFunctionProvider( TYPE_OPERATORS, CatalogServiceProvider.fail()); @@ -156,7 +146,7 @@ private void testInnerJoin(boolean parallelBuild) // build factory RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -196,7 +186,7 @@ public void testInnerJoinWithRunLengthEncodedProbe() RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR)) .addSequencePage(10, 20) .addSequencePage(10, 21); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, false, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, false, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -246,7 +236,7 @@ public void testYield() int entries = 40; RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(BIGINT)) .addSequencePage(entries, 42); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, true, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, true, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe matching the above 40 entries @@ -391,7 +381,7 @@ private void innerJoinWithSpill(List whenSpill, SingleStreamSpillerFa .addSequencePage(4, 30, 300) .addSequencePage(4, 40, 400); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, true, taskContext, buildPages, Optional.of(filterFunction), true, buildSpillerFactory); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, true, taskContext, buildPages, Optional.of(filterFunction), true, buildSpillerFactory); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -562,7 +552,7 @@ public void testBuildGracefulSpill() DummySpillerFactory buildSpillerFactory = new DummySpillerFactory(); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, true, taskContext, buildPages, Optional.empty(), true, buildSpillerFactory); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, true, taskContext, buildPages, Optional.empty(), true, buildSpillerFactory); instantiateBuildDrivers(buildSideSetup, taskContext); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); @@ -611,7 +601,7 @@ private void testInnerJoinWithNullProbe(boolean parallelBuild) .row("a") .row("b") .row("c"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -656,7 +646,7 @@ private void testInnerJoinWithOutputSingleMatch(boolean parallelBuild) .row("a") .row("a") .row("b"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -701,7 +691,7 @@ private void testInnerJoinWithNullBuild(boolean parallelBuild) .row((String) null) .row("a") .row("b"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -747,7 +737,7 @@ private void testInnerJoinWithNullOnBothSides(boolean parallelBuild) .row((String) null) .row("a") .row("b"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -790,7 +780,7 @@ private void testProbeOuterJoin(boolean parallelBuild) List buildTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -845,7 +835,7 @@ private void testProbeOuterJoinWithFilterFunction(boolean parallelBuild) List buildTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -899,7 +889,7 @@ private void testOuterJoinWithNullProbe(boolean parallelBuild) .row("a") .row("b") .row("c"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -950,7 +940,7 @@ private void testOuterJoinWithNullProbeAndFilterFunction(boolean parallelBuild) .row("a") .row("b") .row("c"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1000,7 +990,7 @@ private void testOuterJoinWithNullBuild(boolean parallelBuild) .row((String) null) .row("a") .row("b"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1051,7 +1041,7 @@ private void testOuterJoinWithNullBuildAndFilterFunction(boolean parallelBuild) .row((String) null) .row("a") .row("b"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1097,7 +1087,7 @@ private void testOuterJoinWithNullOnBothSides(boolean parallelBuild) .row((String) null) .row("a") .row("b"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1149,7 +1139,7 @@ private void testOuterJoinWithNullOnBothSidesAndFilterFunction(boolean parallelB .row((String) null) .row("a") .row("b"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1192,7 +1182,7 @@ private void testMemoryLimit(boolean parallelBuild) RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); instantiateBuildDrivers(buildSideSetup, taskContext); assertThatThrownBy(() -> buildLookupSource(executor, buildSideSetup)) @@ -1410,7 +1400,7 @@ private void testInnerJoinWithEmptyLookupSource(boolean parallelBuild) // build factory List buildTypes = ImmutableList.of(VARCHAR); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1453,7 +1443,7 @@ private void testLookupOuterJoinWithEmptyLookupSource(boolean parallelBuild) // build factory List buildTypes = ImmutableList.of(VARCHAR); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1496,7 +1486,7 @@ private void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild) // build factory List buildTypes = ImmutableList.of(VARCHAR); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1548,7 +1538,7 @@ private void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild) // build factory List buildTypes = ImmutableList.of(VARCHAR); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1604,7 +1594,7 @@ private void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallel .row("b") .row((String) null) .row("c"); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1729,7 +1719,7 @@ private OperatorFactory createJoinOperatorFactoryWithBlockingLookupSource(TaskCo // build factory List buildTypes = ImmutableList.of(VARCHAR); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), false, SINGLE_STREAM_SPILLER_FACTORY); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java index 04d9d128d914..c72bb458c4ba 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/JoinTestUtils.java @@ -36,7 +36,7 @@ import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; import io.trino.sql.gen.JoinFilterFunctionCompiler; -import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.plan.PlanNodeId; import java.util.ArrayList; @@ -109,17 +109,17 @@ public static void instantiateBuildDrivers(BuildSideSetup buildSideSetup, TaskCo } public static BuildSideSetup setupBuildSide( - NodePartitioningManager nodePartitioningManager, + PartitionFunctionProvider partitionFunctionProvider, boolean parallelBuild, TaskContext taskContext, RowPagesBuilder buildPages, Optional filterFunction) { - return setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, filterFunction, true); + return setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, filterFunction, true); } public static BuildSideSetup setupBuildSide( - NodePartitioningManager nodePartitioningManager, + PartitionFunctionProvider partitionFunctionProvider, boolean parallelBuild, TaskContext taskContext, RowPagesBuilder buildPages, @@ -136,10 +136,11 @@ public static BuildSideSetup setupBuildSide( .map(types::get) .collect(toImmutableList()); LocalExchange localExchange = new LocalExchange( - nodePartitioningManager, + partitionFunctionProvider, taskContext.getSession(), partitionCount, FIXED_HASH_DISTRIBUTION, + Optional.empty(), hashChannels, hashChannelTypes, DataSize.of(32, DataSize.Unit.MEGABYTE), diff --git a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java index 69a5a6aec020..31c1e72b4001 100644 --- a/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/join/unspilled/TestHashJoinOperator.java @@ -22,11 +22,6 @@ import io.trino.ExceededMemoryLimitException; import io.trino.RowPagesBuilder; import io.trino.connector.CatalogServiceProvider; -import io.trino.execution.NodeTaskMap; -import io.trino.execution.scheduler.NodeScheduler; -import io.trino.execution.scheduler.NodeSchedulerConfig; -import io.trino.execution.scheduler.UniformNodeSelectorFactory; -import io.trino.metadata.InMemoryNodeManager; import io.trino.operator.DriverContext; import io.trino.operator.JoinOperatorType; import io.trino.operator.Operator; @@ -43,11 +38,10 @@ import io.trino.spi.block.VariableWidthBlock; import io.trino.spi.type.Type; import io.trino.spi.type.TypeOperators; -import io.trino.sql.planner.NodePartitioningManager; +import io.trino.sql.planner.PartitionFunctionProvider; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.testing.MaterializedResult; import io.trino.testing.TestingTaskContext; -import io.trino.util.FinalizerService; import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; @@ -93,11 +87,7 @@ public class TestHashJoinOperator private final ExecutorService executor = newCachedThreadPool(daemonThreadsNamed("test-executor-%s")); private final ScheduledExecutorService scheduledExecutor = newScheduledThreadPool(2, daemonThreadsNamed(getClass().getSimpleName() + "-scheduledExecutor-%s")); - private final NodePartitioningManager nodePartitioningManager = new NodePartitioningManager( - new NodeScheduler(new UniformNodeSelectorFactory( - new InMemoryNodeManager(), - new NodeSchedulerConfig().setIncludeCoordinator(true), - new NodeTaskMap(new FinalizerService()))), + private final PartitionFunctionProvider partitionFunctionProvider = new PartitionFunctionProvider( TYPE_OPERATORS, CatalogServiceProvider.fail()); @@ -122,7 +112,7 @@ private void testInnerJoin(boolean parallelBuild) // build factory RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty()); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty()); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -171,7 +161,7 @@ private void testInnerJoinWithRunLengthEncodedProbe(boolean withFilter, boolean .row("20", 1L) .row("21", 2L) .row("21", 3L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, false, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, false, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -254,7 +244,7 @@ private void testYield(boolean singleBigintLookupSource) int entries = 40; RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(BIGINT)) .addSequencePage(entries, 42); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, true, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, true, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe matching the above 40 entries @@ -320,7 +310,7 @@ private void testInnerJoinWithNullProbe(boolean parallelBuild, boolean singleBig .row(1L) .row(2L) .row(3L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -367,7 +357,7 @@ private void testInnerJoinWithOutputSingleMatch(boolean parallelBuild, boolean s .row(1L) .row(1L) .row(2L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -412,7 +402,7 @@ private void testInnerJoinWithNullBuild(boolean parallelBuild) .row((String) null) .row(1L) .row(2L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty()); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty()); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -458,7 +448,7 @@ private void testInnerJoinWithNullOnBothSides(boolean parallelBuild) .row((String) null) .row(1L) .row(2L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty()); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty()); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -501,7 +491,7 @@ private void testProbeOuterJoin(boolean parallelBuild) List buildTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty()); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty()); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -556,7 +546,7 @@ private void testProbeOuterJoinWithFilterFunction(boolean parallelBuild) List buildTypes = ImmutableList.of(VARCHAR, BIGINT, BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction)); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction)); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -612,7 +602,7 @@ private void testOuterJoinWithNullProbe(boolean parallelBuild, boolean singleBig .row(1L) .row(2L) .row(3L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -665,7 +655,7 @@ private void testOuterJoinWithNullProbeAndFilterFunction(boolean parallelBuild, .row(1L) .row(2L) .row(3L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -717,7 +707,7 @@ private void testOuterJoinWithNullBuild(boolean parallelBuild, boolean singleBig .row((String) null) .row(1L) .row(2L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -770,7 +760,7 @@ private void testOuterJoinWithNullBuildAndFilterFunction(boolean parallelBuild, .row((String) null) .row(1L) .row(2L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -818,7 +808,7 @@ private void testOuterJoinWithNullOnBothSides(boolean parallelBuild, boolean sin .row((String) null) .row(1L) .row(2L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -872,7 +862,7 @@ private void testOuterJoinWithNullOnBothSidesAndFilterFunction(boolean parallelB .row((String) null) .row(1L) .row(2L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.of(filterFunction), singleBigintLookupSource); JoinBridgeManager lookupSourceFactory = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -915,7 +905,7 @@ private void testMemoryLimit(boolean parallelBuild) RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), ImmutableList.of(VARCHAR, BIGINT, BIGINT)) .addSequencePage(10, 20, 30, 40); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty()); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty()); instantiateBuildDrivers(buildSideSetup, taskContext); assertThatThrownBy(() -> buildLookupSource(executor, buildSideSetup)) @@ -939,7 +929,7 @@ private void testInnerJoinWithEmptyLookupSource(boolean parallelBuild, boolean s // build factory List buildTypes = ImmutableList.of(BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -982,7 +972,7 @@ private void testLookupOuterJoinWithEmptyLookupSource(boolean parallelBuild, boo // build factory List buildTypes = ImmutableList.of(BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1025,7 +1015,7 @@ private void testProbeOuterJoinWithEmptyLookupSource(boolean parallelBuild, bool // build factory List buildTypes = ImmutableList.of(BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1077,7 +1067,7 @@ private void testFullOuterJoinWithEmptyLookupSource(boolean parallelBuild, boole // build factory List buildTypes = ImmutableList.of(BIGINT); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1133,7 +1123,7 @@ private void testInnerJoinWithNonEmptyLookupSourceAndEmptyProbe(boolean parallel .row(2L) .row((String) null) .row(3L); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty(), singleBigintLookupSource); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory @@ -1256,7 +1246,7 @@ private OperatorFactory createJoinOperatorFactoryWithBlockingLookupSource(TaskCo // build factory List buildTypes = ImmutableList.of(VARCHAR); RowPagesBuilder buildPages = rowPagesBuilder(Ints.asList(0), buildTypes); - BuildSideSetup buildSideSetup = setupBuildSide(nodePartitioningManager, parallelBuild, taskContext, buildPages, Optional.empty()); + BuildSideSetup buildSideSetup = setupBuildSide(partitionFunctionProvider, parallelBuild, taskContext, buildPages, Optional.empty()); JoinBridgeManager lookupSourceFactoryManager = buildSideSetup.getLookupSourceFactoryManager(); // probe factory diff --git a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java index b394053400e5..b5c2a10b19b7 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java +++ b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java @@ -54,6 +54,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.OptionalInt; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; @@ -1062,6 +1063,7 @@ private static PlanFragment createPlan( Optional.empty(), ImmutableList.of(tableScanNodeId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), + OptionalInt.empty(), StatsAndCosts.empty(), ImmutableList.of(), ImmutableMap.of(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java index a1f1cc071e3c..a3243e3f8a06 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTopologicalOrderSubPlanVisitor.java @@ -35,6 +35,7 @@ import java.util.List; import java.util.Optional; +import java.util.OptionalInt; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.spi.type.VarcharType.VARCHAR; @@ -192,6 +193,7 @@ private static SubPlan createSubPlan(String fragmentId, PlanNode plan, List outp ImmutableList.copyOf(outputSymbols), false, Optional.empty(), + Optional.empty(), Optional.of(partitionCount))); } @@ -929,6 +930,7 @@ public ExchangeBuilder fixedArbitraryDistributionPartitioningScheme(List ImmutableList.copyOf(outputSymbols), false, Optional.empty(), + Optional.empty(), Optional.of(partitionCount))); }