diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java index 2bc0b3fccb65..6d2885f7d69f 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.ptf.ConnectorTableFunctionHandle; import static java.util.Objects.requireNonNull; @@ -24,16 +25,19 @@ public class TableFunctionHandle { private final CatalogHandle catalogHandle; + private final SchemaFunctionName schemaFunctionName; private final ConnectorTableFunctionHandle functionHandle; private final ConnectorTransactionHandle transactionHandle; @JsonCreator public TableFunctionHandle( @JsonProperty("catalogHandle") CatalogHandle catalogHandle, + @JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName, @JsonProperty("functionHandle") ConnectorTableFunctionHandle functionHandle, @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle) { this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); + this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); } @@ -44,6 +48,12 @@ public CatalogHandle getCatalogHandle() return catalogHandle; } + @JsonProperty + public SchemaFunctionName getSchemaFunctionName() + { + return schemaFunctionName; + } + @JsonProperty public ConnectorTableFunctionHandle getFunctionHandle() { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index bc5bfffcda39..8f47959127fc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -2230,6 +2230,7 @@ public TableArgumentAnalysis build() public static class TableFunctionInvocationAnalysis { private final CatalogHandle catalogHandle; + private final String schemaName; private final String functionName; private final Map arguments; private final List tableArgumentAnalyses; @@ -2241,6 +2242,7 @@ public static class TableFunctionInvocationAnalysis public TableFunctionInvocationAnalysis( CatalogHandle catalogHandle, + String schemaName, String functionName, Map arguments, List tableArgumentAnalyses, @@ -2251,6 +2253,7 @@ public TableFunctionInvocationAnalysis( ConnectorTransactionHandle transactionHandle) { this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); this.functionName = requireNonNull(functionName, "functionName is null"); this.arguments = ImmutableMap.copyOf(arguments); this.tableArgumentAnalyses = ImmutableList.copyOf(tableArgumentAnalyses); @@ -2267,6 +2270,11 @@ public CatalogHandle getCatalogHandle() return catalogHandle; } + public String getSchemaName() + { + return schemaName; + } + public String getFunctionName() { return functionName; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index cf5a99c76fac..6cb9cce5f858 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -1671,6 +1671,7 @@ else if (argument.getPartitionBy().isPresent()) { analysis.setTableFunctionAnalysis(node, new TableFunctionInvocationAnalysis( catalogHandle, + function.getSchema(), function.getName(), argumentsAnalysis.getPassedArguments(), orderedTableArguments.build(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 2b7adc365b8a..b440835f90ea 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -221,6 +221,7 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; @@ -1658,6 +1659,12 @@ public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecuti throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); } + @Override + public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode node, LocalExecutionPlanContext context) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index e548c2305023..b3a8f1c6b2de 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -61,6 +61,7 @@ import io.trino.sql.planner.iterative.rule.ImplementIntersectDistinctAsUnion; import io.trino.sql.planner.iterative.rule.ImplementLimitWithTies; import io.trino.sql.planner.iterative.rule.ImplementOffset; +import io.trino.sql.planner.iterative.rule.ImplementTableFunctionSource; import io.trino.sql.planner.iterative.rule.InlineProjectIntoFilter; import io.trino.sql.planner.iterative.rule.InlineProjections; import io.trino.sql.planner.iterative.rule.MergeExcept; @@ -629,7 +630,11 @@ public PlanOptimizers( costCalculator, // Temporary hack: separate optimizer step to avoid the sample node being replaced by filter before pushing // it to table scan node - ImmutableSet.of(new ImplementBernoulliSampleAsFilter(metadata))), + ImmutableSet.of( + new ImplementBernoulliSampleAsFilter(metadata), + // Must run after RewriteTableFunctionToTableScan because that rule applies to TableFunctionNode. + // While the node gets rewritten to TableFunctionProcessorNode, we can no longer pushdown the function to the connector. + new ImplementTableFunctionSource(metadata))), columnPruningOptimizer, new IterativeOptimizer( plannerContext, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 756841691a25..1886c6679228 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -22,6 +22,7 @@ import io.trino.metadata.TableFunctionHandle; import io.trino.metadata.TableHandle; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.ExpressionUtils; @@ -47,6 +48,8 @@ import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.SampleNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.UnionNode; @@ -428,27 +431,38 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); } - sources.add(sourcePlanBuilder.getRoot()); - sourceProperties.add(new TableArgumentProperties( - tableArgument.getArgumentName(), - tableArgument.isRowSemantics(), - tableArgument.isPruneWhenEmpty(), - tableArgument.isPassThroughColumns(), - requiredColumns, - specification)); - // add output symbols passed from the table argument + ImmutableList.Builder passThroughColumns = ImmutableList.builder(); if (tableArgument.isPassThroughColumns()) { // the original output symbols from the source node, not coerced // note: hidden columns are included. They are present in sourcePlan.fieldMappings outputSymbols.addAll(sourcePlan.getFieldMappings()); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + sourcePlan.getFieldMappings().stream() + .map(symbol -> new PassThroughColumn(symbol, partitionBy.contains(symbol))) + .forEach(passThroughColumns::add); } else if (tableArgument.getPartitionBy().isPresent()) { tableArgument.getPartitionBy().get().stream() // the original symbols for partitioning columns, not coerced .map(sourcePlanBuilder::translate) - .forEach(outputSymbols::add); + .forEach(symbol -> { + outputSymbols.add(symbol); + passThroughColumns.add(new PassThroughColumn(symbol, true)); + }); } + + sources.add(sourcePlanBuilder.getRoot()); + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + new PassThroughSpecification(tableArgument.isPassThroughColumns(), passThroughColumns.build()), + requiredColumns, + specification)); } PlanNode root = new TableFunctionNode( @@ -459,7 +473,11 @@ else if (tableArgument.getPartitionBy().isPresent()) { sources.build(), sourceProperties.build(), functionAnalysis.getCopartitioningLists(), - new TableFunctionHandle(functionAnalysis.getCatalogHandle(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); + new TableFunctionHandle( + functionAnalysis.getCatalogHandle(), + new SchemaFunctionName(functionAnalysis.getSchemaName(), functionAnalysis.getFunctionName()), + functionAnalysis.getConnectorTableFunctionHandle(), + functionAnalysis.getTransactionHandle())); return new RelationPlan(root, analysis.getScope(node), outputSymbols.build(), outerContext); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java new file mode 100644 index 000000000000..4d18d1f03dfd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java @@ -0,0 +1,756 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.metadata.Metadata; +import io.trino.metadata.ResolvedFunction; +import io.trino.spi.type.Type; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; +import io.trino.sql.planner.plan.WindowNode; +import io.trino.sql.planner.plan.WindowNode.Frame; +import io.trino.sql.tree.Cast; +import io.trino.sql.tree.CoalesceExpression; +import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.GenericLiteral; +import io.trino.sql.tree.IfExpression; +import io.trino.sql.tree.LogicalExpression; +import io.trino.sql.tree.NotExpression; +import io.trino.sql.tree.NullLiteral; +import io.trino.sql.tree.QualifiedName; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.JoinNode.Type.FULL; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; +import static io.trino.sql.planner.plan.JoinNode.Type.RIGHT; +import static io.trino.sql.planner.plan.Patterns.tableFunction; +import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING; +import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; +import static io.trino.sql.tree.LogicalExpression.Operator.AND; +import static io.trino.sql.tree.LogicalExpression.Operator.OR; +import static io.trino.sql.tree.WindowFrame.Type.ROWS; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * This rule prepares cartesian product of partitions + * from all inputs of table function. + *

+ * It rewrites TableFunctionNode with potentially many sources + * into a TableFunctionProcessorNode. The new node has one + * source being a combination of the original sources. + *

+ * The original sources are combined with joins. The join + * conditions depend on the prune when empty property, and on + * the co-partitioning of sources. + *

+ * The resulting source should be partitioned and ordered + * according to combined schemas from the component sources. + *

+ * Example transformation for two sources, both with set semantics + * and KEEP WHEN EMPTY property: + *

+ * - TableFunction foo
+ *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
+ *      - source T2(a2, b2) PARTITION BY a2
+ * 
+ * Is transformed into: + *
+ * - TableFunctionProcessor foo
+ *      PARTITION BY (a1, a2), ORDER BY combined_row_number
+ *      - Project
+ *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
+ *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
+ *          - Project
+ *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
+ *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
+ *              - FULL Join
+ *                  [table1_row_number = table2_row_number OR
+ *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
+ *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
+ *                  - Window [PARTITION BY a1 ORDER BY b1]
+ *                      table1_row_number <= row_number()
+ *                      table1_partition_size <= count()
+ *                          - source T1(a1, b1)
+ *                  - Window [PARTITION BY a2]
+ *                      table2_row_number <= row_number()
+ *                      table2_partition_size <= count()
+ *                          - source T2(a2, b2)
+ * 
+ */ +public class ImplementTableFunctionSource + implements Rule +{ + private static final Pattern PATTERN = tableFunction(); + private static final Frame FULL_FRAME = new Frame( + ROWS, + UNBOUNDED_PRECEDING, + Optional.empty(), + Optional.empty(), + UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); + + private final Metadata metadata; + + public ImplementTableFunctionSource(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode node, Captures captures, Context context) + { + if (node.getSources().isEmpty()) { + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.empty(), + false, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + if (node.getSources().size() == 1) { + // Single source does not require pre-processing. + // If the source has row semantics, its specification is empty. + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // This property can be used later to choose optimal distribution. + TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(getOnlyElement(node.getSources())), + sourceProperties.isPruneWhenEmpty(), + ImmutableList.of(sourceProperties.getPassThroughSpecification()), + ImmutableList.of(sourceProperties.getRequiredColumns()), + Optional.empty(), + sourceProperties.getSpecification(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); + ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); + ResolvedFunction rowNumberFunction = metadata.resolveFunction(context.getSession(), QualifiedName.of("row_number"), ImmutableList.of()); + ResolvedFunction countFunction = metadata.resolveFunction(context.getSession(), QualifiedName.of("count"), ImmutableList.of()); + + // handle co-partitioned sources + for (List copartitioningList : node.getCopartitioningLists()) { + List sourceList = copartitioningList.stream() + .map(sources::get) + .collect(toImmutableList()); + intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context)); + } + + // prepare non-co-partitioned sources + Set copartitionedSources = node.getCopartitioningLists().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + sources.entrySet().stream() + .filter(entry -> !copartitionedSources.contains(entry.getKey())) + .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) + .forEach(intermediateResultsBuilder::add); + + NodeWithSymbols finalResultSource; + + List intermediateResultSources = intermediateResultsBuilder.build(); + if (intermediateResultSources.size() == 1) { + finalResultSource = getOnlyElement(intermediateResultSources); + } + else { + NodeWithSymbols first = intermediateResultSources.get(0); + NodeWithSymbols second = intermediateResultSources.get(1); + JoinedNodes joined = join(first, second, context); + + for (int i = 2; i < intermediateResultSources.size(); i++) { + NodeWithSymbols joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context); + joined = join(joinedWithSymbols, intermediateResultSources.get(i), context); + } + + finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context); + } + + // For each source, all source's output symbols are mapped to the source's row number symbol. + // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. + // The "filler" input rows are the rows appended while joining partitions of different lengths, + // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, + // and should not be processed by the table function. + Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); + + // The max row number symbol from all joined partitions. + Symbol finalRowNumberSymbol = finalResultSource.rowNumber(); + // Combined partitioning lists from all sources. + List finalPartitionBy = finalResultSource.partitionBy(); + + NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context); + + // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. + // In the new map, every source symbol is associated with the corresponding marker symbol. + // Null value of the marker indicates that the source value should be ignored by the table function. + ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.symbolToMarker().get(entry.getValue()))); + + // Use the final row number symbol for ordering the combined sources. + // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. + // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows + // of each input table come before the "filler" rows. + Optional finalOrderBy = Optional.of(new OrderingScheme(ImmutableList.of(finalRowNumberSymbol), ImmutableMap.of(finalRowNumberSymbol, ASC_NULLS_LAST))); + + // derive the prune when empty property + boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::isPruneWhenEmpty); + + // Combine the pass through specifications from all sources + List passThroughSpecifications = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .collect(toImmutableList()); + + // Combine the required symbols from all sources + List> requiredSymbols = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getRequiredColumns) + .collect(toImmutableList()); + + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(marked.node()), + pruneWhenEmpty, + passThroughSpecifications, + requiredSymbols, + Optional.of(markerSymbols), + Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + private static Map mapSourcesByName(List sources, List properties) + { + return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) + .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); + } + + private static NodeWithSymbols planWindowFunctionsForSource( + PlanNode source, + TableArgumentProperties argumentProperties, + ResolvedFunction rowNumberFunction, + ResolvedFunction countFunction, + Context context) + { + String argumentName = argumentProperties.getArgumentName(); + + Symbol rowNumber = context.getSymbolAllocator().newSymbol(argumentName + "_row_number", BIGINT); + Map rowNumberSymbolMapping = source.getOutputSymbols().stream() + .collect(toImmutableMap(identity(), symbol -> rowNumber)); + + Symbol partitionSize = context.getSymbolAllocator().newSymbol(argumentName + "_partition_size", BIGINT); + + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // If the source has row semantics, its specification is empty. Currently, such source is processed + // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. + DataOrganizationSpecification specification = argumentProperties.getSpecification().orElse(UNORDERED_SINGLE_PARTITION); + + PlanNode window = new WindowNode( + context.getIdAllocator().getNextId(), + source, + specification, + ImmutableMap.of( + rowNumber, new WindowNode.Function(rowNumberFunction, ImmutableList.of(), FULL_FRAME, false), + partitionSize, new WindowNode.Function(countFunction, ImmutableList.of(), FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + + return new NodeWithSymbols(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.isPruneWhenEmpty(), rowNumberSymbolMapping); + } + + private static NodeWithSymbols copartition( + List sourceList, + ResolvedFunction rowNumberFunction, + ResolvedFunction countFunction, + Context context) + { + checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); + + // Reorder the co-partitioned sources to process the sources with prune when empty property first. + // It allows to use inner or side joins instead of outer joins. + sourceList = sourceList.stream() + .sorted(Comparator.comparingInt(source -> source.properties().isPruneWhenEmpty() ? -1 : 1)) + .collect(toImmutableList()); + + NodeWithSymbols first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); + NodeWithSymbols second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); + JoinedNodes copartitioned = copartition(first, second, context); + + for (int i = 2; i < sourceList.size(); i++) { + NodeWithSymbols copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context); + NodeWithSymbols next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); + copartitioned = copartition(copartitionedWithSymbols, next, context); + } + + return appendHelperSymbolsForCopartitionedNodes(copartitioned, context); + } + + private static JoinedNodes copartition(NodeWithSymbols left, NodeWithSymbols right, Context context) + { + checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); + + // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. + // Co-partitioning tables with empty partition by would be ineffective. + checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); + + Expression leftRowNumber = left.rowNumber().toSymbolReference(); + Expression leftPartitionSize = left.partitionSize().toSymbolReference(); + List leftPartitionBy = left.partitionBy().stream() + .map(Symbol::toSymbolReference) + .collect(toImmutableList()); + Expression rightRowNumber = right.rowNumber().toSymbolReference(); + Expression rightPartitionSize = right.partitionSize().toSymbolReference(); + List rightPartitionBy = right.partitionBy().stream() + .map(Symbol::toSymbolReference) + .collect(toImmutableList()); + + List copartitionConjuncts = Streams.zip( + leftPartitionBy.stream(), + rightPartitionBy.stream(), + (leftColumn, rightColumn) -> new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, leftColumn, rightColumn))) + .collect(toImmutableList()); + + // Align matching partitions (co-partitions) from left and right source, according to row number. + // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. + // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. + // + // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... + // AND ( + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1)) + Expression joinCondition = new LogicalExpression( + AND, + ImmutableList.builder() + .addAll(copartitionConjuncts) + .add(new LogicalExpression(OR, ImmutableList.of( + new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), + new ComparisonExpression(EQUAL, rightRowNumber, new GenericLiteral("BIGINT", "1")))), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), + new ComparisonExpression(EQUAL, leftRowNumber, new GenericLiteral("BIGINT", "1"))))))) + .build()); + + // The join type depends on the prune when empty property of the sources. + // If a source is prune when empty, we should not process any co-partition which is not present in this source, + // so effectively the other source becomes inner side of the join. + // + // example: + // table T1 partition by P1 table T2 partition by P2 + // P1 C1 P2 C2 + // ---------- ---------- + // 1 'a' 2 'c' + // 2 'b' 3 'd' + // + // co-partitioning results: + // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // + // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // null null 3 'd' + // + // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // + // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // null null 3 'd' + JoinNode.Type joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + left.node().getOutputSymbols(), + right.node().getOutputSymbols(), + false, + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithSymbols appendHelperSymbolsForCopartitionedNodes( + JoinedNodes copartitionedNodes, + Context context) + { + checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); + + Expression leftRowNumber = copartitionedNodes.leftRowNumber().toSymbolReference(); + Expression leftPartitionSize = copartitionedNodes.leftPartitionSize().toSymbolReference(); + Expression rightRowNumber = copartitionedNodes.rightRowNumber().toSymbolReference(); + Expression rightPartitionSize = copartitionedNodes.rightPartitionSize().toSymbolReference(); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); + Expression rowNumberExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftRowNumber, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightRowNumber, new GenericLiteral("BIGINT", "-1"))), + leftRowNumber, + rightRowNumber); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); + Expression partitionSizeExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftPartitionSize, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightPartitionSize, new GenericLiteral("BIGINT", "-1"))), + leftPartitionSize, + rightPartitionSize); + + // Derive partitioning columns for joined partitions. + // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), + // or one of them is null as a result of outer join. + ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); + Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); + for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { + Symbol leftColumn = copartitionedNodes.leftPartitionBy().get(i); + Symbol rightColumn = copartitionedNodes.rightPartitionBy().get(i); + Type type = context.getSymbolAllocator().getTypes().get(leftColumn); + + Symbol joinedColumn = context.getSymbolAllocator().newSymbol("combined_partition_column", type); + joinedPartitionByAssignments.put(joinedColumn, new CoalesceExpression(leftColumn.toSymbolReference(), rightColumn.toSymbolReference())); + joinedPartitionBy.add(joinedColumn); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + copartitionedNodes.joinedNode(), + Assignments.builder() + .putIdentities(copartitionedNodes.joinedNode().getOutputSymbols()) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .putAll(joinedPartitionByAssignments.build()) + .build()); + + boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) + .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithSymbols(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static JoinedNodes join(NodeWithSymbols left, NodeWithSymbols right, Context context) + { + Expression leftRowNumber = left.rowNumber().toSymbolReference(); + Expression leftPartitionSize = left.partitionSize().toSymbolReference(); + Expression rightRowNumber = right.rowNumber().toSymbolReference(); + Expression rightPartitionSize = right.partitionSize().toSymbolReference(); + + // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows + // from the bigger partition, matching them to the first row from the smaller partition. + // + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1) + Expression joinCondition = new LogicalExpression(OR, ImmutableList.of( + new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), + new ComparisonExpression(EQUAL, rightRowNumber, new GenericLiteral("BIGINT", "1")))), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), + new ComparisonExpression(EQUAL, leftRowNumber, new GenericLiteral("BIGINT", "1")))))); + + JoinNode.Type joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + left.node().getOutputSymbols(), + right.node().getOutputSymbols(), + false, + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithSymbols appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context) + { + Expression leftRowNumber = joinedNodes.leftRowNumber().toSymbolReference(); + Expression leftPartitionSize = joinedNodes.leftPartitionSize().toSymbolReference(); + Expression rightRowNumber = joinedNodes.rightRowNumber().toSymbolReference(); + Expression rightPartitionSize = joinedNodes.rightPartitionSize().toSymbolReference(); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); + Expression rowNumberExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftRowNumber, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightRowNumber, new GenericLiteral("BIGINT", "-1"))), + leftRowNumber, + rightRowNumber); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); + Expression partitionSizeExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftPartitionSize, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightPartitionSize, new GenericLiteral("BIGINT", "-1"))), + leftPartitionSize, + rightPartitionSize); + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + joinedNodes.joinedNode(), + Assignments.builder() + .putIdentities(joinedNodes.joinedNode().getOutputSymbols()) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .build()); + + List joinedPartitionBy = ImmutableList.builder() + .addAll(joinedNodes.leftPartitionBy()) + .addAll(joinedNodes.rightPartitionBy()) + .build(); + + boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(joinedNodes.leftRowNumberSymbolsMapping()) + .putAll(joinedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithSymbols(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set symbols, Symbol referenceSymbol, Context context) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putIdentities(node.getOutputSymbols()); + + ImmutableMap.Builder symbolsToMarkers = ImmutableMap.builder(); + + for (Symbol symbol : symbols) { + Symbol marker = context.getSymbolAllocator().newSymbol("marker", BIGINT); + symbolsToMarkers.put(symbol, marker); + Expression actual = symbol.toSymbolReference(); + Expression reference = referenceSymbol.toSymbolReference(); + assignments.put(marker, new IfExpression(new ComparisonExpression(EQUAL, actual, reference), actual, new Cast(new NullLiteral(), toSqlType(BIGINT)))); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + node, + assignments.build()); + + return new NodeWithMarkers(project, symbolsToMarkers.buildOrThrow()); + } + + private record SourceWithProperties(PlanNode source, TableArgumentProperties properties) + { + SourceWithProperties + { + requireNonNull(source, "source is null"); + requireNonNull(properties, "properties is null"); + } + } + + private record NodeWithSymbols(PlanNode node, Symbol rowNumber, Symbol partitionSize, List partitionBy, boolean pruneWhenEmpty, Map rowNumberSymbolsMapping) + { + NodeWithSymbols + { + requireNonNull(node, "node is null"); + requireNonNull(rowNumber, "rowNumber is null"); + requireNonNull(partitionSize, "partitionSize is null"); + partitionBy = ImmutableList.copyOf(partitionBy); + rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); + } + } + + private record JoinedNodes( + PlanNode joinedNode, + Symbol leftRowNumber, + Symbol leftPartitionSize, + List leftPartitionBy, + boolean leftPruneWhenEmpty, + Map leftRowNumberSymbolsMapping, + Symbol rightRowNumber, + Symbol rightPartitionSize, + List rightPartitionBy, + boolean rightPruneWhenEmpty, + Map rightRowNumberSymbolsMapping) + { + JoinedNodes + { + requireNonNull(joinedNode, "joinedNode is null"); + requireNonNull(leftRowNumber, "leftRowNumber is null"); + requireNonNull(leftPartitionSize, "leftPartitionSize is null"); + leftPartitionBy = ImmutableList.copyOf(leftPartitionBy); + leftRowNumberSymbolsMapping = ImmutableMap.copyOf(leftRowNumberSymbolsMapping); + requireNonNull(rightRowNumber, "rightRowNumber is null"); + requireNonNull(rightPartitionSize, "rightPartitionSize is null"); + rightPartitionBy = ImmutableList.copyOf(rightPartitionBy); + rightRowNumberSymbolsMapping = ImmutableMap.copyOf(rightRowNumberSymbolsMapping); + } + } + + private record NodeWithMarkers(PlanNode node, Map symbolToMarker) + { + NodeWithMarkers + { + requireNonNull(node, "node is null"); + symbolToMarker = ImmutableMap.copyOf(symbolToMarker); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 667ab168b37c..9afa4735e3e3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -74,6 +74,7 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -372,6 +373,12 @@ public PlanWithProperties visitTableFunction(TableFunctionNode node, PreferredPr throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); } + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, PreferredProperties preferredProperties) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + @Override public PlanWithProperties visitRowNumber(RowNumberNode node, PreferredProperties preferredProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 46a823049e9c..7dbe89b5c813 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -37,6 +37,9 @@ import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.TopNRankingNode; @@ -56,6 +59,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -64,6 +68,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.sql.planner.plan.AggregationNode.groupingSets; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; public class SymbolMapper { @@ -214,16 +219,18 @@ public WindowNode map(WindowNode node, PlanNode source) newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls())); }); + SpecificationWithPreSortedPrefix newSpecification = mapAndDistinct(node.getSpecification(), node.getPreSortedOrderPrefix()); + return new WindowNode( node.getId(), source, - mapAndDistinct(node.getSpecification()), + newSpecification.specification(), newFunctions.buildOrThrow(), node.getHashSymbol().map(this::map), node.getPrePartitionedInputs().stream() .map(this::map) .collect(toImmutableSet()), - node.getPreSortedOrderPrefix()); + newSpecification.preSorted()); } private WindowNode.Frame map(WindowNode.Frame frame) @@ -240,6 +247,18 @@ private WindowNode.Frame map(WindowNode.Frame frame) frame.getOriginalEndValue()); } + private SpecificationWithPreSortedPrefix mapAndDistinct(DataOrganizationSpecification specification, int preSorted) + { + Optional newOrderingScheme = specification.getOrderingScheme() + .map(orderingScheme -> map(orderingScheme, preSorted)); + + return new SpecificationWithPreSortedPrefix( + new DataOrganizationSpecification( + mapAndDistinct(specification.getPartitionBy()), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::orderingScheme)), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::preSorted).orElse(preSorted)); + } + public DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) { return new DataOrganizationSpecification( @@ -249,6 +268,8 @@ public DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecificatio public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source) { + SpecificationWithPreSortedPrefix newSpecification = mapAndDistinct(node.getSpecification(), node.getPreSortedOrderPrefix()); + ImmutableMap.Builder newFunctions = ImmutableMap.builder(); node.getWindowFunctions().forEach((symbol, function) -> { List newArguments = function.getArguments().stream() @@ -271,12 +292,12 @@ public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source) return new PatternRecognitionNode( node.getId(), source, - mapAndDistinct(node.getSpecification()), + newSpecification.specification(), node.getHashSymbol().map(this::map), node.getPrePartitionedInputs().stream() .map(this::map) .collect(toImmutableSet()), - node.getPreSortedOrderPrefix(), + newSpecification.preSorted(), newFunctions.buildOrThrow(), newMeasures.buildOrThrow(), node.getCommonBaseFrame().map(this::map), @@ -339,6 +360,46 @@ public Expression rewriteSymbolReference(SymbolReference node, Void context, Exp expressionAndValuePointers.getMatchNumberSymbols()); } + public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source) + { + List newPassThroughSpecifications = node.getPassThroughSpecifications().stream() + .map(specification -> new PassThroughSpecification( + specification.declaredAsPassThrough(), + specification.columns().stream() + .map(column -> new PassThroughColumn( + map(column.symbol()), + column.isPartitioningColumn())) + .collect(toImmutableList()))) + .collect(toImmutableList()); + + List> newRequiredSymbols = node.getRequiredSymbols().stream() + .map(this::map) + .collect(toImmutableList()); + + Optional> newMarkerSymbols = node.getMarkerSymbols() + .map(mapping -> mapping.entrySet().stream() + .collect(toMap(entry -> map(entry.getKey()), entry -> map(entry.getValue())))); + + Optional newSpecification = node.getSpecification().map(specification -> mapAndDistinct(specification, node.getPreSorted())); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + map(node.getProperOutputs()), + Optional.of(source), + node.isPruneWhenEmpty(), + newPassThroughSpecifications, + newRequiredSymbols, + newMarkerSymbols, + newSpecification.map(SpecificationWithPreSortedPrefix::specification), + node.getPrePartitioned().stream() + .map(this::map) + .collect(toImmutableSet()), + newSpecification.map(SpecificationWithPreSortedPrefix::preSorted).orElse(node.getPreSorted()), + node.getHashSymbol().map(this::map), + node.getHandle()); + } + public LimitNode map(LimitNode node, PlanNode source) { return new LimitNode( @@ -352,6 +413,29 @@ public LimitNode map(LimitNode node, PlanNode source) .collect(toImmutableList())); } + public OrderingSchemeWithPreSortedPrefix map(OrderingScheme orderingScheme, int preSorted) + { + ImmutableList.Builder newSymbols = ImmutableList.builder(); + ImmutableMap.Builder newOrderings = ImmutableMap.builder(); + int newPreSorted = preSorted; + + Set added = new HashSet<>(orderingScheme.getOrderBy().size()); + + for (int i = 0; i < orderingScheme.getOrderBy().size(); i++) { + Symbol symbol = orderingScheme.getOrderBy().get(i); + Symbol canonical = map(symbol); + if (added.add(canonical)) { + newSymbols.add(canonical); + newOrderings.put(canonical, orderingScheme.getOrdering(symbol)); + } + else if (i < preSorted) { + newPreSorted--; + } + } + + return new OrderingSchemeWithPreSortedPrefix(new OrderingScheme(newSymbols.build(), newOrderings.buildOrThrow()), newPreSorted); + } + public OrderingScheme map(OrderingScheme orderingScheme) { ImmutableList.Builder newSymbols = ImmutableList.builder(); @@ -542,6 +626,24 @@ public TopNNode map(TopNNode node, PlanNode source, PlanNodeId nodeId) node.getStep()); } + private record OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + private OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.preSorted = preSorted; + } + } + + private record SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + private SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + this.specification = requireNonNull(specification, "specification is null"); + this.preSorted = preSorted; + } + } + public static Builder builder() { return new Builder(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index a9d9f52e466d..7d92beb9e1c4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -76,7 +76,10 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -336,11 +339,18 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext SymbolMapper inputMapper = symbolMapper(new HashMap<>(newSource.getMappings())); TableArgumentProperties properties = node.getTableArgumentProperties().get(i); Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + PassThroughSpecification newPassThroughSpecification = new PassThroughSpecification( + properties.getPassThroughSpecification().declaredAsPassThrough(), + properties.getPassThroughSpecification().columns().stream() + .map(column -> new PassThroughColumn( + inputMapper.map(column.symbol()), + column.isPartitioningColumn())) + .collect(toImmutableList())); newTableArgumentProperties.add(new TableArgumentProperties( properties.getArgumentName(), properties.isRowSemantics(), properties.isPruneWhenEmpty(), - properties.isPassThroughColumns(), + newPassThroughSpecification, inputMapper.map(properties.getRequiredColumns()), newSpecification)); } @@ -358,6 +368,39 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext mapping); } + @Override + public PlanAndMappings visitTableFunctionProcessor(TableFunctionProcessorNode node, UnaliasContext context) + { + if (node.getSource().isEmpty()) { + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); + return new PlanAndMappings( + new TableFunctionProcessorNode( + node.getId(), + node.getName(), + mapper.map(node.getProperOutputs()), + Optional.empty(), + node.isPruneWhenEmpty(), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + node.getHashSymbol().map(mapper::map), + node.getHandle()), + mapping); + } + + PlanAndMappings rewrittenSource = node.getSource().orElseThrow().accept(this, context); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); + + TableFunctionProcessorNode rewrittenTableFunctionProcessor = mapper.map(node, rewrittenSource.getRoot()); + + return new PlanAndMappings(rewrittenTableFunctionProcessor, mapping); + } + @Override public PlanAndMappings visitTableScan(TableScanNode node, UnaliasContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index 5691480683d4..013d8fe1f79e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -69,6 +69,7 @@ @JsonSubTypes.Type(value = StatisticsWriterNode.class, name = "statisticsWriterNode"), @JsonSubTypes.Type(value = PatternRecognitionNode.class, name = "patternRecognition"), @JsonSubTypes.Type(value = TableFunctionNode.class, name = "tableFunction"), + @JsonSubTypes.Type(value = TableFunctionProcessorNode.class, name = "tableFunctionProcessor"), }) public abstract class PlanNode { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index 18ea400ac26e..2b4b6bfbcfe5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -248,4 +248,9 @@ public R visitTableFunction(TableFunctionNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index 3386a2280334..b3c2a9eac356 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -23,6 +23,7 @@ import javax.annotation.concurrent.Immutable; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; @@ -116,17 +117,12 @@ public List getOutputSymbols() symbols.addAll(properOutputs); - for (int i = 0; i < sources.size(); i++) { - TableArgumentProperties sourceProperties = tableArgumentProperties.get(i); - if (sourceProperties.isPassThroughColumns()) { - symbols.addAll(sources.get(i).getOutputSymbols()); - } - else { - sourceProperties.getSpecification() - .map(DataOrganizationSpecification::getPartitionBy) - .ifPresent(symbols::addAll); - } - } + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(PassThroughColumn::symbol) + .forEach(symbols::add); return symbols.build(); } @@ -149,7 +145,7 @@ public static class TableArgumentProperties private final String argumentName; private final boolean rowSemantics; private final boolean pruneWhenEmpty; - private final boolean passThroughColumns; + private final PassThroughSpecification passThroughSpecification; private final List requiredColumns; private final Optional specification; @@ -158,14 +154,14 @@ public TableArgumentProperties( @JsonProperty("argumentName") String argumentName, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, - @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { this.argumentName = requireNonNull(argumentName, "argumentName is null"); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; - this.passThroughColumns = passThroughColumns; + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } @@ -189,9 +185,9 @@ public boolean isPruneWhenEmpty() } @JsonProperty - public boolean isPassThroughColumns() + public PassThroughSpecification getPassThroughSpecification() { - return passThroughColumns; + return passThroughSpecification; } @JsonProperty @@ -206,4 +202,23 @@ public Optional getSpecification() return specification; } } + + public record PassThroughSpecification(boolean declaredAsPassThrough, List columns) + { + public PassThroughSpecification + { + columns = ImmutableList.copyOf(columns); + checkArgument( + declaredAsPassThrough || columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + } + + public record PassThroughColumn(Symbol symbol, boolean isPartitioningColumn) + { + public PassThroughColumn + { + requireNonNull(symbol, "symbol is null"); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java new file mode 100644 index 000000000000..e04e5450a599 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java @@ -0,0 +1,226 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.metadata.TableFunctionHandle; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorNode + extends PlanNode +{ + private final String name; + + // symbols produced by the function + private final List properOutputs; + + // pre-planned sources + private final Optional source; + // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + + // all source symbols to be produced on output, ordered as table argument specifications + private final List passThroughSpecifications; + + // symbols required from each source, ordered as table argument specifications + private final List> requiredSymbols; + + // mapping from source symbol to helper "marker" symbol which indicates whether the source value is valid + // for processing or for pass-through. null value in the marker column indicates that the value at the same + // position in the source column should not be processed or passed-through. + // the mapping is only present if there are two or more sources. + private final Optional> markerSymbols; + + // partitioning and ordering combined from sources + private final Optional specification; + private final Set prePartitioned; + private final int preSorted; + private final Optional hashSymbol; + + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("source") Optional source, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, + @JsonProperty("passThroughSpecifications") List passThroughSpecifications, + @JsonProperty("requiredSymbols") List> requiredSymbols, + @JsonProperty("markerSymbols") Optional> markerSymbols, + @JsonProperty("specification") Optional specification, + @JsonProperty("prePartitioned") Set prePartitioned, + @JsonProperty("preSorted") int preSorted, + @JsonProperty("hashSymbol") Optional hashSymbol, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(id); + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.source = requireNonNull(source, "source is null"); + this.pruneWhenEmpty = pruneWhenEmpty; + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.requiredSymbols = requiredSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.prePartitioned = ImmutableSet.copyOf(prePartitioned); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); + this.preSorted = preSorted; + checkArgument( + specification + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .map(OrderingScheme::getOrderBy) + .map(List::size) + .orElse(0) >= preSorted, + "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); + checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public Optional getSource() + { + return source; + } + + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + + @JsonProperty + public List getPassThroughSpecifications() + { + return passThroughSpecifications; + } + + @JsonProperty + public List> getRequiredSymbols() + { + return requiredSymbols; + } + + @JsonProperty + public Optional> getMarkerSymbols() + { + return markerSymbols; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + + @JsonProperty + public Set getPrePartitioned() + { + return prePartitioned; + } + + @JsonProperty + public int getPreSorted() + { + return preSorted; + } + + @JsonProperty + public Optional getHashSymbol() + { + return hashSymbol; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return source.map(ImmutableList::of).orElse(ImmutableList.of()); + } + + @Override + public List getOutputSymbols() + { + ImmutableList.Builder symbols = ImmutableList.builder(); + + symbols.addAll(properOutputs); + + passThroughSpecifications.stream() + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::symbol) + .forEach(symbols::add); + + return symbols.build(); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitTableFunctionProcessor(this, context); + } + + @Override + public PlanNode replaceChildren(List newSources) + { + Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, pruneWhenEmpty, passThroughSpecifications, requiredSymbols, markerSymbols, specification, prePartitioned, preSorted, hashSymbol, handle); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 9fe79a604a41..5bd1c626f1b6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -107,6 +107,7 @@ import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -1792,54 +1793,109 @@ public Void visitTableFunction(TableFunctionNode node, Context context) private String formatArgument(String argumentName, Argument argument, Map tableArguments) { if (argument instanceof ScalarArgument scalarArgument) { - return format( - "%s => ScalarArgument{type=%s, value=%s}", - argumentName, - scalarArgument.getType().getDisplayName(), - anonymizer.anonymize( - scalarArgument.getType(), - valuePrinter.castToVarchar(scalarArgument.getType(), scalarArgument.getValue()))); + return formatScalarArgument(argumentName, scalarArgument); } if (argument instanceof DescriptorArgument descriptorArgument) { - String descriptor; - if (descriptorArgument.equals(NULL_DESCRIPTOR)) { - descriptor = "NULL"; - } - else { - descriptor = descriptorArgument.getDescriptor().orElseThrow().getFields().stream() - .map(field -> anonymizer.anonymizeColumn(field.getName()) + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) - .collect(joining(", ", "(", ")")); - } - return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + return formatDescriptorArgument(argumentName, descriptorArgument); } else { TableArgumentProperties argumentProperties = tableArguments.get(argumentName); - StringBuilder properties = new StringBuilder(); - if (argumentProperties.isRowSemantics()) { - properties.append("row semantics"); - } - argumentProperties.getSpecification().ifPresent(specification -> { + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + anonymizer.anonymize( + argument.getType(), + valuePrinter.castToVarchar(argument.getType(), argument.getValue()))); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow().getFields().stream() + .map(field -> anonymizer.anonymizeColumn(field.getName()) + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableArgumentProperties argumentProperties) + { + StringBuilder properties = new StringBuilder(); + if (argumentProperties.isRowSemantics()) { + properties.append("row semantics"); + } + argumentProperties.getSpecification().ifPresent(specification -> { + properties + .append("partition by: [") + .append(Joiner.on(", ").join(anonymize(specification.getPartitionBy()))) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { properties - .append("partition by: [") - .append(Joiner.on(", ").join(anonymize(specification.getPartitionBy()))) - .append("]"); - specification.getOrderingScheme().ifPresent(orderingScheme -> { - properties - .append(", order by: ") - .append(formatOrderingScheme(orderingScheme)); - }); + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); }); - properties.append("required columns: [") - .append(Joiner.on(", ").join(anonymize(argumentProperties.getRequiredColumns()))) - .append("]"); - if (argumentProperties.isPruneWhenEmpty()) { - properties.append(", prune when empty"); - } - if (argumentProperties.isPassThroughColumns()) { - properties.append(", pass through columns"); - } - return format("%s => TableArgument{%s}", argumentName, properties); + }); + properties.append("required columns: [") + .append(Joiner.on(", ").join(anonymize(argumentProperties.getRequiredColumns()))) + .append("]"); + if (argumentProperties.isPruneWhenEmpty()) { + properties.append(", prune when empty"); + } + if (argumentProperties.getPassThroughSpecification().declaredAsPassThrough()) { + properties.append(", pass through columns"); } + return format("%s => TableArgument{%s}", argumentName, properties); + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Context context) + { + ImmutableMap.Builder descriptor = ImmutableMap.builder(); + + descriptor.put("name", node.getName()); + + descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(anonymize(node.getProperOutputs())))); + + node.getSpecification().ifPresent(specification -> { + if (!specification.getPartitionBy().isEmpty()) { + List prePartitioned = specification.getPartitionBy().stream() + .filter(node.getPrePartitioned()::contains) + .collect(toImmutableList()); + + List notPrePartitioned = specification.getPartitionBy().stream() + .filter(column -> !node.getPrePartitioned().contains(column)) + .collect(toImmutableList()); + + StringBuilder builder = new StringBuilder(); + if (!prePartitioned.isEmpty()) { + builder.append(anonymize(prePartitioned).stream() + .collect(joining(", ", "<", ">"))); + if (!notPrePartitioned.isEmpty()) { + builder.append(", "); + } + } + if (!notPrePartitioned.isEmpty()) { + builder.append(Joiner.on(", ").join(anonymize(notPrePartitioned))); + } + descriptor.put("partitionBy", format("[%s]", builder)); + } + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy", formatOrderingScheme(orderingScheme, node.getPreSorted()))); + }); + + addNode(node, "TableFunctionProcessor", descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow(), context.tag()); + + return processChildren(node, new Context()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 04564844a5af..f9475b2f4fd1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -65,6 +65,9 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -249,8 +252,87 @@ public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) source.getOutputSymbols()); }); }); + Set passThroughSymbols = argumentProperties.getPassThroughSpecification().columns().stream() + .map(PassThroughColumn::symbol) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + passThroughSymbols, + source.getOutputSymbols()); + } + + return null; + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundSymbols) + { + if (node.getSource().isEmpty()) { + return null; } + PlanNode source = node.getSource().orElseThrow(); + source.accept(this, boundSymbols); + + Set inputs = createInputs(source, boundSymbols); + + Set passThroughSymbols = node.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(PassThroughColumn::symbol) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols (%s) not in source plan output (%s)", + passThroughSymbols, + source.getOutputSymbols()); + + Set requiredSymbols = node.getRequiredSymbols().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + checkDependencies( + inputs, + requiredSymbols, + "Invalid node. Required symbols (%s) not in source plan output (%s)", + requiredSymbols, + source.getOutputSymbols()); + + node.getMarkerSymbols().ifPresent(mapping -> { + checkDependencies( + inputs, + mapping.keySet(), + "Invalid node. Source symbols (%s) not in source plan output (%s)", + mapping.keySet(), + source.getOutputSymbols()); + checkDependencies( + inputs, + mapping.values(), + "Invalid node. Source marker symbols (%s) not in source plan output (%s)", + mapping.values(), + source.getOutputSymbols()); + }); + + node.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols (%s) not in source plan output (%s)", + specification.getPartitionBy(), + source.getOutputSymbols()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderBy(), + "Invalid node. Order by symbols (%s) not in source plan output (%s)", + orderingScheme.getOrderBy(), + source.getOutputSymbols()); + }); + }); + return null; } diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index 4febbe36409a..3a94bacd9509 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -158,6 +158,7 @@ public TableArgumentFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -234,9 +235,11 @@ public TwoTableArgumentsFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT1") + .keepWhenEmpty() .build(), TableArgumentSpecification.builder() .name("INPUT2") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -264,6 +267,7 @@ public OnlyPassThroughFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), ONLY_PASS_THROUGH); } @@ -308,6 +312,7 @@ public PolymorphicStaticReturnTypeFunction() "polymorphic_static_return_type_function", ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), new DescribedTable(Descriptor.descriptor( ImmutableList.of("a", "b"), @@ -332,6 +337,7 @@ public PassThroughFunction() ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .passThroughColumns() + .keepWhenEmpty() .build()), new DescribedTable(Descriptor.descriptor( ImmutableList.of("x"), @@ -357,6 +363,7 @@ public DifferentArgumentTypesFunction() TableArgumentSpecification.builder() .name("INPUT_1") .passThroughColumns() + .keepWhenEmpty() .build(), DescriptorArgumentSpecification.builder() .name("LAYOUT") @@ -401,6 +408,7 @@ public RequiredColumnsFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java index 94eed0fe15a4..8baca4810b01 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java @@ -22,10 +22,12 @@ import io.trino.connector.TestingTableFunctions.DifferentArgumentTypesFunction; import io.trino.connector.TestingTableFunctions.TestingTableFunctionHandle; import io.trino.connector.TestingTableFunctions.TwoScalarArgumentsFunction; +import io.trino.connector.TestingTableFunctions.TwoTableArgumentsFunction; import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.ptf.Descriptor; import io.trino.spi.ptf.Descriptor.Field; import io.trino.sql.planner.assertions.BasePlanTest; +import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.LongLiteral; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -58,7 +60,8 @@ public final void setup() .withTableFunctions(ImmutableSet.of( new DifferentArgumentTypesFunction(), new TwoScalarArgumentsFunction(), - new DescriptorArgumentFunction())) + new DescriptorArgumentFunction(), + new TwoTableArgumentsFunction())) .withApplyTableFunction((session, handle) -> { if (handle instanceof TestingTableFunctionHandle functionHandle) { return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow())); @@ -89,17 +92,20 @@ public void testTableFunctionInitialPlan() "INPUT_1", tableArgument(0) .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) - .passThroughColumns()) + .passThroughColumns() + .passThroughSymbols(ImmutableSet.of("c1"))) .addTableArgument( "INPUT_3", tableArgument(2) .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) - .pruneWhenEmpty()) + .pruneWhenEmpty() + .passThroughSymbols(ImmutableSet.of("c3"))) .addTableArgument( "INPUT_2", tableArgument(1) .rowSemantics() - .passThroughColumns()) + .passThroughColumns() + .passThroughSymbols(ImmutableSet.of("c2"))) .addScalarArgument("ID", 2001L) .addDescriptorArgument( "LAYOUT", @@ -113,6 +119,36 @@ public void testTableFunctionInitialPlan() anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); } + @Test + public void testTableFunctionInitialPlanWithCoercionForCopartitioning() + { + assertPlan( + """ + SELECT * FROM TABLE(mock.system.two_table_arguments_function( + INPUT1 => TABLE(VALUES SMALLINT '1') t1(c1) PARTITION BY c1, + INPUT2 => TABLE(VALUES INTEGER '2') t2(c2) PARTITION BY c2 + COPARTITION (t1, t2))) t + """, + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_table_arguments_function") + .addTableArgument( + "INPUT1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1_coerced"), ImmutableList.of(), ImmutableMap.of())) + .passThroughSymbols(ImmutableSet.of("c1"))) + .addTableArgument( + "INPUT2", + tableArgument(1) + .specification(specification(ImmutableList.of("c2"), ImmutableList.of(), ImmutableMap.of())) + .passThroughSymbols(ImmutableSet.of("c2"))) + .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) + .properOutputs(ImmutableList.of("COLUMN")), + project(ImmutableMap.of("c1_coerced", expression("CAST(c1 AS INTEGER)")), + anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new GenericLiteral("SMALLINT", "1")))))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new GenericLiteral("INTEGER", "2")))))))); + } + @Test public void testNullScalarArgument() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 4d52921ed152..a7c2b50384c9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -128,7 +128,7 @@ public static PlanMatchPattern any(PlanMatchPattern... sources) * Matches to any tree of nodes with children matching to given source matchers. * anyTree(tableScan("nation")) - will match to any plan which all leafs contain * any node containing table scan from nation table. - * + *

* Note: anyTree does not match zero nodes. E.g. output(anyTree(tableScan)) will NOT match TableScan node followed by OutputNode. */ public static PlanMatchPattern anyTree(PlanMatchPattern... sources) @@ -848,6 +848,20 @@ public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern source) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(source); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java index 322c9fbc1c87..88454654f413 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java @@ -24,9 +24,11 @@ import io.trino.spi.ptf.DescriptorArgument; import io.trino.spi.ptf.ScalarArgument; import io.trino.spi.ptf.TableArgument; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.tree.SymbolReference; @@ -35,10 +37,12 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; import static io.trino.sql.planner.assertions.MatchResult.match; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; @@ -115,7 +119,7 @@ else if (expected instanceof ScalarArgumentValue expectedScalar) { } if (expectedTableArgument.rowSemantics() != argumentProperties.isRowSemantics() || expectedTableArgument.pruneWhenEmpty() != argumentProperties.isPruneWhenEmpty() || - expectedTableArgument.passThroughColumns() != argumentProperties.isPassThroughColumns()) { + expectedTableArgument.passThroughColumns() != argumentProperties.getPassThroughSpecification().declaredAsPassThrough()) { return NO_MATCH; } boolean specificationMatches = expectedTableArgument.specification() @@ -124,6 +128,16 @@ else if (expected instanceof ScalarArgumentValue expectedScalar) { if (!specificationMatches) { return NO_MATCH; } + Set expectedPassThrough = expectedTableArgument.passThroughSymbols().stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = argumentProperties.getPassThroughSpecification().columns().stream() + .map(PassThroughColumn::symbol) + .map(Symbol::toSymbolReference) + .collect(toImmutableSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } } } @@ -246,16 +260,24 @@ public record TableArgumentValue( boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, - Optional> specification) + Optional> specification, + Set passThroughSymbols) implements ArgumentValue { - public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification) + public TableArgumentValue( + int sourceIndex, + boolean rowSemantics, + boolean pruneWhenEmpty, + boolean passThroughColumns, + Optional> specification, + Set passThroughSymbols) { this.sourceIndex = sourceIndex; this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; this.specification = requireNonNull(specification, "specification is null"); + this.passThroughSymbols = ImmutableSet.copyOf(passThroughSymbols); } public static class Builder @@ -265,6 +287,7 @@ public static class Builder private boolean pruneWhenEmpty; private boolean passThroughColumns; private Optional> specification = Optional.empty(); + private Set passThroughSymbols = ImmutableSet.of(); private Builder(int sourceIndex) { @@ -301,9 +324,15 @@ public Builder specification(ExpectedValueProvider symbols) + { + this.passThroughSymbols = symbols; + return this; + } + private TableArgumentValue build() { - return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification); + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification, passThroughSymbols); } } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java new file mode 100644 index 000000000000..e9279d35fbef --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -0,0 +1,204 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; +import io.trino.sql.tree.SymbolReference; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; +import static io.trino.sql.planner.assertions.MatchResult.match; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorMatcher + implements Matcher +{ + private final String name; + private final List properOutputs; + private final Set passThroughSymbols; + private final Optional> markerSymbols; + private final Optional> specification; + + private TableFunctionProcessorMatcher( + String name, + List properOutputs, + Set passThroughSymbols, + Optional> markerSymbols, + Optional> specification) + { + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.passThroughSymbols = ImmutableSet.copyOf(passThroughSymbols); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionProcessorNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionProcessorNode tableFunctionProcessorNode = (TableFunctionProcessorNode) node; + + if (!name.equals(tableFunctionProcessorNode.getName())) { + return NO_MATCH; + } + + if (properOutputs.size() != tableFunctionProcessorNode.getProperOutputs().size()) { + return NO_MATCH; + } + + Set expectedPassThrough = passThroughSymbols.stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(PassThroughColumn::symbol) + .map(Symbol::toSymbolReference) + .collect(toImmutableSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + + if (markerSymbols.isPresent() != tableFunctionProcessorNode.getMarkerSymbols().isPresent()) { + return NO_MATCH; + } + if (markerSymbols.isPresent()) { + Map expectedMapping = markerSymbols.get().entrySet().stream() + .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); + Map actualMapping = tableFunctionProcessorNode.getMarkerSymbols().orElseThrow().entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().toSymbolReference(), entry -> entry.getValue().toSymbolReference())); + if (!expectedMapping.equals(actualMapping)) { + return NO_MATCH; + } + } + + if (specification.isPresent() != tableFunctionProcessorNode.getSpecification().isPresent()) { + return NO_MATCH; + } + if (specification.isPresent()) { + if (!specification.get().getExpectedValue(symbolAliases).equals(tableFunctionProcessorNode.getSpecification().orElseThrow())) { + return NO_MATCH; + } + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), tableFunctionProcessorNode.getProperOutputs().get(i).toSymbolReference()); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("properOutputs", properOutputs) + .add("passThroughSymbols", passThroughSymbols) + .add("markerSymbols", markerSymbols) + .add("specification", specification) + .toString(); + } + + public static class Builder + { + private final Optional source; + private String name; + private List properOutputs = ImmutableList.of(); + private Set passThroughSymbols = ImmutableSet.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional> specification = Optional.empty(); + + public Builder() + { + this.source = Optional.empty(); + } + + public Builder(PlanMatchPattern source) + { + this.source = Optional.of(source); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder passThroughSymbols(Set passThroughSymbols) + { + this.passThroughSymbols = passThroughSymbols; + return this; + } + + public Builder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public PlanMatchPattern build() + { + PlanMatchPattern[] sources = source.map(sourcePattern -> new PlanMatchPattern[] {sourcePattern}).orElse(new PlanMatchPattern[] {}); + return node(TableFunctionProcessorNode.class, sources) + .with(new TableFunctionProcessorMatcher(name, properOutputs, passThroughSymbols, markerSymbols, specification)); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java new file mode 100644 index 000000000000..a3314fe90f2f --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java @@ -0,0 +1,1360 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; +import static io.trino.spi.connector.SortOrder.DESC_NULLS_FIRST; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.functionCall; +import static io.trino.sql.planner.assertions.PlanMatchPattern.join; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.assertions.PlanMatchPattern.window; +import static io.trino.sql.planner.plan.JoinNode.Type.FULL; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; + +public class TestImplementTableFunctionSource + extends BaseRuleTest +{ + @Test + public void testNoSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> p.tableFunction( + "test_function", + ImmutableList.of(p.symbol("a")), + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of())) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a")))); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + // no pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + true, + true, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")), + values("c"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + true, + true, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, false))), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c")), + values("c"))); + } + + @Test + public void testSingleSourceWithSetSemantics() + { + // no pass-through columns, no partition by + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(d), ImmutableMap.of(d, ASC_NULLS_LAST))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .specification(specification(ImmutableList.of(), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // no pass-through columns, partitioning column present + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(d), ImmutableMap.of(d, ASC_NULLS_LAST))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, false))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())), + values("c", "d"))); + } + + @Test + public void testTwoSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + FULL, + joinBuilder -> joinBuilder + .filter(""" + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testThreeSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + Symbol g = p.symbol("g"); + Symbol h = p.symbol("h"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f), + p.values(g, h)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(h), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(h), ImmutableMap.of(h, DESC_NULLS_FIRST))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2", + "g", "marker_3", + "h", "marker_3")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)")), + join(// join nodes using helper symbols + FULL, + joinBuilder -> joinBuilder + .filter(""" + combined_row_number_1_2 = input_3_row_number OR + combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR + input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1' + """) + .left(project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + FULL, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))) + .right(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST))) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("g", "h")))))))); + } + + @Test + public void testTwoCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(f), ImmutableMap.of(f, DESC_NULLS_FIRST))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, e)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM e) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testCoPartitionJoinTypes() + { + // both sources are prune when empty, so they are combined using inner join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + INNER, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))))); + + // only the left source is prune when empty, so sources are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))))); + + // only the right source is prune when empty. the sources are reordered so that the prune when empty source is first. they are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), input_2_row_number, input_1_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), input_2_partition_size, input_1_partition_size)"), + "combined_partition_column", expression("COALESCE(d, c)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (d IS DISTINCT FROM c) + AND ( + input_2_row_number = input_1_row_number OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d"))) + .right(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")))))))); + + // neither source is prune when empty, so sources are combined using full join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + FULL, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))))); + } + + @Test + public void testThreeCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)"), + "combined_partition_column_1_2_3", expression("COALESCE(combined_partition_column_1_2, e)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (combined_partition_column_1_2 IS DISTINCT FROM e) + AND ( + combined_row_number_1_2 = input_3_row_number OR + combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR + input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1') + """) + .left(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + INNER, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))) + .right(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e")))))))); + } + + @Test + public void testTwoCoPartitionLists() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + Symbol g = p.symbol("g"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e), + p.values(f, g)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty()))), + new TableArgumentProperties( + "input_4", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(f, true))), + ImmutableList.of(g), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(f), Optional.of(new OrderingScheme(ImmutableList.of(g), ImmutableMap.of(g, DESC_NULLS_FIRST))))))), + ImmutableList.of( + ImmutableList.of("input_1", "input_2"), + ImmutableList.of("input_3", "input_4"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3", + "f", "marker_4", + "g", "marker_4")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3_4, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3_4, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3_4, input_3_row_number, CAST(null AS bigint))"), + "marker_4", expression("IF(input_4_row_number = combined_row_number_1_2_3_4, input_4_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3_4", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(combined_row_number_3_4, BIGINT '-1'), combined_row_number_1_2, combined_row_number_3_4)"), + "combined_partition_size_1_2_3_4", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(combined_partition_size_3_4, BIGINT '-1'), combined_partition_size_1_2, combined_partition_size_3_4)")), + join(// join nodes using helper symbols + LEFT, + joinBuilder -> joinBuilder + .filter(""" + combined_row_number_1_2 = combined_row_number_3_4 OR + combined_row_number_1_2 > combined_partition_size_3_4 AND combined_row_number_3_4 = BIGINT '1' OR + combined_row_number_3_4 > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1' + """) + .left(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + INNER, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))) + .right(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_3_4", expression("IF(COALESCE(input_3_row_number, BIGINT '-1') > COALESCE(input_4_row_number, BIGINT '-1'), input_3_row_number, input_4_row_number)"), + "combined_partition_size_3_4", expression("IF(COALESCE(input_3_partition_size, BIGINT '-1') > COALESCE(input_4_partition_size, BIGINT '-1'), input_3_partition_size, input_4_partition_size)"), + "combined_partition_column_3_4", expression("COALESCE(e, f)")), + join(// co-partition nodes + FULL, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (e IS DISTINCT FROM f) + AND ( + input_3_row_number = input_4_row_number OR + input_3_row_number > input_4_partition_size AND input_4_row_number = BIGINT '1' OR + input_4_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e"))) + .right(window(// append helper symbols for source input_4 + builder -> builder + .specification(specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST))) + .addFunction("input_4_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_4_partition_size", functionCall("count", ImmutableList.of())), + // input_4 + values("f", "g"))))))))))); + } + + @Test + public void testCoPartitionedAndNotCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_2_3_1, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_2_3_1, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_2_3_1, input_3_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_2_3_1", expression("IF(COALESCE(combined_row_number_2_3, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), combined_row_number_2_3, input_1_row_number)"), + "combined_partition_size_2_3_1", expression("IF(COALESCE(combined_partition_size_2_3, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), combined_partition_size_2_3, input_1_partition_size)")), + join(// join nodes using helper symbols + INNER, + joinBuilder -> joinBuilder + .filter(""" + combined_row_number_2_3 = input_1_row_number OR + combined_row_number_2_3 > input_1_partition_size AND input_1_row_number = BIGINT '1' OR + input_1_row_number > combined_partition_size_2_3 AND combined_row_number_2_3 = BIGINT '1' + """) + .left(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_2_3", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), input_2_row_number, input_3_row_number)"), + "combined_partition_size_2_3", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), input_2_partition_size, input_3_partition_size)"), + "combined_partition_column_2_3", expression("COALESCE(d, e)")), + join(// co-partition nodes + LEFT, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (d IS DISTINCT FROM e) + AND ( + input_2_row_number = input_3_row_number OR + input_2_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1' OR + input_3_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d"))) + .right(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e")))))) + .right(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")))))))); + } + + @Test + public void testCoerceForCopartitioning() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c", TINYINT); + Symbol cCoerced = p.symbol("c_coerced", INTEGER); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e", INTEGER); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + // coerce column c for co-partitioning + p.project( + Assignments.builder() + .put(c, PlanBuilder.expression("c")) + .put(d, PlanBuilder.expression("d")) + .put(cCoerced, PlanBuilder.expression("CAST(c AS INTEGER)")) + .build(), + p.values(c, d)), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(cCoerced), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(f), ImmutableMap.of(f, DESC_NULLS_FIRST))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "c_coerced", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c_coerced, e)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c_coerced IS DISTINCT FROM e) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + project( + ImmutableMap.of("c_coerced", expression("CAST(c AS INTEGER)")), + values("c", "d")))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testTwoCoPartitioningColumns() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c, d), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e, f), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1", expression("COALESCE(c, e)"), + "combined_partition_column_2", expression("COALESCE(d, f)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM e) + AND NOT (d IS DISTINCT FROM f) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testTwoSourcesWithRowAndSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + true, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(e), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + FULL, + joinBuilder -> joinBuilder + .filter(""" + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 55e7c65e0d0e..718a01134b67 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -26,13 +26,16 @@ import io.trino.metadata.OutputTableHandle; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableExecuteHandle; +import io.trino.metadata.TableFunctionHandle; import io.trino.metadata.TableHandle; import io.trino.operator.RetryPolicy; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.type.Type; import io.trino.sql.ExpressionUtils; import io.trino.sql.analyzer.TypeSignatureProvider; @@ -87,6 +90,8 @@ import io.trino.sql.planner.plan.StatisticAggregationsDescriptor; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.CreateTarget; @@ -1174,6 +1179,25 @@ public TableExecuteNode tableExecute( preferredPartitioningScheme); } + public TableFunctionNode tableFunction( + String name, + List properOutputs, + List sources, + List tableArgumentProperties, + List> copartitioningLists) + + { + return new TableFunctionNode( + idAllocator.getNextId(), + name, + ImmutableMap.of(), + properOutputs, + sources, + tableArgumentProperties, + copartitioningLists, + new TableFunctionHandle(TEST_CATALOG_HANDLE, new SchemaFunctionName("system", name), new ConnectorTableFunctionHandle() {}, TestingTransactionHandle.create())); + } + public PartitioningScheme partitioningScheme(List outputSymbols, List partitioningSymbols, Symbol hashSymbol) { return new PartitioningScheme(Partitioning.create( diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java b/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java index eeeb48517c11..1f7e9c33869e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java @@ -13,6 +13,8 @@ */ package io.trino.spi.function; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.Experimental; import java.util.Objects; @@ -25,7 +27,8 @@ public final class SchemaFunctionName private final String schemaName; private final String functionName; - public SchemaFunctionName(String schemaName, String functionName) + @JsonCreator + public SchemaFunctionName(@JsonProperty("schemaName") String schemaName, @JsonProperty("functionName") String functionName) { this.schemaName = requireNonNull(schemaName, "schemaName is null"); if (schemaName.isEmpty()) { @@ -37,11 +40,13 @@ public SchemaFunctionName(String schemaName, String functionName) } } + @JsonProperty public String getSchemaName() { return schemaName; } + @JsonProperty public String getFunctionName() { return functionName; diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java index c9a1aa16c519..bdd209d57e84 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java @@ -15,6 +15,9 @@ import io.trino.spi.Experimental; +import static io.trino.spi.ptf.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + @Experimental(eta = "2022-10-31") public class TableArgumentSpecification extends ArgumentSpecification @@ -23,10 +26,13 @@ public class TableArgumentSpecification private final boolean pruneWhenEmpty; private final boolean passThroughColumns; - private TableArgumentSpecification(String name, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns) + private TableArgumentSpecification(String name, boolean rowSemantics, Boolean pruneWhenEmpty, boolean passThroughColumns) { super(name, true, null); + requireNonNull(pruneWhenEmpty, "The pruneWhenEmpty property is not set"); + checkArgument(!rowSemantics || pruneWhenEmpty, "Cannot set the KEEP WHEN EMPTY property for a table argument with row semantics"); + this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; @@ -56,7 +62,7 @@ public static final class Builder { private String name; private boolean rowSemantics; - private boolean pruneWhenEmpty; + private Boolean pruneWhenEmpty; private boolean passThroughColumns; private Builder() {} @@ -80,6 +86,12 @@ public Builder pruneWhenEmpty() return this; } + public Builder keepWhenEmpty() + { + this.pruneWhenEmpty = false; + return this; + } + public Builder passThroughColumns() { this.passThroughColumns = true;