diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java index 29327a968b02..ab8108ecfbf1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlStage.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlStage.java @@ -86,7 +86,7 @@ public static SqlStage createSqlStage( { requireNonNull(stageId, "stageId is null"); requireNonNull(fragment, "fragment is null"); - checkArgument(fragment.getPartitioningScheme().getBucketToPartition().isEmpty(), "bucket to partition is not expected to be set at this point"); + checkArgument(fragment.getOutputPartitioningScheme().getBucketToPartition().isEmpty(), "bucket to partition is not expected to be set at this point"); requireNonNull(tables, "tables is null"); requireNonNull(remoteTaskFactory, "remoteTaskFactory is null"); requireNonNull(session, "session is null"); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java index 65650abe4da0..dd89d6e1cf78 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java @@ -78,7 +78,7 @@ public SqlTaskExecution create( taskContext, fragment.getRoot(), TypeProvider.copyOf(fragment.getSymbols()), - fragment.getPartitioningScheme(), + fragment.getOutputPartitioningScheme(), fragment.getPartitionedSources(), outputBuffer); } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java index 9913528031b2..3cf04981d607 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java @@ -802,7 +802,7 @@ private void createStageExecution(SubPlan subPlan, boolean rootFragment, int sch stage::recordGetSplitTime, outputDataSizeEstimates.buildOrThrow())); - FaultTolerantPartitioningScheme sinkPartitioningScheme = partitioningSchemeFactory.get(fragment.getPartitioningScheme().getPartitioning().getHandle()); + FaultTolerantPartitioningScheme sinkPartitioningScheme = partitioningSchemeFactory.get(fragment.getOutputPartitioningScheme().getPartitioning().getHandle()); ExchangeContext exchangeContext = new ExchangeContext(queryStateMachine.getQueryId(), new ExchangeId("external-exchange-" + stage.getStageId().getId())); boolean preserveOrderWithinPartition = rootFragment && stage.getFragment().getPartitioning().equals(SINGLE_DISTRIBUTION); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java index 1f74a93315d9..32ab66b9f5cf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java @@ -581,7 +581,7 @@ private static Map createOutputBuf private static PipelinedOutputBufferManager createSingleStreamOutputBuffer(SqlStage stage) { - PartitioningHandle partitioningHandle = stage.getFragment().getPartitioningScheme().getPartitioning().getHandle(); + PartitioningHandle partitioningHandle = stage.getFragment().getOutputPartitioningScheme().getPartitioning().getHandle(); checkArgument(partitioningHandle.isSingleNode(), "partitioning is expected to be single node: " + partitioningHandle); return new PartitionedPipelinedOutputBufferManager(partitioningHandle, 1); } @@ -946,7 +946,7 @@ private static Map> createBucketToPartitionMap( partitioningCache, fragment.getRoot(), fragment.getRemoteSourceNodes(), - fragment.getPartitioningScheme().getPartitionCount()); + fragment.getPartitionCount()); for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) { result.put(childStage.getFragment().getId(), bucketToPartition); } @@ -989,7 +989,7 @@ private static Map createOutputBuf for (SqlStage parentStage : stageManager.getDistributedStagesInTopologicalOrder()) { for (SqlStage childStage : stageManager.getChildren(parentStage.getStageId())) { PlanFragmentId fragmentId = childStage.getFragment().getId(); - PartitioningHandle partitioningHandle = childStage.getFragment().getPartitioningScheme().getPartitioning().getHandle(); + PartitioningHandle partitioningHandle = childStage.getFragment().getOutputPartitioningScheme().getPartitioning().getHandle(); PipelinedOutputBufferManager outputBufferManager; if (partitioningHandle.equals(FIXED_BROADCAST_DISTRIBUTION)) { @@ -1026,7 +1026,7 @@ private static StageScheduler createStageScheduler( Session session = queryStateMachine.getSession(); PlanFragment fragment = stageExecution.getFragment(); PartitioningHandle partitioningHandle = fragment.getPartitioning(); - Optional partitionCount = fragment.getPartitioningScheme().getPartitionCount(); + Optional partitionCount = fragment.getPartitionCount(); Map splitSources = splitSourceFactory.createSplitSources(session, fragment); if (!splitSources.isEmpty()) { queryStateMachine.addStateChangeListener(new StateChangeListener<>() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java index 474564f1dd43..da1c27ac7e39 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragment.java @@ -44,12 +44,13 @@ public class PlanFragment private final PlanNode root; private final Map symbols; private final PartitioningHandle partitioning; + private final Optional partitionCount; private final List partitionedSources; private final Set partitionedSourcesSet; private final List types; private final Set partitionedSourceNodes; private final List remoteSourceNodes; - private final PartitioningScheme partitioningScheme; + private final PartitioningScheme outputPartitioningScheme; private final StatsAndCosts statsAndCosts; private final List activeCatalogs; private final Optional jsonRepresentation; @@ -60,12 +61,13 @@ private PlanFragment( PlanNode root, Map symbols, PartitioningHandle partitioning, + Optional partitionCount, List partitionedSources, Set partitionedSourcesSet, List types, Set partitionedSourceNodes, List remoteSourceNodes, - PartitioningScheme partitioningScheme, + PartitioningScheme outputPartitioningScheme, StatsAndCosts statsAndCosts, List activeCatalogs) { @@ -73,12 +75,13 @@ private PlanFragment( this.root = requireNonNull(root, "root is null"); this.symbols = requireNonNull(symbols, "symbols is null"); this.partitioning = requireNonNull(partitioning, "partitioning is null"); + this.partitionCount = requireNonNull(partitionCount, "partitionCount is null"); this.partitionedSources = requireNonNull(partitionedSources, "partitionedSources is null"); this.partitionedSourcesSet = requireNonNull(partitionedSourcesSet, "partitionedSourcesSet is null"); this.types = requireNonNull(types, "types is null"); this.partitionedSourceNodes = requireNonNull(partitionedSourceNodes, "partitionedSourceNodes is null"); this.remoteSourceNodes = requireNonNull(remoteSourceNodes, "remoteSourceNodes is null"); - this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); + this.outputPartitioningScheme = requireNonNull(outputPartitioningScheme, "outputPartitioningScheme is null"); this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); this.jsonRepresentation = Optional.empty(); @@ -90,8 +93,9 @@ public PlanFragment( @JsonProperty("root") PlanNode root, @JsonProperty("symbols") Map symbols, @JsonProperty("partitioning") PartitioningHandle partitioning, + @JsonProperty("partitionCount") Optional partitionCount, @JsonProperty("partitionedSources") List partitionedSources, - @JsonProperty("partitioningScheme") PartitioningScheme partitioningScheme, + @JsonProperty("outputPartitioningScheme") PartitioningScheme outputPartitioningScheme, @JsonProperty("statsAndCosts") StatsAndCosts statsAndCosts, @JsonProperty("activeCatalogs") List activeCatalogs, @JsonProperty("jsonRepresentation") Optional jsonRepresentation) @@ -100,17 +104,22 @@ public PlanFragment( this.root = requireNonNull(root, "root is null"); this.symbols = requireNonNull(symbols, "symbols is null"); this.partitioning = requireNonNull(partitioning, "partitioning is null"); + this.partitionCount = requireNonNull(partitionCount, "partitionCount is null"); this.partitionedSources = ImmutableList.copyOf(requireNonNull(partitionedSources, "partitionedSources is null")); this.partitionedSourcesSet = ImmutableSet.copyOf(partitionedSources); this.statsAndCosts = requireNonNull(statsAndCosts, "statsAndCosts is null"); this.activeCatalogs = requireNonNull(activeCatalogs, "activeCatalogs is null"); this.jsonRepresentation = requireNonNull(jsonRepresentation, "jsonRepresentation is null"); + checkArgument( + partitionCount.isEmpty() || partitioning.getConnectorHandle() instanceof SystemPartitioningHandle, + "Connector partitioning handle should be of type system partitioning when partitionCount is present"); + checkArgument(partitionedSourcesSet.size() == partitionedSources.size(), "partitionedSources contains duplicates"); - checkArgument(ImmutableSet.copyOf(root.getOutputSymbols()).containsAll(partitioningScheme.getOutputLayout()), - "Root node outputs (%s) does not include all fragment outputs (%s)", root.getOutputSymbols(), partitioningScheme.getOutputLayout()); + checkArgument(ImmutableSet.copyOf(root.getOutputSymbols()).containsAll(outputPartitioningScheme.getOutputLayout()), + "Root node outputs (%s) does not include all fragment outputs (%s)", root.getOutputSymbols(), outputPartitioningScheme.getOutputLayout()); - types = partitioningScheme.getOutputLayout().stream() + types = outputPartitioningScheme.getOutputLayout().stream() .map(symbols::get) .collect(toImmutableList()); @@ -120,7 +129,7 @@ public PlanFragment( findRemoteSourceNodes(root, remoteSourceNodes); this.remoteSourceNodes = remoteSourceNodes.build(); - this.partitioningScheme = requireNonNull(partitioningScheme, "partitioningScheme is null"); + this.outputPartitioningScheme = requireNonNull(outputPartitioningScheme, "partitioningScheme is null"); } @JsonProperty @@ -147,6 +156,12 @@ public PartitioningHandle getPartitioning() return partitioning; } + @JsonProperty + public Optional getPartitionCount() + { + return partitionCount; + } + @JsonProperty public List getPartitionedSources() { @@ -159,9 +174,9 @@ public boolean isPartitionedSources(PlanNodeId nodeId) } @JsonProperty - public PartitioningScheme getPartitioningScheme() + public PartitioningScheme getOutputPartitioningScheme() { - return partitioningScheme; + return outputPartitioningScheme; } @JsonProperty @@ -194,12 +209,13 @@ public PlanFragment withoutEmbeddedJsonRepresentation() this.root, this.symbols, this.partitioning, + this.partitionCount, this.partitionedSources, this.partitionedSourcesSet, this.types, this.partitionedSourceNodes, this.remoteSourceNodes, - this.partitioningScheme, + this.outputPartitioningScheme, this.statsAndCosts, this.activeCatalogs); } @@ -255,7 +271,7 @@ private static void findRemoteSourceNodes(PlanNode node, ImmutableList.Builder bucketToPartition) { - return new PlanFragment(id, root, symbols, partitioning, partitionedSources, partitioningScheme.withBucketToPartition(bucketToPartition), statsAndCosts, activeCatalogs, jsonRepresentation); + return new PlanFragment(id, root, symbols, partitioning, partitionCount, partitionedSources, outputPartitioningScheme.withBucketToPartition(bucketToPartition), statsAndCosts, activeCatalogs, jsonRepresentation); } @Override @@ -264,8 +280,9 @@ public String toString() return toStringHelper(this) .add("id", id) .add("partitioning", partitioning) + .add("partitionCount", partitionCount) .add("partitionedSource", partitionedSources) - .add("partitionFunction", partitioningScheme) + .add("outputPartitioningScheme", outputPartitioningScheme) .toString(); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java index 512a4ac4ae07..d1276e808c2e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanFragmenter.java @@ -171,7 +171,7 @@ private SubPlan reassignPartitioningHandleIfNecessaryHelper(Session session, Sub PartitioningHandleReassigner partitioningHandleReassigner = new PartitioningHandleReassigner(fragment.getPartitioning(), metadata, session); newRoot = SimplePlanRewriter.rewriteWith(partitioningHandleReassigner, newRoot); } - PartitioningScheme outputPartitioningScheme = fragment.getPartitioningScheme(); + PartitioningScheme outputPartitioningScheme = fragment.getOutputPartitioningScheme(); Partitioning newOutputPartitioning = outputPartitioningScheme.getPartitioning(); if (outputPartitioningScheme.getPartitioning().getHandle().getCatalogHandle().isPresent()) { // Do not replace the handle if the source's output handle is a system one, e.g. broadcast. @@ -182,6 +182,7 @@ private SubPlan reassignPartitioningHandleIfNecessaryHelper(Session session, Sub newRoot, fragment.getSymbols(), fragment.getPartitioning(), + fragment.getPartitionCount(), fragment.getPartitionedSources(), new PartitioningScheme( newOutputPartitioning, @@ -249,6 +250,7 @@ private SubPlan buildFragment(PlanNode root, FragmentProperties properties, Plan root, symbols, properties.getPartitioningHandle(), + properties.getPartitionCount(), schedulingOrder, properties.getPartitioningScheme(), statsAndCosts.getForSubplan(root), @@ -326,21 +328,33 @@ public PlanNode visitRefreshMaterializedView(RefreshMaterializedViewNode node, R @Override public PlanNode visitTableWriter(TableWriterNode node, RewriteContext context) { - node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution(scheme.getPartitioning().getHandle(), metadata, session)); + node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution( + scheme.getPartitioning().getHandle(), + scheme.getPartitionCount(), + metadata, + session)); return context.defaultRewrite(node, context.get()); } @Override public PlanNode visitTableExecute(TableExecuteNode node, RewriteContext context) { - node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution(scheme.getPartitioning().getHandle(), metadata, session)); + node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution( + scheme.getPartitioning().getHandle(), + scheme.getPartitionCount(), + metadata, + session)); return context.defaultRewrite(node, context.get()); } @Override public PlanNode visitMergeWriter(MergeWriterNode node, RewriteContext context) { - node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution(scheme.getPartitioning().getHandle(), metadata, session)); + node.getPartitioningScheme().ifPresent(scheme -> context.get().setDistribution( + scheme.getPartitioning().getHandle(), + scheme.getPartitionCount(), + metadata, + session)); return context.defaultRewrite(node, context.get()); } @@ -368,7 +382,11 @@ public PlanNode visitExchange(ExchangeNode exchange, RewriteContext childrenProperties = ImmutableList.builder(); @@ -427,6 +445,7 @@ private static class FragmentProperties private final PartitioningScheme partitioningScheme; private Optional partitioningHandle = Optional.empty(); + private Optional partitionCount = Optional.empty(); private final Set partitionedSources = new HashSet<>(); public FragmentProperties(PartitioningScheme partitioningScheme) @@ -461,10 +480,15 @@ public FragmentProperties setSingleNodeDistribution() return this; } - public FragmentProperties setDistribution(PartitioningHandle distribution, Metadata metadata, Session session) + public FragmentProperties setDistribution( + PartitioningHandle distribution, + Optional partitionCount, + Metadata metadata, + Session session) { if (partitioningHandle.isEmpty()) { partitioningHandle = Optional.of(distribution); + this.partitionCount = partitionCount; return this; } @@ -485,6 +509,7 @@ public FragmentProperties setDistribution(PartitioningHandle distribution, Metad if (isCompatibleScaledWriterPartitioning(currentPartitioning, distribution)) { this.partitioningHandle = Optional.of(distribution); + this.partitionCount = partitionCount; return this; } @@ -597,6 +622,11 @@ public PartitioningHandle getPartitioningHandle() return partitioningHandle.get(); } + public Optional getPartitionCount() + { + return partitionCount; + } + public Set getPartitionedSources() { return partitionedSources; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 5a579b77ab6c..5daba32cbf4a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -550,7 +550,7 @@ private static String formatFragment( } } - PartitioningScheme partitioningScheme = fragment.getPartitioningScheme(); + PartitioningScheme partitioningScheme = fragment.getOutputPartitioningScheme(); List layout = partitioningScheme.getOutputLayout().stream() .map(anonymizer::anonymize) .collect(toImmutableList()); @@ -584,7 +584,7 @@ private static String formatFragment( hashColumn)); } - partitioningScheme.getPartitionCount().ifPresent(partitionCount -> builder.append(format("Partition count: %s\n", partitionCount))); + fragment.getPartitionCount().ifPresent(partitionCount -> builder.append(format("Partition count: %s\n", partitionCount))); builder.append( new PlanPrinter( @@ -633,6 +633,7 @@ public static String graphvizLogicalPlan(PlanNode plan, TypeProvider types) plan, types.allTypes(), SINGLE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(plan.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), plan.getOutputSymbols()), StatsAndCosts.empty(), diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index bd8de3b69d69..8a97e251632c 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -986,7 +986,7 @@ private List createDrivers(Session session, Plan plan, OutputFactory out LocalExecutionPlan localExecutionPlan = executionPlanner.plan( taskContext, subplan.getFragment().getRoot(), - subplan.getFragment().getPartitioningScheme().getOutputLayout(), + subplan.getFragment().getOutputPartitioningScheme().getOutputLayout(), plan.getTypes(), subplan.getFragment().getPartitionedSources(), outputFactory); diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 4ef6ccb09dbc..986422e7d8d9 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -118,6 +118,7 @@ public MockRemoteTask createTableScanTask(TaskId taskId, InternalNode newNode, L Optional.empty()), ImmutableMap.of(symbol, VARCHAR), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(sourceId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java index a00d7ef1d0c0..c3c8cc8b2b57 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/TaskTestUtils.java @@ -96,6 +96,7 @@ private TaskTestUtils() {} Optional.empty()), ImmutableMap.of(SYMBOL, VARCHAR), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(SYMBOL)) .withBucketToPartition(Optional.of(new int[1])), @@ -119,6 +120,7 @@ private TaskTestUtils() {} ImmutableMap.of(DYNAMIC_FILTER_SOURCE_ID, SYMBOL)), ImmutableMap.of(SYMBOL, VARCHAR), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(SYMBOL)) .withBucketToPartition(Optional.of(new int[1])), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java index 8764a125932c..d2bc8c895139 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java @@ -180,6 +180,7 @@ private static PlanFragment createExchangePlanFragment() planNode, types.buildOrThrow(), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()), StatsAndCosts.empty(), diff --git a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java index 451da665020a..195036d70b8c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestStageStateMachine.java @@ -250,6 +250,7 @@ private static PlanFragment createValuesPlan() ImmutableList.of(new Row(ImmutableList.of(new StringLiteral("foo"))))), ImmutableMap.of(symbol, VARCHAR), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(valuesNodeId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java index 85d50277f5de..b85590be8206 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestScaledWriterScheduler.java @@ -358,6 +358,7 @@ private static PlanFragment createFragment() tableScan, ImmutableMap.of(symbol, VARCHAR), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java index a62f91458cf8..bdf008a1b559 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/TestSourcePartitionedScheduler.java @@ -705,6 +705,7 @@ private static PlanFragment createFragment() Optional.empty()), ImmutableMap.of(symbol, VARCHAR), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(TABLE_SCAN_NODE_ID), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), diff --git a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java index 31c9eb38c506..423640e5888e 100644 --- a/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/scheduler/policy/PlanUtils.java @@ -195,6 +195,7 @@ private static PlanFragment createFragment(PlanNode planNode) planNode, types.buildOrThrow(), SOURCE_DISTRIBUTION, + Optional.empty(), ImmutableList.of(planNode.getId()), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), planNode.getOutputSymbols()), StatsAndCosts.empty(), diff --git a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java index 14a3d4a234e0..9ea42011ca79 100644 --- a/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java +++ b/core/trino-main/src/test/java/io/trino/server/TestDynamicFilterService.java @@ -1084,6 +1084,7 @@ private static PlanFragment createPlan( Optional.empty()), ImmutableMap.of(symbol, VARCHAR), stagePartitioning, + Optional.empty(), ImmutableList.of(tableScanNodeId), new PartitioningScheme(Partitioning.create(SINGLE_DISTRIBUTION, ImmutableList.of()), ImmutableList.of(symbol)), StatsAndCosts.empty(), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java new file mode 100644 index 000000000000..af67cb4e21e7 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestPlanFragmentPartitionCount.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.graph.Traverser; +import io.trino.Session; +import io.trino.cost.StatsAndCosts; +import io.trino.execution.QueryManagerConfig; +import io.trino.execution.warnings.WarningCollector; +import io.trino.plugin.tpch.TpchConnectorFactory; +import io.trino.security.AllowAllAccessControl; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.OutputNode; +import io.trino.sql.planner.plan.PlanFragmentId; +import io.trino.testing.LocalQueryRunner; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Optional; +import java.util.function.Function; + +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.testing.TestingHandles.TEST_CATALOG_NAME; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.transaction.TransactionBuilder.transaction; +import static org.assertj.core.api.Assertions.assertThat; + +public class TestPlanFragmentPartitionCount +{ + private PlanFragmenter planFragmenter; + private Session session; + private LocalQueryRunner localQueryRunner; + + @BeforeClass + public void setUp() + { + session = testSessionBuilder().setCatalog(TEST_CATALOG_NAME).build(); + localQueryRunner = LocalQueryRunner.create(session); + localQueryRunner.createCatalog(TEST_CATALOG_NAME, new TpchConnectorFactory(), ImmutableMap.of()); + + planFragmenter = new PlanFragmenter( + localQueryRunner.getMetadata(), + localQueryRunner.getFunctionManager(), + localQueryRunner.getTransactionManager(), + localQueryRunner.getCatalogManager(), + new QueryManagerConfig()); + } + + @AfterClass(alwaysRun = true) + public void tearDown() + { + planFragmenter = null; + session = null; + localQueryRunner.close(); + localQueryRunner = null; + } + + @Test + public void testPartitionCountInPlanFragment() + { + PlanBuilder p = new PlanBuilder(new PlanNodeIdAllocator(), localQueryRunner.getMetadata(), session); + Symbol a = p.symbol("a", VARCHAR); + Symbol b = p.symbol("b", VARCHAR); + Symbol c = p.symbol("c", VARCHAR); + Symbol d = p.symbol("d", VARCHAR); + Symbol f = p.symbol("f", VARCHAR); + Symbol g = p.symbol("g", VARCHAR); + Symbol h = p.symbol("h", VARCHAR); + Symbol i = p.symbol("i", VARCHAR); + + OutputNode output = p.output(o -> o + .source( + p.exchange(e -> e + .type(REPARTITION) + .addSource( + p.exchange(exc -> exc + .type(REPARTITION) + .addSource( + p.join( + INNER, + p.exchange(ex -> ex + .type(REPARTITION) + .addSource(p.values(a, b)) + .addInputsSet(a, b) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(a, b), ImmutableList.of(b), 5)), + p.exchange(ex -> ex + .type(REPARTITION) + .addSource(p.values(c, d)) + .addInputsSet(c, d) + .fixedHashDistributionPartitioningScheme(ImmutableList.of(c, d), ImmutableList.of(d), 5)), + new JoinNode.EquiJoinClause(b, d))) + .addInputsSet(a, b, c, d) + .fixedArbitraryDistributionPartitioningScheme(ImmutableList.of(a, b, c, d), 2))) + .addSource(p.values(f, g, h, i)) + .addInputsSet(a, b, c, d) + .addInputsSet(f, g, h, i) + .fixedHashDistributionPartitioningScheme( + ImmutableList.of(a, b, c, d), + ImmutableList.of(b), + 3)))); + + Plan plan = new Plan(output, p.getTypes(), StatsAndCosts.empty()); + SubPlan rootSubPlan = fragment(plan); + ImmutableMap.Builder> actualPartitionCount = ImmutableMap.builder(); + Traverser.forTree(SubPlan::getChildren).depthFirstPreOrder(rootSubPlan).forEach(subPlan -> + actualPartitionCount.put(subPlan.getFragment().getId(), subPlan.getFragment().getPartitionCount())); + + Map> expectedPartitionCount = ImmutableMap.of( + // for output fragment + new PlanFragmentId("0"), Optional.of(3), + // for union exchange fragment + new PlanFragmentId("1"), Optional.of(2), + // for join fragment + new PlanFragmentId("2"), Optional.of(5), + // for all other fragments partitionCount should be empty + new PlanFragmentId("3"), Optional.empty(), + new PlanFragmentId("4"), Optional.empty(), + new PlanFragmentId("5"), Optional.empty()); + + assertThat(expectedPartitionCount).isEqualTo(actualPartitionCount.buildOrThrow()); + } + + private SubPlan fragment(Plan plan) + { + return inTransaction(session -> planFragmenter.createSubPlans(session, plan, false, WarningCollector.NOOP)); + } + + private T inTransaction(Function transactionSessionConsumer) + { + return transaction(localQueryRunner.getTransactionManager(), new AllowAllAccessControl()) + .singleStatement() + .execute(session, session -> { + // metadata.getCatalogHandle() registers the catalog for the transaction + session.getCatalog().ifPresent(catalog -> localQueryRunner.getMetadata().getCatalogHandle(session, catalog)); + return transactionSessionConsumer.apply(session); + }); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 718a01134b67..44db23d7772b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -135,6 +135,7 @@ import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.IntegerType.INTEGER; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_ARBITRARY_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.FIXED_HASH_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; @@ -889,6 +890,30 @@ public ExchangeBuilder fixedHashDistributionPartitioningScheme(List outp Optional.of(hashSymbol))); } + public ExchangeBuilder fixedHashDistributionPartitioningScheme(List outputSymbols, List partitioningSymbols, int partitionCount) + { + return partitioningScheme(new PartitioningScheme(Partitioning.create( + FIXED_HASH_DISTRIBUTION, + ImmutableList.copyOf(partitioningSymbols)), + ImmutableList.copyOf(outputSymbols), + Optional.empty(), + false, + Optional.empty(), + Optional.of(partitionCount))); + } + + public ExchangeBuilder fixedArbitraryDistributionPartitioningScheme(List outputSymbols, int partitionCount) + { + return partitioningScheme(new PartitioningScheme(Partitioning.create( + FIXED_ARBITRARY_DISTRIBUTION, + ImmutableList.of()), + ImmutableList.copyOf(outputSymbols), + Optional.empty(), + false, + Optional.empty(), + Optional.of(partitionCount))); + } + public ExchangeBuilder partitioningScheme(PartitioningScheme partitioningScheme) { this.partitioningScheme = partitioningScheme;