From 8028d3d218534322007d1729fa5a399e3806ee09 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Wed, 25 Jan 2023 13:28:54 +0100 Subject: [PATCH 1/9] Require explicit specification of KEEP / PRUNE WHEN EMPTY Change the builder for TableArgumentSpecification so that there is no default for the empty behavior. --- .../trino/connector/TestingTableFunctions.java | 8 ++++++++ .../spi/ptf/TableArgumentSpecification.java | 16 ++++++++++++++-- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index 4febbe36409a..3a94bacd9509 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -158,6 +158,7 @@ public TableArgumentFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -234,9 +235,11 @@ public TwoTableArgumentsFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT1") + .keepWhenEmpty() .build(), TableArgumentSpecification.builder() .name("INPUT2") + .keepWhenEmpty() .build()), GENERIC_TABLE); } @@ -264,6 +267,7 @@ public OnlyPassThroughFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), ONLY_PASS_THROUGH); } @@ -308,6 +312,7 @@ public PolymorphicStaticReturnTypeFunction() "polymorphic_static_return_type_function", ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), new DescribedTable(Descriptor.descriptor( ImmutableList.of("a", "b"), @@ -332,6 +337,7 @@ public PassThroughFunction() ImmutableList.of(TableArgumentSpecification.builder() .name("INPUT") .passThroughColumns() + .keepWhenEmpty() .build()), new DescribedTable(Descriptor.descriptor( ImmutableList.of("x"), @@ -357,6 +363,7 @@ public DifferentArgumentTypesFunction() TableArgumentSpecification.builder() .name("INPUT_1") .passThroughColumns() + .keepWhenEmpty() .build(), DescriptorArgumentSpecification.builder() .name("LAYOUT") @@ -401,6 +408,7 @@ public RequiredColumnsFunction() ImmutableList.of( TableArgumentSpecification.builder() .name("INPUT") + .keepWhenEmpty() .build()), GENERIC_TABLE); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java index c9a1aa16c519..bdd209d57e84 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableArgumentSpecification.java @@ -15,6 +15,9 @@ import io.trino.spi.Experimental; +import static io.trino.spi.ptf.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + @Experimental(eta = "2022-10-31") public class TableArgumentSpecification extends ArgumentSpecification @@ -23,10 +26,13 @@ public class TableArgumentSpecification private final boolean pruneWhenEmpty; private final boolean passThroughColumns; - private TableArgumentSpecification(String name, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns) + private TableArgumentSpecification(String name, boolean rowSemantics, Boolean pruneWhenEmpty, boolean passThroughColumns) { super(name, true, null); + requireNonNull(pruneWhenEmpty, "The pruneWhenEmpty property is not set"); + checkArgument(!rowSemantics || pruneWhenEmpty, "Cannot set the KEEP WHEN EMPTY property for a table argument with row semantics"); + this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; @@ -56,7 +62,7 @@ public static final class Builder { private String name; private boolean rowSemantics; - private boolean pruneWhenEmpty; + private Boolean pruneWhenEmpty; private boolean passThroughColumns; private Builder() {} @@ -80,6 +86,12 @@ public Builder pruneWhenEmpty() return this; } + public Builder keepWhenEmpty() + { + this.pruneWhenEmpty = false; + return this; + } + public Builder passThroughColumns() { this.passThroughColumns = true; From 05535286323470132dd5df0e88ee1e6aa49f4c34 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Fri, 7 Oct 2022 11:20:05 +0200 Subject: [PATCH 2/9] Pass uncoerced partitioning columns from table function source for output If a table function has multiple partitioned inputs, they might be co-partitioned. Co-partitioning might require coercing of the corresponding partitioning columns to common supertype. Additionally, each partitioning source column should be produced on table function's output in its original (uncoerced) form. Before this change, the coerced columns were incorrectly passed to output. After this change, the uncoerced columns are preserved for output independently from the Specification, which is planned with regard to the required co-partitioning coercions. --- .../io/trino/sql/planner/RelationPlanner.java | 26 ++++++----- .../UnaliasSymbolReferences.java | 1 + .../sql/planner/plan/TableFunctionNode.java | 23 +++++----- .../sanity/ValidateDependenciesChecker.java | 7 +++ .../planner/TestTableFunctionInvocation.java | 44 +++++++++++++++++-- .../assertions/TableFunctionMatcher.java | 33 ++++++++++++-- 6 files changed, 106 insertions(+), 28 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 756841691a25..62736da52a84 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -428,27 +428,33 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node specification = Optional.of(new DataOrganizationSpecification(partitionBy, orderBy)); } - sources.add(sourcePlanBuilder.getRoot()); - sourceProperties.add(new TableArgumentProperties( - tableArgument.getArgumentName(), - tableArgument.isRowSemantics(), - tableArgument.isPruneWhenEmpty(), - tableArgument.isPassThroughColumns(), - requiredColumns, - specification)); - // add output symbols passed from the table argument + ImmutableList.Builder passThroughSymbols = ImmutableList.builder(); if (tableArgument.isPassThroughColumns()) { // the original output symbols from the source node, not coerced // note: hidden columns are included. They are present in sourcePlan.fieldMappings outputSymbols.addAll(sourcePlan.getFieldMappings()); + passThroughSymbols.addAll(sourcePlan.getFieldMappings()); } else if (tableArgument.getPartitionBy().isPresent()) { tableArgument.getPartitionBy().get().stream() // the original symbols for partitioning columns, not coerced .map(sourcePlanBuilder::translate) - .forEach(outputSymbols::add); + .forEach(symbol -> { + outputSymbols.add(symbol); + passThroughSymbols.add(symbol); + }); } + + sources.add(sourcePlanBuilder.getRoot()); + sourceProperties.add(new TableArgumentProperties( + tableArgument.getArgumentName(), + tableArgument.isRowSemantics(), + tableArgument.isPruneWhenEmpty(), + tableArgument.isPassThroughColumns(), + passThroughSymbols.build(), + requiredColumns, + specification)); } PlanNode root = new TableFunctionNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index a9d9f52e466d..702c47bb1d2a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -341,6 +341,7 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext properties.isRowSemantics(), properties.isPruneWhenEmpty(), properties.isPassThroughColumns(), + inputMapper.map(properties.getPassThroughSymbols()), inputMapper.map(properties.getRequiredColumns()), newSpecification)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index 3386a2280334..05f949bd799d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -116,17 +116,9 @@ public List getOutputSymbols() symbols.addAll(properOutputs); - for (int i = 0; i < sources.size(); i++) { - TableArgumentProperties sourceProperties = tableArgumentProperties.get(i); - if (sourceProperties.isPassThroughColumns()) { - symbols.addAll(sources.get(i).getOutputSymbols()); - } - else { - sourceProperties.getSpecification() - .map(DataOrganizationSpecification::getPartitionBy) - .ifPresent(symbols::addAll); - } - } + tableArgumentProperties.stream() + .map(TableArgumentProperties::getPassThroughSymbols) + .forEach(symbols::addAll); return symbols.build(); } @@ -150,6 +142,7 @@ public static class TableArgumentProperties private final boolean rowSemantics; private final boolean pruneWhenEmpty; private final boolean passThroughColumns; + private final List passThroughSymbols; private final List requiredColumns; private final Optional specification; @@ -159,6 +152,7 @@ public TableArgumentProperties( @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("passThroughSymbols") List passThroughSymbols, @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { @@ -166,6 +160,7 @@ public TableArgumentProperties( this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; + this.passThroughSymbols = ImmutableList.copyOf(passThroughSymbols); this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } @@ -194,6 +189,12 @@ public boolean isPassThroughColumns() return passThroughColumns; } + @JsonProperty + public List getPassThroughSymbols() + { + return passThroughSymbols; + } + @JsonProperty public List getRequiredColumns() { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 04564844a5af..122091ac1350 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -249,6 +249,13 @@ public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) source.getOutputSymbols()); }); }); + checkDependencies( + inputs, + argumentProperties.getPassThroughSymbols(), + "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", + argumentProperties.getArgumentName(), + argumentProperties.getPassThroughSymbols(), + source.getOutputSymbols()); } return null; diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java index 94eed0fe15a4..8baca4810b01 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestTableFunctionInvocation.java @@ -22,10 +22,12 @@ import io.trino.connector.TestingTableFunctions.DifferentArgumentTypesFunction; import io.trino.connector.TestingTableFunctions.TestingTableFunctionHandle; import io.trino.connector.TestingTableFunctions.TwoScalarArgumentsFunction; +import io.trino.connector.TestingTableFunctions.TwoTableArgumentsFunction; import io.trino.spi.connector.TableFunctionApplicationResult; import io.trino.spi.ptf.Descriptor; import io.trino.spi.ptf.Descriptor.Field; import io.trino.sql.planner.assertions.BasePlanTest; +import io.trino.sql.tree.GenericLiteral; import io.trino.sql.tree.LongLiteral; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; @@ -58,7 +60,8 @@ public final void setup() .withTableFunctions(ImmutableSet.of( new DifferentArgumentTypesFunction(), new TwoScalarArgumentsFunction(), - new DescriptorArgumentFunction())) + new DescriptorArgumentFunction(), + new TwoTableArgumentsFunction())) .withApplyTableFunction((session, handle) -> { if (handle instanceof TestingTableFunctionHandle functionHandle) { return Optional.of(new TableFunctionApplicationResult<>(functionHandle.getTableHandle(), functionHandle.getTableHandle().getColumns().orElseThrow())); @@ -89,17 +92,20 @@ public void testTableFunctionInitialPlan() "INPUT_1", tableArgument(0) .specification(specification(ImmutableList.of("c1"), ImmutableList.of("c1"), ImmutableMap.of("c1", ASC_NULLS_LAST))) - .passThroughColumns()) + .passThroughColumns() + .passThroughSymbols(ImmutableSet.of("c1"))) .addTableArgument( "INPUT_3", tableArgument(2) .specification(specification(ImmutableList.of("c3"), ImmutableList.of(), ImmutableMap.of())) - .pruneWhenEmpty()) + .pruneWhenEmpty() + .passThroughSymbols(ImmutableSet.of("c3"))) .addTableArgument( "INPUT_2", tableArgument(1) .rowSemantics() - .passThroughColumns()) + .passThroughColumns() + .passThroughSymbols(ImmutableSet.of("c2"))) .addScalarArgument("ID", 2001L) .addDescriptorArgument( "LAYOUT", @@ -113,6 +119,36 @@ public void testTableFunctionInitialPlan() anyTree(project(ImmutableMap.of("c3", expression("'b'")), values(1)))))); } + @Test + public void testTableFunctionInitialPlanWithCoercionForCopartitioning() + { + assertPlan( + """ + SELECT * FROM TABLE(mock.system.two_table_arguments_function( + INPUT1 => TABLE(VALUES SMALLINT '1') t1(c1) PARTITION BY c1, + INPUT2 => TABLE(VALUES INTEGER '2') t2(c2) PARTITION BY c2 + COPARTITION (t1, t2))) t + """, + CREATED, + anyTree(tableFunction(builder -> builder + .name("two_table_arguments_function") + .addTableArgument( + "INPUT1", + tableArgument(0) + .specification(specification(ImmutableList.of("c1_coerced"), ImmutableList.of(), ImmutableMap.of())) + .passThroughSymbols(ImmutableSet.of("c1"))) + .addTableArgument( + "INPUT2", + tableArgument(1) + .specification(specification(ImmutableList.of("c2"), ImmutableList.of(), ImmutableMap.of())) + .passThroughSymbols(ImmutableSet.of("c2"))) + .addCopartitioning(ImmutableList.of("INPUT1", "INPUT2")) + .properOutputs(ImmutableList.of("COLUMN")), + project(ImmutableMap.of("c1_coerced", expression("CAST(c1 AS INTEGER)")), + anyTree(values(ImmutableList.of("c1"), ImmutableList.of(ImmutableList.of(new GenericLiteral("SMALLINT", "1")))))), + anyTree(values(ImmutableList.of("c2"), ImmutableList.of(ImmutableList.of(new GenericLiteral("INTEGER", "2")))))))); + } + @Test public void testNullScalarArgument() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java index 322c9fbc1c87..932209e4ff0b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java @@ -24,6 +24,7 @@ import io.trino.spi.ptf.DescriptorArgument; import io.trino.spi.ptf.ScalarArgument; import io.trino.spi.ptf.TableArgument; +import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.TableFunctionNode; @@ -35,10 +36,12 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; import static io.trino.sql.planner.assertions.MatchResult.match; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; @@ -124,6 +127,15 @@ else if (expected instanceof ScalarArgumentValue expectedScalar) { if (!specificationMatches) { return NO_MATCH; } + Set expectedPassThrough = expectedTableArgument.passThroughSymbols().stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = argumentProperties.getPassThroughSymbols().stream() + .map(Symbol::toSymbolReference) + .collect(toImmutableSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } } } @@ -246,16 +258,24 @@ public record TableArgumentValue( boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, - Optional> specification) + Optional> specification, + Set passThroughSymbols) implements ArgumentValue { - public TableArgumentValue(int sourceIndex, boolean rowSemantics, boolean pruneWhenEmpty, boolean passThroughColumns, Optional> specification) + public TableArgumentValue( + int sourceIndex, + boolean rowSemantics, + boolean pruneWhenEmpty, + boolean passThroughColumns, + Optional> specification, + Set passThroughSymbols) { this.sourceIndex = sourceIndex; this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; this.specification = requireNonNull(specification, "specification is null"); + this.passThroughSymbols = ImmutableSet.copyOf(passThroughSymbols); } public static class Builder @@ -265,6 +285,7 @@ public static class Builder private boolean pruneWhenEmpty; private boolean passThroughColumns; private Optional> specification = Optional.empty(); + private Set passThroughSymbols = ImmutableSet.of(); private Builder(int sourceIndex) { @@ -301,9 +322,15 @@ public Builder specification(ExpectedValueProvider symbols) + { + this.passThroughSymbols = symbols; + return this; + } + private TableArgumentValue build() { - return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification); + return new TableArgumentValue(sourceIndex, rowSemantics, pruneWhenEmpty, passThroughColumns, specification, passThroughSymbols); } } } From cf0bd23dc457ba0cf6483e086cc342b071c32efb Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Sun, 25 Dec 2022 16:08:32 +0100 Subject: [PATCH 3/9] Add table function's name to TableFunctionHandle Adds the SchemaFunctionName, which can act as a unique identifier of the function for a given catalog. --- .../java/io/trino/metadata/TableFunctionHandle.java | 10 ++++++++++ .../src/main/java/io/trino/sql/analyzer/Analysis.java | 8 ++++++++ .../java/io/trino/sql/analyzer/StatementAnalyzer.java | 1 + .../java/io/trino/sql/planner/RelationPlanner.java | 7 ++++++- .../java/io/trino/spi/function/SchemaFunctionName.java | 7 ++++++- 5 files changed, 31 insertions(+), 2 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java index 2bc0b3fccb65..6d2885f7d69f 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java +++ b/core/trino-main/src/main/java/io/trino/metadata/TableFunctionHandle.java @@ -17,6 +17,7 @@ import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ConnectorTransactionHandle; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.ptf.ConnectorTableFunctionHandle; import static java.util.Objects.requireNonNull; @@ -24,16 +25,19 @@ public class TableFunctionHandle { private final CatalogHandle catalogHandle; + private final SchemaFunctionName schemaFunctionName; private final ConnectorTableFunctionHandle functionHandle; private final ConnectorTransactionHandle transactionHandle; @JsonCreator public TableFunctionHandle( @JsonProperty("catalogHandle") CatalogHandle catalogHandle, + @JsonProperty("schemaFunctionName") SchemaFunctionName schemaFunctionName, @JsonProperty("functionHandle") ConnectorTableFunctionHandle functionHandle, @JsonProperty("transactionHandle") ConnectorTransactionHandle transactionHandle) { this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); + this.schemaFunctionName = requireNonNull(schemaFunctionName, "schemaFunctionName is null"); this.functionHandle = requireNonNull(functionHandle, "functionHandle is null"); this.transactionHandle = requireNonNull(transactionHandle, "transactionHandle is null"); } @@ -44,6 +48,12 @@ public CatalogHandle getCatalogHandle() return catalogHandle; } + @JsonProperty + public SchemaFunctionName getSchemaFunctionName() + { + return schemaFunctionName; + } + @JsonProperty public ConnectorTableFunctionHandle getFunctionHandle() { diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index bc5bfffcda39..8f47959127fc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -2230,6 +2230,7 @@ public TableArgumentAnalysis build() public static class TableFunctionInvocationAnalysis { private final CatalogHandle catalogHandle; + private final String schemaName; private final String functionName; private final Map arguments; private final List tableArgumentAnalyses; @@ -2241,6 +2242,7 @@ public static class TableFunctionInvocationAnalysis public TableFunctionInvocationAnalysis( CatalogHandle catalogHandle, + String schemaName, String functionName, Map arguments, List tableArgumentAnalyses, @@ -2251,6 +2253,7 @@ public TableFunctionInvocationAnalysis( ConnectorTransactionHandle transactionHandle) { this.catalogHandle = requireNonNull(catalogHandle, "catalogHandle is null"); + this.schemaName = requireNonNull(schemaName, "schemaName is null"); this.functionName = requireNonNull(functionName, "functionName is null"); this.arguments = ImmutableMap.copyOf(arguments); this.tableArgumentAnalyses = ImmutableList.copyOf(tableArgumentAnalyses); @@ -2267,6 +2270,11 @@ public CatalogHandle getCatalogHandle() return catalogHandle; } + public String getSchemaName() + { + return schemaName; + } + public String getFunctionName() { return functionName; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index cf5a99c76fac..6cb9cce5f858 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -1671,6 +1671,7 @@ else if (argument.getPartitionBy().isPresent()) { analysis.setTableFunctionAnalysis(node, new TableFunctionInvocationAnalysis( catalogHandle, + function.getSchema(), function.getName(), argumentsAnalysis.getPassedArguments(), orderedTableArguments.build(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index 62736da52a84..fedee402a9f7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -22,6 +22,7 @@ import io.trino.metadata.TableFunctionHandle; import io.trino.metadata.TableHandle; import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.sql.ExpressionUtils; @@ -465,7 +466,11 @@ else if (tableArgument.getPartitionBy().isPresent()) { sources.build(), sourceProperties.build(), functionAnalysis.getCopartitioningLists(), - new TableFunctionHandle(functionAnalysis.getCatalogHandle(), functionAnalysis.getConnectorTableFunctionHandle(), functionAnalysis.getTransactionHandle())); + new TableFunctionHandle( + functionAnalysis.getCatalogHandle(), + new SchemaFunctionName(functionAnalysis.getSchemaName(), functionAnalysis.getFunctionName()), + functionAnalysis.getConnectorTableFunctionHandle(), + functionAnalysis.getTransactionHandle())); return new RelationPlan(root, analysis.getScope(node), outputSymbols.build(), outerContext); } diff --git a/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java b/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java index eeeb48517c11..1f7e9c33869e 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java +++ b/core/trino-spi/src/main/java/io/trino/spi/function/SchemaFunctionName.java @@ -13,6 +13,8 @@ */ package io.trino.spi.function; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; import io.trino.spi.Experimental; import java.util.Objects; @@ -25,7 +27,8 @@ public final class SchemaFunctionName private final String schemaName; private final String functionName; - public SchemaFunctionName(String schemaName, String functionName) + @JsonCreator + public SchemaFunctionName(@JsonProperty("schemaName") String schemaName, @JsonProperty("functionName") String functionName) { this.schemaName = requireNonNull(schemaName, "schemaName is null"); if (schemaName.isEmpty()) { @@ -37,11 +40,13 @@ public SchemaFunctionName(String schemaName, String functionName) } } + @JsonProperty public String getSchemaName() { return schemaName; } + @JsonProperty public String getFunctionName() { return functionName; From 795f46d93e3039c4872cfd2be842aeaa306f39f8 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 20 Dec 2022 10:30:35 +0100 Subject: [PATCH 4/9] Pass all necessary pass-through information the necessary pass-through information for a table function's source includes: - whether the source was declared as pass-through - an ordered list of pass-through columns - for each column, information whether it is a partitioning column --- .../io/trino/sql/planner/RelationPlanner.java | 17 +++++-- .../UnaliasSymbolReferences.java | 12 ++++- .../sql/planner/plan/TableFunctionNode.java | 46 ++++++++++++------- .../sql/planner/planprinter/PlanPrinter.java | 2 +- .../sanity/ValidateDependenciesChecker.java | 8 +++- .../assertions/TableFunctionMatcher.java | 6 ++- 6 files changed, 63 insertions(+), 28 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index fedee402a9f7..1886c6679228 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -48,6 +48,8 @@ import io.trino.sql.planner.plan.ProjectNode; import io.trino.sql.planner.plan.SampleNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.UnionNode; @@ -430,12 +432,18 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node } // add output symbols passed from the table argument - ImmutableList.Builder passThroughSymbols = ImmutableList.builder(); + ImmutableList.Builder passThroughColumns = ImmutableList.builder(); if (tableArgument.isPassThroughColumns()) { // the original output symbols from the source node, not coerced // note: hidden columns are included. They are present in sourcePlan.fieldMappings outputSymbols.addAll(sourcePlan.getFieldMappings()); - passThroughSymbols.addAll(sourcePlan.getFieldMappings()); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + sourcePlan.getFieldMappings().stream() + .map(symbol -> new PassThroughColumn(symbol, partitionBy.contains(symbol))) + .forEach(passThroughColumns::add); } else if (tableArgument.getPartitionBy().isPresent()) { tableArgument.getPartitionBy().get().stream() @@ -443,7 +451,7 @@ else if (tableArgument.getPartitionBy().isPresent()) { .map(sourcePlanBuilder::translate) .forEach(symbol -> { outputSymbols.add(symbol); - passThroughSymbols.add(symbol); + passThroughColumns.add(new PassThroughColumn(symbol, true)); }); } @@ -452,8 +460,7 @@ else if (tableArgument.getPartitionBy().isPresent()) { tableArgument.getArgumentName(), tableArgument.isRowSemantics(), tableArgument.isPruneWhenEmpty(), - tableArgument.isPassThroughColumns(), - passThroughSymbols.build(), + new PassThroughSpecification(tableArgument.isPassThroughColumns(), passThroughColumns.build()), requiredColumns, specification)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 702c47bb1d2a..08e7c4b150e7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -76,6 +76,8 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; @@ -336,12 +338,18 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext SymbolMapper inputMapper = symbolMapper(new HashMap<>(newSource.getMappings())); TableArgumentProperties properties = node.getTableArgumentProperties().get(i); Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); + PassThroughSpecification newPassThroughSpecification = new PassThroughSpecification( + properties.getPassThroughSpecification().declaredAsPassThrough(), + properties.getPassThroughSpecification().columns().stream() + .map(column -> new PassThroughColumn( + inputMapper.map(column.symbol()), + column.isPartitioningColumn())) + .collect(toImmutableList())); newTableArgumentProperties.add(new TableArgumentProperties( properties.getArgumentName(), properties.isRowSemantics(), properties.isPruneWhenEmpty(), - properties.isPassThroughColumns(), - inputMapper.map(properties.getPassThroughSymbols()), + newPassThroughSpecification, inputMapper.map(properties.getRequiredColumns()), newSpecification)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index 05f949bd799d..b3c2a9eac356 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -23,6 +23,7 @@ import javax.annotation.concurrent.Immutable; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Optional; @@ -117,8 +118,11 @@ public List getOutputSymbols() symbols.addAll(properOutputs); tableArgumentProperties.stream() - .map(TableArgumentProperties::getPassThroughSymbols) - .forEach(symbols::addAll); + .map(TableArgumentProperties::getPassThroughSpecification) + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(PassThroughColumn::symbol) + .forEach(symbols::add); return symbols.build(); } @@ -141,8 +145,7 @@ public static class TableArgumentProperties private final String argumentName; private final boolean rowSemantics; private final boolean pruneWhenEmpty; - private final boolean passThroughColumns; - private final List passThroughSymbols; + private final PassThroughSpecification passThroughSpecification; private final List requiredColumns; private final Optional specification; @@ -151,16 +154,14 @@ public TableArgumentProperties( @JsonProperty("argumentName") String argumentName, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, - @JsonProperty("passThroughColumns") boolean passThroughColumns, - @JsonProperty("passThroughSymbols") List passThroughSymbols, + @JsonProperty("passThroughSpecification") PassThroughSpecification passThroughSpecification, @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { this.argumentName = requireNonNull(argumentName, "argumentName is null"); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; - this.passThroughColumns = passThroughColumns; - this.passThroughSymbols = ImmutableList.copyOf(passThroughSymbols); + this.passThroughSpecification = requireNonNull(passThroughSpecification, "passThroughSpecification is null"); this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } @@ -184,15 +185,9 @@ public boolean isPruneWhenEmpty() } @JsonProperty - public boolean isPassThroughColumns() + public PassThroughSpecification getPassThroughSpecification() { - return passThroughColumns; - } - - @JsonProperty - public List getPassThroughSymbols() - { - return passThroughSymbols; + return passThroughSpecification; } @JsonProperty @@ -207,4 +202,23 @@ public Optional getSpecification() return specification; } } + + public record PassThroughSpecification(boolean declaredAsPassThrough, List columns) + { + public PassThroughSpecification + { + columns = ImmutableList.copyOf(columns); + checkArgument( + declaredAsPassThrough || columns.stream().allMatch(PassThroughColumn::isPartitioningColumn), + "non-partitioning pass-through column for non-pass-through source of a table function"); + } + } + + public record PassThroughColumn(Symbol symbol, boolean isPartitioningColumn) + { + public PassThroughColumn + { + requireNonNull(symbol, "symbol is null"); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 9fe79a604a41..127a62718f5f 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -1835,7 +1835,7 @@ private String formatArgument(String argumentName, Argument argument, Map TableArgument{%s}", argumentName, properties); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 122091ac1350..7f1d9565eb64 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -65,6 +65,7 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -249,12 +250,15 @@ public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) source.getOutputSymbols()); }); }); + Set passThroughSymbols = argumentProperties.getPassThroughSpecification().columns().stream() + .map(PassThroughColumn::symbol) + .collect(toImmutableSet()); checkDependencies( inputs, - argumentProperties.getPassThroughSymbols(), + passThroughSymbols, "Invalid node. Pass-through symbols for source %s (%s) not in source plan output (%s)", argumentProperties.getArgumentName(), - argumentProperties.getPassThroughSymbols(), + passThroughSymbols, source.getOutputSymbols()); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java index 932209e4ff0b..88454654f413 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionMatcher.java @@ -28,6 +28,7 @@ import io.trino.sql.planner.plan.DataOrganizationSpecification; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.tree.SymbolReference; @@ -118,7 +119,7 @@ else if (expected instanceof ScalarArgumentValue expectedScalar) { } if (expectedTableArgument.rowSemantics() != argumentProperties.isRowSemantics() || expectedTableArgument.pruneWhenEmpty() != argumentProperties.isPruneWhenEmpty() || - expectedTableArgument.passThroughColumns() != argumentProperties.isPassThroughColumns()) { + expectedTableArgument.passThroughColumns() != argumentProperties.getPassThroughSpecification().declaredAsPassThrough()) { return NO_MATCH; } boolean specificationMatches = expectedTableArgument.specification() @@ -130,7 +131,8 @@ else if (expected instanceof ScalarArgumentValue expectedScalar) { Set expectedPassThrough = expectedTableArgument.passThroughSymbols().stream() .map(symbolAliases::get) .collect(toImmutableSet()); - Set actualPassThrough = argumentProperties.getPassThroughSymbols().stream() + Set actualPassThrough = argumentProperties.getPassThroughSpecification().columns().stream() + .map(PassThroughColumn::symbol) .map(Symbol::toSymbolReference) .collect(toImmutableSet()); if (!expectedPassThrough.equals(actualPassThrough)) { From af620aaf426745e47f4a4d458cbc6b71d2684ddb Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Mon, 28 Nov 2022 15:07:54 +0100 Subject: [PATCH 5/9] Fix handling of pre-sorted prefix in SymbolMapper When the ordering scheme is mapped, some symbols might be removed due to de-duplication. If de-duplication happens within the pre-sorted prefix, the length of the prefix should be updated. This causes no issues currently, since UnaliasSymbolReferences, and so the SymbolMapper, is never called after AddLocalExchanges, where the pre-sorted symbols are determined. --- .../planner/optimizations/SymbolMapper.java | 66 +++++++++++++++++-- 1 file changed, 62 insertions(+), 4 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 46a823049e9c..c6b46622e1da 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -56,6 +56,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.Set; import java.util.function.Function; @@ -214,16 +215,18 @@ public WindowNode map(WindowNode node, PlanNode source) newFunctions.put(map(symbol), new WindowNode.Function(function.getResolvedFunction(), newArguments, newFrame, function.isIgnoreNulls())); }); + SpecificationWithPreSortedPrefix newSpecification = mapAndDistinct(node.getSpecification(), node.getPreSortedOrderPrefix()); + return new WindowNode( node.getId(), source, - mapAndDistinct(node.getSpecification()), + newSpecification.specification(), newFunctions.buildOrThrow(), node.getHashSymbol().map(this::map), node.getPrePartitionedInputs().stream() .map(this::map) .collect(toImmutableSet()), - node.getPreSortedOrderPrefix()); + newSpecification.preSorted()); } private WindowNode.Frame map(WindowNode.Frame frame) @@ -240,6 +243,18 @@ private WindowNode.Frame map(WindowNode.Frame frame) frame.getOriginalEndValue()); } + private SpecificationWithPreSortedPrefix mapAndDistinct(DataOrganizationSpecification specification, int preSorted) + { + Optional newOrderingScheme = specification.getOrderingScheme() + .map(orderingScheme -> map(orderingScheme, preSorted)); + + return new SpecificationWithPreSortedPrefix( + new DataOrganizationSpecification( + mapAndDistinct(specification.getPartitionBy()), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::orderingScheme)), + newOrderingScheme.map(OrderingSchemeWithPreSortedPrefix::preSorted).orElse(preSorted)); + } + public DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecification specification) { return new DataOrganizationSpecification( @@ -249,6 +264,8 @@ public DataOrganizationSpecification mapAndDistinct(DataOrganizationSpecificatio public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source) { + SpecificationWithPreSortedPrefix newSpecification = mapAndDistinct(node.getSpecification(), node.getPreSortedOrderPrefix()); + ImmutableMap.Builder newFunctions = ImmutableMap.builder(); node.getWindowFunctions().forEach((symbol, function) -> { List newArguments = function.getArguments().stream() @@ -271,12 +288,12 @@ public PatternRecognitionNode map(PatternRecognitionNode node, PlanNode source) return new PatternRecognitionNode( node.getId(), source, - mapAndDistinct(node.getSpecification()), + newSpecification.specification(), node.getHashSymbol().map(this::map), node.getPrePartitionedInputs().stream() .map(this::map) .collect(toImmutableSet()), - node.getPreSortedOrderPrefix(), + newSpecification.preSorted(), newFunctions.buildOrThrow(), newMeasures.buildOrThrow(), node.getCommonBaseFrame().map(this::map), @@ -352,6 +369,29 @@ public LimitNode map(LimitNode node, PlanNode source) .collect(toImmutableList())); } + public OrderingSchemeWithPreSortedPrefix map(OrderingScheme orderingScheme, int preSorted) + { + ImmutableList.Builder newSymbols = ImmutableList.builder(); + ImmutableMap.Builder newOrderings = ImmutableMap.builder(); + int newPreSorted = preSorted; + + Set added = new HashSet<>(orderingScheme.getOrderBy().size()); + + for (int i = 0; i < orderingScheme.getOrderBy().size(); i++) { + Symbol symbol = orderingScheme.getOrderBy().get(i); + Symbol canonical = map(symbol); + if (added.add(canonical)) { + newSymbols.add(canonical); + newOrderings.put(canonical, orderingScheme.getOrdering(symbol)); + } + else if (i < preSorted) { + newPreSorted--; + } + } + + return new OrderingSchemeWithPreSortedPrefix(new OrderingScheme(newSymbols.build(), newOrderings.buildOrThrow()), newPreSorted); + } + public OrderingScheme map(OrderingScheme orderingScheme) { ImmutableList.Builder newSymbols = ImmutableList.builder(); @@ -542,6 +582,24 @@ public TopNNode map(TopNNode node, PlanNode source, PlanNodeId nodeId) node.getStep()); } + private record OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + private OrderingSchemeWithPreSortedPrefix(OrderingScheme orderingScheme, int preSorted) + { + this.orderingScheme = requireNonNull(orderingScheme, "orderingScheme is null"); + this.preSorted = preSorted; + } + } + + private record SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + private SpecificationWithPreSortedPrefix(DataOrganizationSpecification specification, int preSorted) + { + this.specification = requireNonNull(specification, "specification is null"); + this.preSorted = preSorted; + } + } + public static Builder builder() { return new Builder(); From 96be405333dd55ccd789aa5133799c492cf06dcf Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 4 Oct 2022 14:18:21 +0200 Subject: [PATCH 6/9] Introduce a PlanNode for table function invocation with prepared source --- .../sql/planner/LocalExecutionPlanner.java | 7 + .../planner/optimizations/AddExchanges.java | 7 + .../planner/optimizations/SymbolMapper.java | 36 ++++ .../UnaliasSymbolReferences.java | 13 ++ .../io/trino/sql/planner/plan/PlanNode.java | 1 + .../trino/sql/planner/plan/PlanVisitor.java | 5 + .../plan/TableFunctionProcessorNode.java | 169 ++++++++++++++++++ .../sql/planner/planprinter/PlanPrinter.java | 115 +++++++----- .../sanity/ValidateDependenciesChecker.java | 67 +++++++ 9 files changed, 380 insertions(+), 40 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java index 2b7adc365b8a..b440835f90ea 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LocalExecutionPlanner.java @@ -221,6 +221,7 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.MergeTarget; @@ -1658,6 +1659,12 @@ public PhysicalOperation visitTableFunction(TableFunctionNode node, LocalExecuti throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); } + @Override + public PhysicalOperation visitTableFunctionProcessor(TableFunctionProcessorNode node, LocalExecutionPlanContext context) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + @Override public PhysicalOperation visitTopN(TopNNode node, LocalExecutionPlanContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 667ab168b37c..9afa4735e3e3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -74,6 +74,7 @@ import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -372,6 +373,12 @@ public PlanWithProperties visitTableFunction(TableFunctionNode node, PreferredPr throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); } + @Override + public PlanWithProperties visitTableFunctionProcessor(TableFunctionProcessorNode node, PreferredProperties preferredProperties) + { + throw new UnsupportedOperationException("execution by operator is not yet implemented for table function " + node.getName()); + } + @Override public PlanWithProperties visitRowNumber(RowNumberNode node, PreferredProperties preferredProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index c6b46622e1da..c862a41a1a79 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -37,6 +37,9 @@ import io.trino.sql.planner.plan.StatisticsWriterNode; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; import io.trino.sql.planner.plan.TopNRankingNode; @@ -65,6 +68,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.trino.sql.planner.plan.AggregationNode.groupingSets; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; public class SymbolMapper { @@ -356,6 +360,38 @@ public Expression rewriteSymbolReference(SymbolReference node, Void context, Exp expressionAndValuePointers.getMatchNumberSymbols()); } + public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode source) + { + List newPassThroughSpecifications = node.getPassThroughSpecifications().stream() + .map(specification -> new PassThroughSpecification( + specification.declaredAsPassThrough(), + specification.columns().stream() + .map(column -> new PassThroughColumn( + map(column.symbol()), + column.isPartitioningColumn())) + .collect(toImmutableList()))) + .collect(toImmutableList()); + + List> newRequiredSymbols = node.getRequiredSymbols().stream() + .map(this::map) + .collect(toImmutableList()); + + Optional> newMarkerSymbols = node.getMarkerSymbols() + .map(mapping -> mapping.entrySet().stream() + .collect(toMap(entry -> map(entry.getKey()), entry -> map(entry.getValue())))); + + return new TableFunctionProcessorNode( + node.getId(), + node.getName(), + map(node.getProperOutputs()), + source, + newPassThroughSpecifications, + newRequiredSymbols, + newMarkerSymbols, + node.getSpecification().map(this::mapAndDistinct), // TODO handle pre-partitioned, pre-sorted properly when we add them + node.getHandle()); + } + public LimitNode map(LimitNode node, PlanNode source) { return new LimitNode( diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 08e7c4b150e7..0e76e19a7aa5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -79,6 +79,7 @@ import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -367,6 +368,18 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext mapping); } + @Override + public PlanAndMappings visitTableFunctionProcessor(TableFunctionProcessorNode node, UnaliasContext context) + { + PlanAndMappings rewrittenSource = node.getSource().accept(this, context); + Map mapping = new HashMap<>(rewrittenSource.getMappings()); + SymbolMapper mapper = symbolMapper(mapping); + + TableFunctionProcessorNode rewrittenTableFunctionProcessor = mapper.map(node, rewrittenSource.getRoot()); + + return new PlanAndMappings(rewrittenTableFunctionProcessor, mapping); + } + @Override public PlanAndMappings visitTableScan(TableScanNode node, UnaliasContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java index 5691480683d4..013d8fe1f79e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanNode.java @@ -69,6 +69,7 @@ @JsonSubTypes.Type(value = StatisticsWriterNode.class, name = "statisticsWriterNode"), @JsonSubTypes.Type(value = PatternRecognitionNode.class, name = "patternRecognition"), @JsonSubTypes.Type(value = TableFunctionNode.class, name = "tableFunction"), + @JsonSubTypes.Type(value = TableFunctionProcessorNode.class, name = "tableFunctionProcessor"), }) public abstract class PlanNode { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java index 18ea400ac26e..2b4b6bfbcfe5 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/PlanVisitor.java @@ -248,4 +248,9 @@ public R visitTableFunction(TableFunctionNode node, C context) { return visitPlan(node, context); } + + public R visitTableFunctionProcessor(TableFunctionProcessorNode node, C context) + { + return visitPlan(node, context); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java new file mode 100644 index 000000000000..802730a47849 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java @@ -0,0 +1,169 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.plan; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.metadata.TableFunctionHandle; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorNode + extends PlanNode +{ + private final String name; + + // symbols produced by the function + private final List properOutputs; + + // pre-planned sources + private final PlanNode source; + // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + + // all source symbols to be produced on output, ordered as table argument specifications + private final List passThroughSpecifications; + + // symbols required from each source, ordered as table argument specifications + private final List> requiredSymbols; + + // mapping from source symbol to helper "marker" symbol which indicates whether the source value is valid + // for processing or for pass-through. null value in the marker column indicates that the value at the same + // position in the source column should not be processed or passed-through. + // the mapping is only present if there are two or more sources. + private final Optional> markerSymbols; + + // partitioning and ordering combined from sources + private final Optional specification; // TODO add pre-partitioned, pre-sorted + + private final TableFunctionHandle handle; + + @JsonCreator + public TableFunctionProcessorNode( + @JsonProperty("id") PlanNodeId id, + @JsonProperty("name") String name, + @JsonProperty("properOutputs") List properOutputs, + @JsonProperty("source") PlanNode source, + @JsonProperty("passThroughSpecifications") List passThroughSpecifications, + @JsonProperty("requiredSymbols") List> requiredSymbols, + @JsonProperty("markerSymbols") Optional> markerSymbols, + @JsonProperty("specification") Optional specification, + @JsonProperty("handle") TableFunctionHandle handle) + { + super(id); + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.source = requireNonNull(source, "source is null"); + this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); + this.requiredSymbols = requiredSymbols.stream() + .map(ImmutableList::copyOf) + .collect(toImmutableList()); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + this.handle = requireNonNull(handle, "handle is null"); + } + + @JsonProperty + public String getName() + { + return name; + } + + @JsonProperty + public List getProperOutputs() + { + return properOutputs; + } + + @JsonProperty + public PlanNode getSource() + { + return source; + } + + @JsonProperty + public List getPassThroughSpecifications() + { + return passThroughSpecifications; + } + + @JsonProperty + public List> getRequiredSymbols() + { + return requiredSymbols; + } + + @JsonProperty + public Optional> getMarkerSymbols() + { + return markerSymbols; + } + + @JsonProperty + public Optional getSpecification() + { + return specification; + } + + @JsonProperty + public TableFunctionHandle getHandle() + { + return handle; + } + + @JsonProperty + @Override + public List getSources() + { + return ImmutableList.of(source); + } + + @Override + public List getOutputSymbols() + { + ImmutableList.Builder symbols = ImmutableList.builder(); + + symbols.addAll(properOutputs); + + passThroughSpecifications.stream() + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(TableFunctionNode.PassThroughColumn::symbol) + .forEach(symbols::add); + + return symbols.build(); + } + + @Override + public R accept(PlanVisitor visitor, C context) + { + return visitor.visitTableFunctionProcessor(this, context); + } + + @Override + public PlanNode replaceChildren(List newSources) + { + return new TableFunctionProcessorNode(getId(), name, properOutputs, getOnlyElement(newSources), passThroughSpecifications, requiredSymbols, markerSymbols, specification, handle); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 127a62718f5f..90b6f0d07943 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -107,6 +107,7 @@ import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -1792,54 +1793,88 @@ public Void visitTableFunction(TableFunctionNode node, Context context) private String formatArgument(String argumentName, Argument argument, Map tableArguments) { if (argument instanceof ScalarArgument scalarArgument) { - return format( - "%s => ScalarArgument{type=%s, value=%s}", - argumentName, - scalarArgument.getType().getDisplayName(), - anonymizer.anonymize( - scalarArgument.getType(), - valuePrinter.castToVarchar(scalarArgument.getType(), scalarArgument.getValue()))); + return formatScalarArgument(argumentName, scalarArgument); } if (argument instanceof DescriptorArgument descriptorArgument) { - String descriptor; - if (descriptorArgument.equals(NULL_DESCRIPTOR)) { - descriptor = "NULL"; - } - else { - descriptor = descriptorArgument.getDescriptor().orElseThrow().getFields().stream() - .map(field -> anonymizer.anonymizeColumn(field.getName()) + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) - .collect(joining(", ", "(", ")")); - } - return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + return formatDescriptorArgument(argumentName, descriptorArgument); } else { TableArgumentProperties argumentProperties = tableArguments.get(argumentName); - StringBuilder properties = new StringBuilder(); - if (argumentProperties.isRowSemantics()) { - properties.append("row semantics"); - } - argumentProperties.getSpecification().ifPresent(specification -> { + return formatTableArgument(argumentName, argumentProperties); + } + } + + private String formatScalarArgument(String argumentName, ScalarArgument argument) + { + return format( + "%s => ScalarArgument{type=%s, value=%s}", + argumentName, + argument.getType().getDisplayName(), + anonymizer.anonymize( + argument.getType(), + valuePrinter.castToVarchar(argument.getType(), argument.getValue()))); + } + + private String formatDescriptorArgument(String argumentName, DescriptorArgument argument) + { + String descriptor; + if (argument.equals(NULL_DESCRIPTOR)) { + descriptor = "NULL"; + } + else { + descriptor = argument.getDescriptor().orElseThrow().getFields().stream() + .map(field -> anonymizer.anonymizeColumn(field.getName()) + field.getType().map(type -> " " + type.getDisplayName()).orElse("")) + .collect(joining(", ", "(", ")")); + } + return format("%s => DescriptorArgument{%s}", argumentName, descriptor); + } + + private String formatTableArgument(String argumentName, TableArgumentProperties argumentProperties) + { + StringBuilder properties = new StringBuilder(); + if (argumentProperties.isRowSemantics()) { + properties.append("row semantics"); + } + argumentProperties.getSpecification().ifPresent(specification -> { + properties + .append("partition by: [") + .append(Joiner.on(", ").join(anonymize(specification.getPartitionBy()))) + .append("]"); + specification.getOrderingScheme().ifPresent(orderingScheme -> { properties - .append("partition by: [") - .append(Joiner.on(", ").join(anonymize(specification.getPartitionBy()))) - .append("]"); - specification.getOrderingScheme().ifPresent(orderingScheme -> { - properties - .append(", order by: ") - .append(formatOrderingScheme(orderingScheme)); - }); + .append(", order by: ") + .append(formatOrderingScheme(orderingScheme)); }); - properties.append("required columns: [") - .append(Joiner.on(", ").join(anonymize(argumentProperties.getRequiredColumns()))) - .append("]"); - if (argumentProperties.isPruneWhenEmpty()) { - properties.append(", prune when empty"); - } - if (argumentProperties.getPassThroughSpecification().declaredAsPassThrough()) { - properties.append(", pass through columns"); - } - return format("%s => TableArgument{%s}", argumentName, properties); + }); + properties.append("required columns: [") + .append(Joiner.on(", ").join(anonymize(argumentProperties.getRequiredColumns()))) + .append("]"); + if (argumentProperties.isPruneWhenEmpty()) { + properties.append(", prune when empty"); } + if (argumentProperties.getPassThroughSpecification().declaredAsPassThrough()) { + properties.append(", pass through columns"); + } + return format("%s => TableArgument{%s}", argumentName, properties); + } + + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Context context) + { + ImmutableMap.Builder descriptor = ImmutableMap.builder(); + + descriptor.put("name", node.getName()); + + descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(anonymize(node.getProperOutputs())))); + + node.getSpecification().ifPresent(specification -> { + descriptor.put("partitionBy", format("[%s]", Joiner.on(", ").join(anonymize(specification.getPartitionBy())))); + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy: ", formatOrderingScheme(orderingScheme))); + }); + + addNode(node, "TableFunctionProcessor", descriptor.buildOrThrow(), context.tag()); + + return processChildren(node, new Context()); } @Override diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index 7f1d9565eb64..a398554bc01b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -66,6 +66,8 @@ import io.trino.sql.planner.plan.TableFinishNode; import io.trino.sql.planner.plan.TableFunctionNode; import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TopNNode; @@ -265,6 +267,71 @@ public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) return null; } + @Override + public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundSymbols) + { + PlanNode source = node.getSource(); + source.accept(this, boundSymbols); + + Set inputs = createInputs(source, boundSymbols); + + Set passThroughSymbols = node.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(PassThroughColumn::symbol) + .collect(toImmutableSet()); + checkDependencies( + inputs, + passThroughSymbols, + "Invalid node. Pass-through symbols (%s) not in source plan output (%s)", + passThroughSymbols, + source.getOutputSymbols()); + + Set requiredSymbols = node.getRequiredSymbols().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + checkDependencies( + inputs, + requiredSymbols, + "Invalid node. Required symbols (%s) not in source plan output (%s)", + requiredSymbols, + source.getOutputSymbols()); + + node.getMarkerSymbols().ifPresent(mapping -> { + checkDependencies( + inputs, + mapping.keySet(), + "Invalid node. Source symbols (%s) not in source plan output (%s)", + mapping.keySet(), + source.getOutputSymbols()); + checkDependencies( + inputs, + mapping.values(), + "Invalid node. Source marker symbols (%s) not in source plan output (%s)", + mapping.values(), + source.getOutputSymbols()); + }); + + node.getSpecification().ifPresent(specification -> { + checkDependencies( + inputs, + specification.getPartitionBy(), + "Invalid node. Partition by symbols (%s) not in source plan output (%s)", + specification.getPartitionBy(), + source.getOutputSymbols()); + specification.getOrderingScheme().ifPresent(orderingScheme -> { + checkDependencies( + inputs, + orderingScheme.getOrderBy(), + "Invalid node. Order by symbols (%s) not in source plan output (%s)", + orderingScheme.getOrderBy(), + source.getOutputSymbols()); + }); + }); + + return null; + } + @Override public Void visitWindow(WindowNode node, Set boundSymbols) { From 5d5eaf8e6b5737758589247fc106867482fb54d3 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Mon, 28 Nov 2022 15:15:06 +0100 Subject: [PATCH 7/9] Add pre-partitioned and pre-sorted properties to TableFunctionProcessingNode --- .../planner/optimizations/SymbolMapper.java | 9 +++- .../plan/TableFunctionProcessorNode.java | 48 ++++++++++++++++++- .../sql/planner/planprinter/PlanPrinter.java | 27 +++++++++-- 3 files changed, 78 insertions(+), 6 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index c862a41a1a79..ecdc367893e7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -380,6 +380,8 @@ public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode .map(mapping -> mapping.entrySet().stream() .collect(toMap(entry -> map(entry.getKey()), entry -> map(entry.getValue())))); + Optional newSpecification = node.getSpecification().map(specification -> mapAndDistinct(specification, node.getPreSorted())); + return new TableFunctionProcessorNode( node.getId(), node.getName(), @@ -388,7 +390,12 @@ public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode newPassThroughSpecifications, newRequiredSymbols, newMarkerSymbols, - node.getSpecification().map(this::mapAndDistinct), // TODO handle pre-partitioned, pre-sorted properly when we add them + newSpecification.map(SpecificationWithPreSortedPrefix::specification), + node.getPrePartitioned().stream() + .map(this::map) + .collect(toImmutableSet()), + newSpecification.map(SpecificationWithPreSortedPrefix::preSorted).orElse(node.getPreSorted()), + node.getHashSymbol().map(this::map), node.getHandle()); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java index 802730a47849..1071a89a2e5a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java @@ -17,7 +17,9 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import io.trino.metadata.TableFunctionHandle; +import io.trino.sql.planner.OrderingScheme; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; @@ -25,7 +27,9 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; import static java.util.Objects.requireNonNull; @@ -55,7 +59,10 @@ public class TableFunctionProcessorNode private final Optional> markerSymbols; // partitioning and ordering combined from sources - private final Optional specification; // TODO add pre-partitioned, pre-sorted + private final Optional specification; + private final Set prePartitioned; + private final int preSorted; + private final Optional hashSymbol; private final TableFunctionHandle handle; @@ -69,6 +76,9 @@ public TableFunctionProcessorNode( @JsonProperty("requiredSymbols") List> requiredSymbols, @JsonProperty("markerSymbols") Optional> markerSymbols, @JsonProperty("specification") Optional specification, + @JsonProperty("prePartitioned") Set prePartitioned, + @JsonProperty("preSorted") int preSorted, + @JsonProperty("hashSymbol") Optional hashSymbol, @JsonProperty("handle") TableFunctionHandle handle) { super(id); @@ -81,6 +91,22 @@ public TableFunctionProcessorNode( .collect(toImmutableList()); this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); this.specification = requireNonNull(specification, "specification is null"); + this.prePartitioned = ImmutableSet.copyOf(prePartitioned); + Set partitionBy = specification + .map(DataOrganizationSpecification::getPartitionBy) + .map(ImmutableSet::copyOf) + .orElse(ImmutableSet.of()); + checkArgument(partitionBy.containsAll(prePartitioned), "all pre-partitioned symbols must be contained in the partitioning list"); + this.preSorted = preSorted; + checkArgument( + specification + .flatMap(DataOrganizationSpecification::getOrderingScheme) + .map(OrderingScheme::getOrderBy) + .map(List::size) + .orElse(0) >= preSorted, + "the number of pre-sorted symbols cannot be greater than the number of all ordering symbols"); + checkArgument(preSorted == 0 || partitionBy.equals(prePartitioned), "to specify pre-sorted symbols, it is required that all partitioning symbols are pre-partitioned"); + this.hashSymbol = requireNonNull(hashSymbol, "hashSymbol is null"); this.handle = requireNonNull(handle, "handle is null"); } @@ -126,6 +152,24 @@ public Optional getSpecification() return specification; } + @JsonProperty + public Set getPrePartitioned() + { + return prePartitioned; + } + + @JsonProperty + public int getPreSorted() + { + return preSorted; + } + + @JsonProperty + public Optional getHashSymbol() + { + return hashSymbol; + } + @JsonProperty public TableFunctionHandle getHandle() { @@ -164,6 +208,6 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newSources) { - return new TableFunctionProcessorNode(getId(), name, properOutputs, getOnlyElement(newSources), passThroughSpecifications, requiredSymbols, markerSymbols, specification, handle); + return new TableFunctionProcessorNode(getId(), name, properOutputs, getOnlyElement(newSources), passThroughSpecifications, requiredSymbols, markerSymbols, specification, prePartitioned, preSorted, hashSymbol, handle); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index 90b6f0d07943..5bd1c626f1b6 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -1868,11 +1868,32 @@ public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Context descriptor.put("properOutputs", format("[%s]", Joiner.on(", ").join(anonymize(node.getProperOutputs())))); node.getSpecification().ifPresent(specification -> { - descriptor.put("partitionBy", format("[%s]", Joiner.on(", ").join(anonymize(specification.getPartitionBy())))); - specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy: ", formatOrderingScheme(orderingScheme))); + if (!specification.getPartitionBy().isEmpty()) { + List prePartitioned = specification.getPartitionBy().stream() + .filter(node.getPrePartitioned()::contains) + .collect(toImmutableList()); + + List notPrePartitioned = specification.getPartitionBy().stream() + .filter(column -> !node.getPrePartitioned().contains(column)) + .collect(toImmutableList()); + + StringBuilder builder = new StringBuilder(); + if (!prePartitioned.isEmpty()) { + builder.append(anonymize(prePartitioned).stream() + .collect(joining(", ", "<", ">"))); + if (!notPrePartitioned.isEmpty()) { + builder.append(", "); + } + } + if (!notPrePartitioned.isEmpty()) { + builder.append(Joiner.on(", ").join(anonymize(notPrePartitioned))); + } + descriptor.put("partitionBy", format("[%s]", builder)); + } + specification.getOrderingScheme().ifPresent(orderingScheme -> descriptor.put("orderBy", formatOrderingScheme(orderingScheme, node.getPreSorted()))); }); - addNode(node, "TableFunctionProcessor", descriptor.buildOrThrow(), context.tag()); + addNode(node, "TableFunctionProcessor", descriptor.put("hash", formatHash(node.getHashSymbol())).buildOrThrow(), context.tag()); return processChildren(node, new Context()); } From 68bbbca9cc32505badc343adab56f59a685a7187 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 11 Oct 2022 12:23:04 +0200 Subject: [PATCH 8/9] Make source of TableFunctionProcessingNode optional In case when there are no input tables for a table function invocation, the resulting TableFunctionProcessingNode has no soures. Otherwise, it has one source being a combination of all inputs. --- .../planner/optimizations/SymbolMapper.java | 2 +- .../UnaliasSymbolReferences.java | 22 ++++++++++++++++++- .../plan/TableFunctionProcessorNode.java | 11 +++++----- .../sanity/ValidateDependenciesChecker.java | 6 ++++- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index ecdc367893e7..1cf7f027a328 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -386,7 +386,7 @@ public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode node.getId(), node.getName(), map(node.getProperOutputs()), - source, + Optional.of(source), newPassThroughSpecifications, newRequiredSymbols, newMarkerSymbols, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 0e76e19a7aa5..1a56e85c33e7 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -371,7 +371,27 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext @Override public PlanAndMappings visitTableFunctionProcessor(TableFunctionProcessorNode node, UnaliasContext context) { - PlanAndMappings rewrittenSource = node.getSource().accept(this, context); + if (node.getSource().isEmpty()) { + Map mapping = new HashMap<>(context.getCorrelationMapping()); + SymbolMapper mapper = symbolMapper(mapping); + return new PlanAndMappings( + new TableFunctionProcessorNode( + node.getId(), + node.getName(), + mapper.map(node.getProperOutputs()), + Optional.empty(), + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + node.getHashSymbol().map(mapper::map), + node.getHandle()), + mapping); + } + + PlanAndMappings rewrittenSource = node.getSource().orElseThrow().accept(this, context); Map mapping = new HashMap<>(rewrittenSource.getMappings()); SymbolMapper mapper = symbolMapper(mapping); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java index 1071a89a2e5a..a4a36e2377ce 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java @@ -43,7 +43,7 @@ public class TableFunctionProcessorNode private final List properOutputs; // pre-planned sources - private final PlanNode source; + private final Optional source; // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? // all source symbols to be produced on output, ordered as table argument specifications @@ -71,7 +71,7 @@ public TableFunctionProcessorNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("name") String name, @JsonProperty("properOutputs") List properOutputs, - @JsonProperty("source") PlanNode source, + @JsonProperty("source") Optional source, @JsonProperty("passThroughSpecifications") List passThroughSpecifications, @JsonProperty("requiredSymbols") List> requiredSymbols, @JsonProperty("markerSymbols") Optional> markerSymbols, @@ -123,7 +123,7 @@ public List getProperOutputs() } @JsonProperty - public PlanNode getSource() + public Optional getSource() { return source; } @@ -180,7 +180,7 @@ public TableFunctionHandle getHandle() @Override public List getSources() { - return ImmutableList.of(source); + return source.map(ImmutableList::of).orElse(ImmutableList.of()); } @Override @@ -208,6 +208,7 @@ public R accept(PlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newSources) { - return new TableFunctionProcessorNode(getId(), name, properOutputs, getOnlyElement(newSources), passThroughSpecifications, requiredSymbols, markerSymbols, specification, prePartitioned, preSorted, hashSymbol, handle); + Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, passThroughSpecifications, requiredSymbols, markerSymbols, specification, prePartitioned, preSorted, hashSymbol, handle); } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java index a398554bc01b..f9475b2f4fd1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/sanity/ValidateDependenciesChecker.java @@ -270,7 +270,11 @@ public Void visitTableFunction(TableFunctionNode node, Set boundSymbols) @Override public Void visitTableFunctionProcessor(TableFunctionProcessorNode node, Set boundSymbols) { - PlanNode source = node.getSource(); + if (node.getSource().isEmpty()) { + return null; + } + + PlanNode source = node.getSource().orElseThrow(); source.accept(this, boundSymbols); Set inputs = createInputs(source, boundSymbols); From ed5cc544fa141dd7bf92984a5765dcd5a38ea25f Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Tue, 11 Oct 2022 12:38:14 +0200 Subject: [PATCH 9/9] Implement table function source Support arbitrary number of sources (including no sources), involving row and set semantics, prune/keep when empty properties, and co-partitioning. --- .../io/trino/sql/planner/PlanOptimizers.java | 7 +- .../rule/ImplementTableFunctionSource.java | 756 +++++++++ .../planner/optimizations/SymbolMapper.java | 1 + .../UnaliasSymbolReferences.java | 1 + .../plan/TableFunctionProcessorNode.java | 14 +- .../planner/assertions/PlanMatchPattern.java | 16 +- .../TableFunctionProcessorMatcher.java | 204 +++ .../TestImplementTableFunctionSource.java | 1360 +++++++++++++++++ .../iterative/rule/test/PlanBuilder.java | 24 + 9 files changed, 2380 insertions(+), 3 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java create mode 100644 core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java create mode 100644 core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index e548c2305023..b3a8f1c6b2de 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -61,6 +61,7 @@ import io.trino.sql.planner.iterative.rule.ImplementIntersectDistinctAsUnion; import io.trino.sql.planner.iterative.rule.ImplementLimitWithTies; import io.trino.sql.planner.iterative.rule.ImplementOffset; +import io.trino.sql.planner.iterative.rule.ImplementTableFunctionSource; import io.trino.sql.planner.iterative.rule.InlineProjectIntoFilter; import io.trino.sql.planner.iterative.rule.InlineProjections; import io.trino.sql.planner.iterative.rule.MergeExcept; @@ -629,7 +630,11 @@ public PlanOptimizers( costCalculator, // Temporary hack: separate optimizer step to avoid the sample node being replaced by filter before pushing // it to table scan node - ImmutableSet.of(new ImplementBernoulliSampleAsFilter(metadata))), + ImmutableSet.of( + new ImplementBernoulliSampleAsFilter(metadata), + // Must run after RewriteTableFunctionToTableScan because that rule applies to TableFunctionNode. + // While the node gets rewritten to TableFunctionProcessorNode, we can no longer pushdown the function to the connector. + new ImplementTableFunctionSource(metadata))), columnPruningOptimizer, new IterativeOptimizer( plannerContext, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java new file mode 100644 index 000000000000..4d18d1f03dfd --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/ImplementTableFunctionSource.java @@ -0,0 +1,756 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Streams; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.metadata.Metadata; +import io.trino.metadata.ResolvedFunction; +import io.trino.spi.type.Type; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; +import io.trino.sql.planner.plan.WindowNode; +import io.trino.sql.planner.plan.WindowNode.Frame; +import io.trino.sql.tree.Cast; +import io.trino.sql.tree.CoalesceExpression; +import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.Expression; +import io.trino.sql.tree.GenericLiteral; +import io.trino.sql.tree.IfExpression; +import io.trino.sql.tree.LogicalExpression; +import io.trino.sql.tree.NotExpression; +import io.trino.sql.tree.NullLiteral; +import io.trino.sql.tree.QualifiedName; + +import java.util.Collection; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureTranslator.toSqlType; +import static io.trino.sql.planner.plan.JoinNode.Type.FULL; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; +import static io.trino.sql.planner.plan.JoinNode.Type.RIGHT; +import static io.trino.sql.planner.plan.Patterns.tableFunction; +import static io.trino.sql.tree.ComparisonExpression.Operator.EQUAL; +import static io.trino.sql.tree.ComparisonExpression.Operator.GREATER_THAN; +import static io.trino.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; +import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_FOLLOWING; +import static io.trino.sql.tree.FrameBound.Type.UNBOUNDED_PRECEDING; +import static io.trino.sql.tree.LogicalExpression.Operator.AND; +import static io.trino.sql.tree.LogicalExpression.Operator.OR; +import static io.trino.sql.tree.WindowFrame.Type.ROWS; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +/** + * This rule prepares cartesian product of partitions + * from all inputs of table function. + *

+ * It rewrites TableFunctionNode with potentially many sources + * into a TableFunctionProcessorNode. The new node has one + * source being a combination of the original sources. + *

+ * The original sources are combined with joins. The join + * conditions depend on the prune when empty property, and on + * the co-partitioning of sources. + *

+ * The resulting source should be partitioned and ordered + * according to combined schemas from the component sources. + *

+ * Example transformation for two sources, both with set semantics + * and KEEP WHEN EMPTY property: + *

+ * - TableFunction foo
+ *      - source T1(a1, b1) PARTITION BY a1 ORDER BY b1
+ *      - source T2(a2, b2) PARTITION BY a2
+ * 
+ * Is transformed into: + *
+ * - TableFunctionProcessor foo
+ *      PARTITION BY (a1, a2), ORDER BY combined_row_number
+ *      - Project
+ *          marker_1 <= IF(table1_row_number = combined_row_number, table1_row_number, CAST(null AS bigint))
+ *          marker_2 <= IF(table2_row_number = combined_row_number, table2_row_number, CAST(null AS bigint))
+ *          - Project
+ *              combined_row_number <= IF(COALESCE(table1_row_number, BIGINT '-1') > COALESCE(table2_row_number, BIGINT '-1'), table1_row_number, table2_row_number)
+ *              combined_partition_size <= IF(COALESCE(table1_partition_size, BIGINT '-1') > COALESCE(table2_partition_size, BIGINT '-1'), table1_partition_size, table2_partition_size)
+ *              - FULL Join
+ *                  [table1_row_number = table2_row_number OR
+ *                   table1_row_number > table2_partition_size AND table2_row_number = BIGINT '1' OR
+ *                   table2_row_number > table1_partition_size AND table1_row_number = BIGINT '1']
+ *                  - Window [PARTITION BY a1 ORDER BY b1]
+ *                      table1_row_number <= row_number()
+ *                      table1_partition_size <= count()
+ *                          - source T1(a1, b1)
+ *                  - Window [PARTITION BY a2]
+ *                      table2_row_number <= row_number()
+ *                      table2_partition_size <= count()
+ *                          - source T2(a2, b2)
+ * 
+ */ +public class ImplementTableFunctionSource + implements Rule +{ + private static final Pattern PATTERN = tableFunction(); + private static final Frame FULL_FRAME = new Frame( + ROWS, + UNBOUNDED_PRECEDING, + Optional.empty(), + Optional.empty(), + UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty()); + private static final DataOrganizationSpecification UNORDERED_SINGLE_PARTITION = new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()); + + private final Metadata metadata; + + public ImplementTableFunctionSource(Metadata metadata) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(TableFunctionNode node, Captures captures, Context context) + { + if (node.getSources().isEmpty()) { + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.empty(), + false, + ImmutableList.of(), + ImmutableList.of(), + Optional.empty(), + Optional.empty(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + if (node.getSources().size() == 1) { + // Single source does not require pre-processing. + // If the source has row semantics, its specification is empty. + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // This property can be used later to choose optimal distribution. + TableArgumentProperties sourceProperties = getOnlyElement(node.getTableArgumentProperties()); + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(getOnlyElement(node.getSources())), + sourceProperties.isPruneWhenEmpty(), + ImmutableList.of(sourceProperties.getPassThroughSpecification()), + ImmutableList.of(sourceProperties.getRequiredColumns()), + Optional.empty(), + sourceProperties.getSpecification(), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + Map sources = mapSourcesByName(node.getSources(), node.getTableArgumentProperties()); + ImmutableList.Builder intermediateResultsBuilder = ImmutableList.builder(); + ResolvedFunction rowNumberFunction = metadata.resolveFunction(context.getSession(), QualifiedName.of("row_number"), ImmutableList.of()); + ResolvedFunction countFunction = metadata.resolveFunction(context.getSession(), QualifiedName.of("count"), ImmutableList.of()); + + // handle co-partitioned sources + for (List copartitioningList : node.getCopartitioningLists()) { + List sourceList = copartitioningList.stream() + .map(sources::get) + .collect(toImmutableList()); + intermediateResultsBuilder.add(copartition(sourceList, rowNumberFunction, countFunction, context)); + } + + // prepare non-co-partitioned sources + Set copartitionedSources = node.getCopartitioningLists().stream() + .flatMap(Collection::stream) + .collect(toImmutableSet()); + sources.entrySet().stream() + .filter(entry -> !copartitionedSources.contains(entry.getKey())) + .map(entry -> planWindowFunctionsForSource(entry.getValue().source(), entry.getValue().properties(), rowNumberFunction, countFunction, context)) + .forEach(intermediateResultsBuilder::add); + + NodeWithSymbols finalResultSource; + + List intermediateResultSources = intermediateResultsBuilder.build(); + if (intermediateResultSources.size() == 1) { + finalResultSource = getOnlyElement(intermediateResultSources); + } + else { + NodeWithSymbols first = intermediateResultSources.get(0); + NodeWithSymbols second = intermediateResultSources.get(1); + JoinedNodes joined = join(first, second, context); + + for (int i = 2; i < intermediateResultSources.size(); i++) { + NodeWithSymbols joinedWithSymbols = appendHelperSymbolsForJoinedNodes(joined, context); + joined = join(joinedWithSymbols, intermediateResultSources.get(i), context); + } + + finalResultSource = appendHelperSymbolsForJoinedNodes(joined, context); + } + + // For each source, all source's output symbols are mapped to the source's row number symbol. + // The row number symbol will be later converted to a marker of "real" input rows vs "filler" input rows of the source. + // The "filler" input rows are the rows appended while joining partitions of different lengths, + // to fill the smaller partition up to the bigger partition's size. They are a side effect of the algorithm, + // and should not be processed by the table function. + Map rowNumberSymbols = finalResultSource.rowNumberSymbolsMapping(); + + // The max row number symbol from all joined partitions. + Symbol finalRowNumberSymbol = finalResultSource.rowNumber(); + // Combined partitioning lists from all sources. + List finalPartitionBy = finalResultSource.partitionBy(); + + NodeWithMarkers marked = appendMarkerSymbols(finalResultSource.node(), ImmutableSet.copyOf(rowNumberSymbols.values()), finalRowNumberSymbol, context); + + // Remap the symbol mapping: replace the row number symbol with the corresponding marker symbol. + // In the new map, every source symbol is associated with the corresponding marker symbol. + // Null value of the marker indicates that the source value should be ignored by the table function. + ImmutableMap markerSymbols = rowNumberSymbols.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> marked.symbolToMarker().get(entry.getValue()))); + + // Use the final row number symbol for ordering the combined sources. + // It runs along each partition in the cartesian product, numbering the partition's rows according to the expected ordering / orderings. + // note: ordering is necessary even if all the source tables are not ordered. Thanks to the ordering, the original rows + // of each input table come before the "filler" rows. + Optional finalOrderBy = Optional.of(new OrderingScheme(ImmutableList.of(finalRowNumberSymbol), ImmutableMap.of(finalRowNumberSymbol, ASC_NULLS_LAST))); + + // derive the prune when empty property + boolean pruneWhenEmpty = node.getTableArgumentProperties().stream().anyMatch(TableArgumentProperties::isPruneWhenEmpty); + + // Combine the pass through specifications from all sources + List passThroughSpecifications = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getPassThroughSpecification) + .collect(toImmutableList()); + + // Combine the required symbols from all sources + List> requiredSymbols = node.getTableArgumentProperties().stream() + .map(TableArgumentProperties::getRequiredColumns) + .collect(toImmutableList()); + + return Result.ofPlanNode(new TableFunctionProcessorNode( + node.getId(), + node.getName(), + node.getProperOutputs(), + Optional.of(marked.node()), + pruneWhenEmpty, + passThroughSpecifications, + requiredSymbols, + Optional.of(markerSymbols), + Optional.of(new DataOrganizationSpecification(finalPartitionBy, finalOrderBy)), + ImmutableSet.of(), + 0, + Optional.empty(), + node.getHandle())); + } + + private static Map mapSourcesByName(List sources, List properties) + { + return Streams.zip(sources.stream(), properties.stream(), SourceWithProperties::new) + .collect(toImmutableMap(entry -> entry.properties().getArgumentName(), identity())); + } + + private static NodeWithSymbols planWindowFunctionsForSource( + PlanNode source, + TableArgumentProperties argumentProperties, + ResolvedFunction rowNumberFunction, + ResolvedFunction countFunction, + Context context) + { + String argumentName = argumentProperties.getArgumentName(); + + Symbol rowNumber = context.getSymbolAllocator().newSymbol(argumentName + "_row_number", BIGINT); + Map rowNumberSymbolMapping = source.getOutputSymbols().stream() + .collect(toImmutableMap(identity(), symbol -> rowNumber)); + + Symbol partitionSize = context.getSymbolAllocator().newSymbol(argumentName + "_partition_size", BIGINT); + + // If the source has set semantics, its specification is present, even if there is no partitioning or ordering specified. + // If the source has row semantics, its specification is empty. Currently, such source is processed + // as if it was a single partition. Alternatively, it could be split into smaller partitions of arbitrary size. + DataOrganizationSpecification specification = argumentProperties.getSpecification().orElse(UNORDERED_SINGLE_PARTITION); + + PlanNode window = new WindowNode( + context.getIdAllocator().getNextId(), + source, + specification, + ImmutableMap.of( + rowNumber, new WindowNode.Function(rowNumberFunction, ImmutableList.of(), FULL_FRAME, false), + partitionSize, new WindowNode.Function(countFunction, ImmutableList.of(), FULL_FRAME, false)), + Optional.empty(), + ImmutableSet.of(), + 0); + + return new NodeWithSymbols(window, rowNumber, partitionSize, specification.getPartitionBy(), argumentProperties.isPruneWhenEmpty(), rowNumberSymbolMapping); + } + + private static NodeWithSymbols copartition( + List sourceList, + ResolvedFunction rowNumberFunction, + ResolvedFunction countFunction, + Context context) + { + checkArgument(sourceList.size() >= 2, "co-partitioning list should contain at least two tables"); + + // Reorder the co-partitioned sources to process the sources with prune when empty property first. + // It allows to use inner or side joins instead of outer joins. + sourceList = sourceList.stream() + .sorted(Comparator.comparingInt(source -> source.properties().isPruneWhenEmpty() ? -1 : 1)) + .collect(toImmutableList()); + + NodeWithSymbols first = planWindowFunctionsForSource(sourceList.get(0).source(), sourceList.get(0).properties(), rowNumberFunction, countFunction, context); + NodeWithSymbols second = planWindowFunctionsForSource(sourceList.get(1).source(), sourceList.get(1).properties(), rowNumberFunction, countFunction, context); + JoinedNodes copartitioned = copartition(first, second, context); + + for (int i = 2; i < sourceList.size(); i++) { + NodeWithSymbols copartitionedWithSymbols = appendHelperSymbolsForCopartitionedNodes(copartitioned, context); + NodeWithSymbols next = planWindowFunctionsForSource(sourceList.get(i).source(), sourceList.get(i).properties(), rowNumberFunction, countFunction, context); + copartitioned = copartition(copartitionedWithSymbols, next, context); + } + + return appendHelperSymbolsForCopartitionedNodes(copartitioned, context); + } + + private static JoinedNodes copartition(NodeWithSymbols left, NodeWithSymbols right, Context context) + { + checkArgument(left.partitionBy().size() == right.partitionBy().size(), "co-partitioning lists do not match"); + + // In StatementAnalyzer we require that co-partitioned tables have non-empty partitioning column lists. + // Co-partitioning tables with empty partition by would be ineffective. + checkState(!left.partitionBy().isEmpty(), "co-partitioned tables must have partitioning columns"); + + Expression leftRowNumber = left.rowNumber().toSymbolReference(); + Expression leftPartitionSize = left.partitionSize().toSymbolReference(); + List leftPartitionBy = left.partitionBy().stream() + .map(Symbol::toSymbolReference) + .collect(toImmutableList()); + Expression rightRowNumber = right.rowNumber().toSymbolReference(); + Expression rightPartitionSize = right.partitionSize().toSymbolReference(); + List rightPartitionBy = right.partitionBy().stream() + .map(Symbol::toSymbolReference) + .collect(toImmutableList()); + + List copartitionConjuncts = Streams.zip( + leftPartitionBy.stream(), + rightPartitionBy.stream(), + (leftColumn, rightColumn) -> new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, leftColumn, rightColumn))) + .collect(toImmutableList()); + + // Align matching partitions (co-partitions) from left and right source, according to row number. + // Matching partitions are identified by their corresponding partitioning columns being NOT DISTINCT from each other. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. + // It preserves the outstanding rows from the bigger partition, matching them to the first row from the smaller partition. + // + // (P1_1 IS NOT DISTINCT FROM P2_1) AND (P1_2 IS NOT DISTINCT FROM P2_2) AND ... + // AND ( + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1)) + Expression joinCondition = new LogicalExpression( + AND, + ImmutableList.builder() + .addAll(copartitionConjuncts) + .add(new LogicalExpression(OR, ImmutableList.of( + new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), + new ComparisonExpression(EQUAL, rightRowNumber, new GenericLiteral("BIGINT", "1")))), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), + new ComparisonExpression(EQUAL, leftRowNumber, new GenericLiteral("BIGINT", "1"))))))) + .build()); + + // The join type depends on the prune when empty property of the sources. + // If a source is prune when empty, we should not process any co-partition which is not present in this source, + // so effectively the other source becomes inner side of the join. + // + // example: + // table T1 partition by P1 table T2 partition by P2 + // P1 C1 P2 C2 + // ---------- ---------- + // 1 'a' 2 'c' + // 2 'b' 3 'd' + // + // co-partitioning results: + // 1) T1 is prune when empty: do LEFT JOIN to drop co-partition '3' + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // + // 2) T2 is prune when empty: do RIGHT JOIN to drop co-partition '1' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // null null 3 'd' + // + // 3) T1 and T2 are both prune when empty: do INNER JOIN to drop co-partitions '1' and '3' + // P1 C1 P2 C2 + // ------------------------ + // 2 'b' 2 'c' + // + // 4) neither table is prune when empty: do FULL JOIN to preserve all co-partitions + // P1 C1 P2 C2 + // ------------------------ + // 1 'a' null null + // 2 'b' 2 'c' + // null null 3 'd' + JoinNode.Type joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + left.node().getOutputSymbols(), + right.node().getOutputSymbols(), + false, + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithSymbols appendHelperSymbolsForCopartitionedNodes( + JoinedNodes copartitionedNodes, + Context context) + { + checkArgument(copartitionedNodes.leftPartitionBy().size() == copartitionedNodes.rightPartitionBy().size(), "co-partitioning lists do not match"); + + Expression leftRowNumber = copartitionedNodes.leftRowNumber().toSymbolReference(); + Expression leftPartitionSize = copartitionedNodes.leftPartitionSize().toSymbolReference(); + Expression rightRowNumber = copartitionedNodes.rightRowNumber().toSymbolReference(); + Expression rightPartitionSize = copartitionedNodes.rightPartitionSize().toSymbolReference(); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); + Expression rowNumberExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftRowNumber, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightRowNumber, new GenericLiteral("BIGINT", "-1"))), + leftRowNumber, + rightRowNumber); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); + Expression partitionSizeExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftPartitionSize, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightPartitionSize, new GenericLiteral("BIGINT", "-1"))), + leftPartitionSize, + rightPartitionSize); + + // Derive partitioning columns for joined partitions. + // Either the combined partitioning columns are pairwise NOT DISTINCT (this is the co-partitioning rule), + // or one of them is null as a result of outer join. + ImmutableList.Builder joinedPartitionBy = ImmutableList.builder(); + Assignments.Builder joinedPartitionByAssignments = Assignments.builder(); + for (int i = 0; i < copartitionedNodes.leftPartitionBy().size(); i++) { + Symbol leftColumn = copartitionedNodes.leftPartitionBy().get(i); + Symbol rightColumn = copartitionedNodes.rightPartitionBy().get(i); + Type type = context.getSymbolAllocator().getTypes().get(leftColumn); + + Symbol joinedColumn = context.getSymbolAllocator().newSymbol("combined_partition_column", type); + joinedPartitionByAssignments.put(joinedColumn, new CoalesceExpression(leftColumn.toSymbolReference(), rightColumn.toSymbolReference())); + joinedPartitionBy.add(joinedColumn); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + copartitionedNodes.joinedNode(), + Assignments.builder() + .putIdentities(copartitionedNodes.joinedNode().getOutputSymbols()) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .putAll(joinedPartitionByAssignments.build()) + .build()); + + boolean joinedPruneWhenEmpty = copartitionedNodes.leftPruneWhenEmpty() || copartitionedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(copartitionedNodes.leftRowNumberSymbolsMapping()) + .putAll(copartitionedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithSymbols(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy.build(), joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static JoinedNodes join(NodeWithSymbols left, NodeWithSymbols right, Context context) + { + Expression leftRowNumber = left.rowNumber().toSymbolReference(); + Expression leftPartitionSize = left.partitionSize().toSymbolReference(); + Expression rightRowNumber = right.rowNumber().toSymbolReference(); + Expression rightPartitionSize = right.partitionSize().toSymbolReference(); + + // Align rows from left and right source according to row number. Because every partition is row-numbered, this produces cartesian product of partitions. + // If one or both sources are ordered, the row number reflects the ordering. + // The second and third disjunct in the join condition account for the situation when partitions have different sizes. It preserves the outstanding rows + // from the bigger partition, matching them to the first row from the smaller partition. + // + // R1 = R2 + // OR + // (R1 > S2 AND R2 = 1) + // OR + // (R2 > S1 AND R1 = 1) + Expression joinCondition = new LogicalExpression(OR, ImmutableList.of( + new ComparisonExpression(EQUAL, leftRowNumber, rightRowNumber), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, leftRowNumber, rightPartitionSize), + new ComparisonExpression(EQUAL, rightRowNumber, new GenericLiteral("BIGINT", "1")))), + new LogicalExpression(AND, ImmutableList.of( + new ComparisonExpression(GREATER_THAN, rightRowNumber, leftPartitionSize), + new ComparisonExpression(EQUAL, leftRowNumber, new GenericLiteral("BIGINT", "1")))))); + + JoinNode.Type joinType; + if (left.pruneWhenEmpty() && right.pruneWhenEmpty()) { + joinType = INNER; + } + else if (left.pruneWhenEmpty()) { + joinType = LEFT; + } + else if (right.pruneWhenEmpty()) { + joinType = RIGHT; + } + else { + joinType = FULL; + } + + return new JoinedNodes( + new JoinNode( + context.getIdAllocator().getNextId(), + joinType, + left.node(), + right.node(), + ImmutableList.of(), + left.node().getOutputSymbols(), + right.node().getOutputSymbols(), + false, + Optional.of(joinCondition), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()), + left.rowNumber(), + left.partitionSize(), + left.partitionBy(), + left.pruneWhenEmpty(), + left.rowNumberSymbolsMapping(), + right.rowNumber(), + right.partitionSize(), + right.partitionBy(), + right.pruneWhenEmpty(), + right.rowNumberSymbolsMapping()); + } + + private static NodeWithSymbols appendHelperSymbolsForJoinedNodes(JoinedNodes joinedNodes, Context context) + { + Expression leftRowNumber = joinedNodes.leftRowNumber().toSymbolReference(); + Expression leftPartitionSize = joinedNodes.leftPartitionSize().toSymbolReference(); + Expression rightRowNumber = joinedNodes.rightRowNumber().toSymbolReference(); + Expression rightPartitionSize = joinedNodes.rightPartitionSize().toSymbolReference(); + + // Derive row number for joined partitions: this is the bigger partition's row number. One of the combined values might be null as a result of outer join. + Symbol joinedRowNumber = context.getSymbolAllocator().newSymbol("combined_row_number", BIGINT); + Expression rowNumberExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftRowNumber, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightRowNumber, new GenericLiteral("BIGINT", "-1"))), + leftRowNumber, + rightRowNumber); + + // Derive partition size for joined partitions: this is the bigger partition's size. One of the combined values might be null as a result of outer join. + Symbol joinedPartitionSize = context.getSymbolAllocator().newSymbol("combined_partition_size", BIGINT); + Expression partitionSizeExpression = new IfExpression( + new ComparisonExpression( + GREATER_THAN, + new CoalesceExpression(leftPartitionSize, new GenericLiteral("BIGINT", "-1")), + new CoalesceExpression(rightPartitionSize, new GenericLiteral("BIGINT", "-1"))), + leftPartitionSize, + rightPartitionSize); + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + joinedNodes.joinedNode(), + Assignments.builder() + .putIdentities(joinedNodes.joinedNode().getOutputSymbols()) + .put(joinedRowNumber, rowNumberExpression) + .put(joinedPartitionSize, partitionSizeExpression) + .build()); + + List joinedPartitionBy = ImmutableList.builder() + .addAll(joinedNodes.leftPartitionBy()) + .addAll(joinedNodes.rightPartitionBy()) + .build(); + + boolean joinedPruneWhenEmpty = joinedNodes.leftPruneWhenEmpty() || joinedNodes.rightPruneWhenEmpty(); + + Map joinedRowNumberSymbolsMapping = ImmutableMap.builder() + .putAll(joinedNodes.leftRowNumberSymbolsMapping()) + .putAll(joinedNodes.rightRowNumberSymbolsMapping()) + .buildOrThrow(); + + return new NodeWithSymbols(project, joinedRowNumber, joinedPartitionSize, joinedPartitionBy, joinedPruneWhenEmpty, joinedRowNumberSymbolsMapping); + } + + private static NodeWithMarkers appendMarkerSymbols(PlanNode node, Set symbols, Symbol referenceSymbol, Context context) + { + Assignments.Builder assignments = Assignments.builder(); + assignments.putIdentities(node.getOutputSymbols()); + + ImmutableMap.Builder symbolsToMarkers = ImmutableMap.builder(); + + for (Symbol symbol : symbols) { + Symbol marker = context.getSymbolAllocator().newSymbol("marker", BIGINT); + symbolsToMarkers.put(symbol, marker); + Expression actual = symbol.toSymbolReference(); + Expression reference = referenceSymbol.toSymbolReference(); + assignments.put(marker, new IfExpression(new ComparisonExpression(EQUAL, actual, reference), actual, new Cast(new NullLiteral(), toSqlType(BIGINT)))); + } + + PlanNode project = new ProjectNode( + context.getIdAllocator().getNextId(), + node, + assignments.build()); + + return new NodeWithMarkers(project, symbolsToMarkers.buildOrThrow()); + } + + private record SourceWithProperties(PlanNode source, TableArgumentProperties properties) + { + SourceWithProperties + { + requireNonNull(source, "source is null"); + requireNonNull(properties, "properties is null"); + } + } + + private record NodeWithSymbols(PlanNode node, Symbol rowNumber, Symbol partitionSize, List partitionBy, boolean pruneWhenEmpty, Map rowNumberSymbolsMapping) + { + NodeWithSymbols + { + requireNonNull(node, "node is null"); + requireNonNull(rowNumber, "rowNumber is null"); + requireNonNull(partitionSize, "partitionSize is null"); + partitionBy = ImmutableList.copyOf(partitionBy); + rowNumberSymbolsMapping = ImmutableMap.copyOf(rowNumberSymbolsMapping); + } + } + + private record JoinedNodes( + PlanNode joinedNode, + Symbol leftRowNumber, + Symbol leftPartitionSize, + List leftPartitionBy, + boolean leftPruneWhenEmpty, + Map leftRowNumberSymbolsMapping, + Symbol rightRowNumber, + Symbol rightPartitionSize, + List rightPartitionBy, + boolean rightPruneWhenEmpty, + Map rightRowNumberSymbolsMapping) + { + JoinedNodes + { + requireNonNull(joinedNode, "joinedNode is null"); + requireNonNull(leftRowNumber, "leftRowNumber is null"); + requireNonNull(leftPartitionSize, "leftPartitionSize is null"); + leftPartitionBy = ImmutableList.copyOf(leftPartitionBy); + leftRowNumberSymbolsMapping = ImmutableMap.copyOf(leftRowNumberSymbolsMapping); + requireNonNull(rightRowNumber, "rightRowNumber is null"); + requireNonNull(rightPartitionSize, "rightPartitionSize is null"); + rightPartitionBy = ImmutableList.copyOf(rightPartitionBy); + rightRowNumberSymbolsMapping = ImmutableMap.copyOf(rightRowNumberSymbolsMapping); + } + } + + private record NodeWithMarkers(PlanNode node, Map symbolToMarker) + { + NodeWithMarkers + { + requireNonNull(node, "node is null"); + symbolToMarker = ImmutableMap.copyOf(symbolToMarker); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java index 1cf7f027a328..7dbe89b5c813 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/SymbolMapper.java @@ -387,6 +387,7 @@ public TableFunctionProcessorNode map(TableFunctionProcessorNode node, PlanNode node.getName(), map(node.getProperOutputs()), Optional.of(source), + node.isPruneWhenEmpty(), newPassThroughSpecifications, newRequiredSymbols, newMarkerSymbols, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 1a56e85c33e7..7d92beb9e1c4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -380,6 +380,7 @@ public PlanAndMappings visitTableFunctionProcessor(TableFunctionProcessorNode no node.getName(), mapper.map(node.getProperOutputs()), Optional.empty(), + node.isPruneWhenEmpty(), ImmutableList.of(), ImmutableList.of(), Optional.empty(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java index a4a36e2377ce..e04e5450a599 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionProcessorNode.java @@ -46,6 +46,10 @@ public class TableFunctionProcessorNode private final Optional source; // TODO do we need the info of which source has row semantics, or is it already included in the joins / join distribution? + // specifies whether the function should be pruned or executed when the input is empty + // pruneWhenEmpty is false if and only if all original input tables are KEEP WHEN EMPTY + private final boolean pruneWhenEmpty; + // all source symbols to be produced on output, ordered as table argument specifications private final List passThroughSpecifications; @@ -72,6 +76,7 @@ public TableFunctionProcessorNode( @JsonProperty("name") String name, @JsonProperty("properOutputs") List properOutputs, @JsonProperty("source") Optional source, + @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, @JsonProperty("passThroughSpecifications") List passThroughSpecifications, @JsonProperty("requiredSymbols") List> requiredSymbols, @JsonProperty("markerSymbols") Optional> markerSymbols, @@ -85,6 +90,7 @@ public TableFunctionProcessorNode( this.name = requireNonNull(name, "name is null"); this.properOutputs = ImmutableList.copyOf(properOutputs); this.source = requireNonNull(source, "source is null"); + this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughSpecifications = ImmutableList.copyOf(passThroughSpecifications); this.requiredSymbols = requiredSymbols.stream() .map(ImmutableList::copyOf) @@ -128,6 +134,12 @@ public Optional getSource() return source; } + @JsonProperty + public boolean isPruneWhenEmpty() + { + return pruneWhenEmpty; + } + @JsonProperty public List getPassThroughSpecifications() { @@ -209,6 +221,6 @@ public R accept(PlanVisitor visitor, C context) public PlanNode replaceChildren(List newSources) { Optional newSource = newSources.isEmpty() ? Optional.empty() : Optional.of(getOnlyElement(newSources)); - return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, passThroughSpecifications, requiredSymbols, markerSymbols, specification, prePartitioned, preSorted, hashSymbol, handle); + return new TableFunctionProcessorNode(getId(), name, properOutputs, newSource, pruneWhenEmpty, passThroughSpecifications, requiredSymbols, markerSymbols, specification, prePartitioned, preSorted, hashSymbol, handle); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 4d52921ed152..a7c2b50384c9 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -128,7 +128,7 @@ public static PlanMatchPattern any(PlanMatchPattern... sources) * Matches to any tree of nodes with children matching to given source matchers. * anyTree(tableScan("nation")) - will match to any plan which all leafs contain * any node containing table scan from nation table. - * + *

* Note: anyTree does not match zero nodes. E.g. output(anyTree(tableScan)) will NOT match TableScan node followed by OutputNode. */ public static PlanMatchPattern anyTree(PlanMatchPattern... sources) @@ -848,6 +848,20 @@ public static PlanMatchPattern tableFunction(Consumer handler, PlanMatchPattern source) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(source); + handler.accept(builder); + return builder.build(); + } + + public static PlanMatchPattern tableFunctionProcessor(Consumer handler) + { + TableFunctionProcessorMatcher.Builder builder = new TableFunctionProcessorMatcher.Builder(); + handler.accept(builder); + return builder.build(); + } + public PlanMatchPattern(List sourcePatterns) { requireNonNull(sourcePatterns, "sourcePatterns are null"); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java new file mode 100644 index 000000000000..e9279d35fbef --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/TableFunctionProcessorMatcher.java @@ -0,0 +1,204 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.assertions; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.cost.StatsProvider; +import io.trino.metadata.Metadata; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionProcessorNode; +import io.trino.sql.tree.SymbolReference; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.MoreObjects.toStringHelper; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.sql.planner.assertions.MatchResult.NO_MATCH; +import static io.trino.sql.planner.assertions.MatchResult.match; +import static io.trino.sql.planner.assertions.PlanMatchPattern.node; +import static java.util.Objects.requireNonNull; + +public class TableFunctionProcessorMatcher + implements Matcher +{ + private final String name; + private final List properOutputs; + private final Set passThroughSymbols; + private final Optional> markerSymbols; + private final Optional> specification; + + private TableFunctionProcessorMatcher( + String name, + List properOutputs, + Set passThroughSymbols, + Optional> markerSymbols, + Optional> specification) + { + this.name = requireNonNull(name, "name is null"); + this.properOutputs = ImmutableList.copyOf(properOutputs); + this.passThroughSymbols = ImmutableSet.copyOf(passThroughSymbols); + this.markerSymbols = markerSymbols.map(ImmutableMap::copyOf); + this.specification = requireNonNull(specification, "specification is null"); + } + + @Override + public boolean shapeMatches(PlanNode node) + { + return node instanceof TableFunctionProcessorNode; + } + + @Override + public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session session, Metadata metadata, SymbolAliases symbolAliases) + { + checkState(shapeMatches(node), "Plan testing framework error: shapeMatches returned false in detailMatches in %s", this.getClass().getName()); + + TableFunctionProcessorNode tableFunctionProcessorNode = (TableFunctionProcessorNode) node; + + if (!name.equals(tableFunctionProcessorNode.getName())) { + return NO_MATCH; + } + + if (properOutputs.size() != tableFunctionProcessorNode.getProperOutputs().size()) { + return NO_MATCH; + } + + Set expectedPassThrough = passThroughSymbols.stream() + .map(symbolAliases::get) + .collect(toImmutableSet()); + Set actualPassThrough = tableFunctionProcessorNode.getPassThroughSpecifications().stream() + .map(PassThroughSpecification::columns) + .flatMap(Collection::stream) + .map(PassThroughColumn::symbol) + .map(Symbol::toSymbolReference) + .collect(toImmutableSet()); + if (!expectedPassThrough.equals(actualPassThrough)) { + return NO_MATCH; + } + + if (markerSymbols.isPresent() != tableFunctionProcessorNode.getMarkerSymbols().isPresent()) { + return NO_MATCH; + } + if (markerSymbols.isPresent()) { + Map expectedMapping = markerSymbols.get().entrySet().stream() + .collect(toImmutableMap(entry -> symbolAliases.get(entry.getKey()), entry -> symbolAliases.get(entry.getValue()))); + Map actualMapping = tableFunctionProcessorNode.getMarkerSymbols().orElseThrow().entrySet().stream() + .collect(toImmutableMap(entry -> entry.getKey().toSymbolReference(), entry -> entry.getValue().toSymbolReference())); + if (!expectedMapping.equals(actualMapping)) { + return NO_MATCH; + } + } + + if (specification.isPresent() != tableFunctionProcessorNode.getSpecification().isPresent()) { + return NO_MATCH; + } + if (specification.isPresent()) { + if (!specification.get().getExpectedValue(symbolAliases).equals(tableFunctionProcessorNode.getSpecification().orElseThrow())) { + return NO_MATCH; + } + } + + ImmutableMap.Builder properOutputsMapping = ImmutableMap.builder(); + for (int i = 0; i < properOutputs.size(); i++) { + properOutputsMapping.put(properOutputs.get(i), tableFunctionProcessorNode.getProperOutputs().get(i).toSymbolReference()); + } + + return match(SymbolAliases.builder() + .putAll(symbolAliases) + .putAll(properOutputsMapping.buildOrThrow()) + .build()); + } + + @Override + public String toString() + { + return toStringHelper(this) + .omitNullValues() + .add("name", name) + .add("properOutputs", properOutputs) + .add("passThroughSymbols", passThroughSymbols) + .add("markerSymbols", markerSymbols) + .add("specification", specification) + .toString(); + } + + public static class Builder + { + private final Optional source; + private String name; + private List properOutputs = ImmutableList.of(); + private Set passThroughSymbols = ImmutableSet.of(); + private Optional> markerSymbols = Optional.empty(); + private Optional> specification = Optional.empty(); + + public Builder() + { + this.source = Optional.empty(); + } + + public Builder(PlanMatchPattern source) + { + this.source = Optional.of(source); + } + + public Builder name(String name) + { + this.name = name; + return this; + } + + public Builder properOutputs(List properOutputs) + { + this.properOutputs = properOutputs; + return this; + } + + public Builder passThroughSymbols(Set passThroughSymbols) + { + this.passThroughSymbols = passThroughSymbols; + return this; + } + + public Builder markerSymbols(Map markerSymbols) + { + this.markerSymbols = Optional.of(markerSymbols); + return this; + } + + public Builder specification(ExpectedValueProvider specification) + { + this.specification = Optional.of(specification); + return this; + } + + public PlanMatchPattern build() + { + PlanMatchPattern[] sources = source.map(sourcePattern -> new PlanMatchPattern[] {sourcePattern}).orElse(new PlanMatchPattern[] {}); + return node(TableFunctionProcessorNode.class, sources) + .with(new TableFunctionProcessorMatcher(name, properOutputs, passThroughSymbols, markerSymbols, specification)); + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java new file mode 100644 index 000000000000..a3314fe90f2f --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestImplementTableFunctionSource.java @@ -0,0 +1,1360 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.sql.planner.OrderingScheme; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.DataOrganizationSpecification; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughColumn; +import io.trino.sql.planner.plan.TableFunctionNode.PassThroughSpecification; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.trino.spi.connector.SortOrder.ASC_NULLS_LAST; +import static io.trino.spi.connector.SortOrder.DESC_NULLS_FIRST; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.TinyintType.TINYINT; +import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; +import static io.trino.sql.planner.assertions.PlanMatchPattern.functionCall; +import static io.trino.sql.planner.assertions.PlanMatchPattern.join; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.specification; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableFunctionProcessor; +import static io.trino.sql.planner.assertions.PlanMatchPattern.values; +import static io.trino.sql.planner.assertions.PlanMatchPattern.window; +import static io.trino.sql.planner.plan.JoinNode.Type.FULL; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; +import static io.trino.sql.planner.plan.JoinNode.Type.LEFT; + +public class TestImplementTableFunctionSource + extends BaseRuleTest +{ + @Test + public void testNoSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> p.tableFunction( + "test_function", + ImmutableList.of(p.symbol("a")), + ImmutableList.of(), + ImmutableList.of(), + ImmutableList.of())) + .matches(tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a")))); + } + + @Test + public void testSingleSourceWithRowSemantics() + { + // no pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + true, + true, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")), + values("c"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + true, + true, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, false))), + ImmutableList.of(c), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c")), + values("c"))); + } + + @Test + public void testSingleSourceWithSetSemantics() + { + // no pass-through columns, no partition by + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(d), ImmutableMap.of(d, ASC_NULLS_LAST))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .specification(specification(ImmutableList.of(), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // no pass-through columns, partitioning column present + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.of(new OrderingScheme(ImmutableList.of(d), ImmutableMap.of(d, ASC_NULLS_LAST))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("d"), ImmutableMap.of("d", ASC_NULLS_LAST))), + values("c", "d"))); + + // pass-through columns + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of(p.values(c, d)), + ImmutableList.of(new TableArgumentProperties( + "table_argument", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, false))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())), + values("c", "d"))); + } + + @Test + public void testTwoSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty())))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + FULL, + joinBuilder -> joinBuilder + .filter(""" + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testThreeSourcesWithSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + Symbol g = p.symbol("g"); + Symbol h = p.symbol("h"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f), + p.values(g, h)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of()), + ImmutableList.of(h), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(), Optional.of(new OrderingScheme(ImmutableList.of(h), ImmutableMap.of(h, DESC_NULLS_FIRST))))))), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2", + "g", "marker_3", + "h", "marker_3")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)")), + join(// join nodes using helper symbols + FULL, + joinBuilder -> joinBuilder + .filter(""" + combined_row_number_1_2 = input_3_row_number OR + combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR + input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1' + """) + .left(project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + FULL, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))) + .right(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of("h"), ImmutableMap.of("h", DESC_NULLS_FIRST))) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("g", "h")))))))); + } + + @Test + public void testTwoCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(f), ImmutableMap.of(f, DESC_NULLS_FIRST))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, e)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM e) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testCoPartitionJoinTypes() + { + // both sources are prune when empty, so they are combined using inner join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + INNER, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))))); + + // only the left source is prune when empty, so sources are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))))); + + // only the right source is prune when empty. the sources are reordered so that the prune when empty source is first. they are combined using left join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), input_2_row_number, input_1_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), input_2_partition_size, input_1_partition_size)"), + "combined_partition_column", expression("COALESCE(d, c)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (d IS DISTINCT FROM c) + AND ( + input_2_row_number = input_1_row_number OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d"))) + .right(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")))))))); + + // neither source is prune when empty, so sources are combined using full join + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c, d)")), + join(// co-partition nodes + FULL, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))))); + } + + @Test + public void testThreeCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2_3"), ImmutableList.of("combined_row_number_1_2_3"), ImmutableMap.of("combined_row_number_1_2_3", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3, input_3_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2_3", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), combined_row_number_1_2, input_3_row_number)"), + "combined_partition_size_1_2_3", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), combined_partition_size_1_2, input_3_partition_size)"), + "combined_partition_column_1_2_3", expression("COALESCE(combined_partition_column_1_2, e)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (combined_partition_column_1_2 IS DISTINCT FROM e) + AND ( + combined_row_number_1_2 = input_3_row_number OR + combined_row_number_1_2 > input_3_partition_size AND input_3_row_number = BIGINT '1' OR + input_3_row_number > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1') + """) + .left(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + INNER, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))) + .right(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e")))))))); + } + + @Test + public void testTwoCoPartitionLists() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + Symbol g = p.symbol("g"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e), + p.values(f, g)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty()))), + new TableArgumentProperties( + "input_4", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(f, true))), + ImmutableList.of(g), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(f), Optional.of(new OrderingScheme(ImmutableList.of(g), ImmutableMap.of(g, DESC_NULLS_FIRST))))))), + ImmutableList.of( + ImmutableList.of("input_1", "input_2"), + ImmutableList.of("input_3", "input_4"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3", + "f", "marker_4", + "g", "marker_4")) + .specification(specification(ImmutableList.of("combined_partition_column_1_2", "combined_partition_column_3_4"), ImmutableList.of("combined_row_number_1_2_3_4"), ImmutableMap.of("combined_row_number_1_2_3_4", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_1_2_3_4, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_1_2_3_4, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_1_2_3_4, input_3_row_number, CAST(null AS bigint))"), + "marker_4", expression("IF(input_4_row_number = combined_row_number_1_2_3_4, input_4_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_1_2_3_4", expression("IF(COALESCE(combined_row_number_1_2, BIGINT '-1') > COALESCE(combined_row_number_3_4, BIGINT '-1'), combined_row_number_1_2, combined_row_number_3_4)"), + "combined_partition_size_1_2_3_4", expression("IF(COALESCE(combined_partition_size_1_2, BIGINT '-1') > COALESCE(combined_partition_size_3_4, BIGINT '-1'), combined_partition_size_1_2, combined_partition_size_3_4)")), + join(// join nodes using helper symbols + LEFT, + joinBuilder -> joinBuilder + .filter(""" + combined_row_number_1_2 = combined_row_number_3_4 OR + combined_row_number_1_2 > combined_partition_size_3_4 AND combined_row_number_3_4 = BIGINT '1' OR + combined_row_number_3_4 > combined_partition_size_1_2 AND combined_row_number_1_2 = BIGINT '1' + """) + .left(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_1_2", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size_1_2", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1_2", expression("COALESCE(c, d)")), + join(// co-partition nodes + INNER, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (c IS DISTINCT FROM d) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d")))))) + .right(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_3_4", expression("IF(COALESCE(input_3_row_number, BIGINT '-1') > COALESCE(input_4_row_number, BIGINT '-1'), input_3_row_number, input_4_row_number)"), + "combined_partition_size_3_4", expression("IF(COALESCE(input_3_partition_size, BIGINT '-1') > COALESCE(input_4_partition_size, BIGINT '-1'), input_3_partition_size, input_4_partition_size)"), + "combined_partition_column_3_4", expression("COALESCE(e, f)")), + join(// co-partition nodes + FULL, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (e IS DISTINCT FROM f) + AND ( + input_3_row_number = input_4_row_number OR + input_3_row_number > input_4_partition_size AND input_4_row_number = BIGINT '1' OR + input_4_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e"))) + .right(window(// append helper symbols for source input_4 + builder -> builder + .specification(specification(ImmutableList.of("f"), ImmutableList.of("g"), ImmutableMap.of("g", DESC_NULLS_FIRST))) + .addFunction("input_4_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_4_partition_size", functionCall("count", ImmutableList.of())), + // input_4 + values("f", "g"))))))))))); + } + + @Test + public void testCoPartitionedAndNotCoPartitionedSources() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c), + p.values(d), + p.values(e)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(d, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(d), Optional.empty()))), + new TableArgumentProperties( + "input_3", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(e, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_2", "input_3"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_2", + "e", "marker_3")) + .specification(specification(ImmutableList.of("combined_partition_column_2_3", "c"), ImmutableList.of("combined_row_number_2_3_1"), ImmutableMap.of("combined_row_number_2_3_1", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number_2_3_1, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number_2_3_1, input_2_row_number, CAST(null AS bigint))"), + "marker_3", expression("IF(input_3_row_number = combined_row_number_2_3_1, input_3_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number_2_3_1", expression("IF(COALESCE(combined_row_number_2_3, BIGINT '-1') > COALESCE(input_1_row_number, BIGINT '-1'), combined_row_number_2_3, input_1_row_number)"), + "combined_partition_size_2_3_1", expression("IF(COALESCE(combined_partition_size_2_3, BIGINT '-1') > COALESCE(input_1_partition_size, BIGINT '-1'), combined_partition_size_2_3, input_1_partition_size)")), + join(// join nodes using helper symbols + INNER, + joinBuilder -> joinBuilder + .filter(""" + combined_row_number_2_3 = input_1_row_number OR + combined_row_number_2_3 > input_1_partition_size AND input_1_row_number = BIGINT '1' OR + input_1_row_number > combined_partition_size_2_3 AND combined_row_number_2_3 = BIGINT '1' + """) + .left(project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number_2_3", expression("IF(COALESCE(input_2_row_number, BIGINT '-1') > COALESCE(input_3_row_number, BIGINT '-1'), input_2_row_number, input_3_row_number)"), + "combined_partition_size_2_3", expression("IF(COALESCE(input_2_partition_size, BIGINT '-1') > COALESCE(input_3_partition_size, BIGINT '-1'), input_2_partition_size, input_3_partition_size)"), + "combined_partition_column_2_3", expression("COALESCE(d, e)")), + join(// co-partition nodes + LEFT, + nestedJoinBuilder -> nestedJoinBuilder + .filter(""" + NOT (d IS DISTINCT FROM e) + AND ( + input_2_row_number = input_3_row_number OR + input_2_row_number > input_3_partition_size AND input_3_row_number = BIGINT '1' OR + input_3_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("d"))) + .right(window(// append helper symbols for source input_3 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_3_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_3_partition_size", functionCall("count", ImmutableList.of())), + // input_3 + values("e")))))) + .right(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c")))))))); + } + + @Test + public void testCoerceForCopartitioning() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c", TINYINT); + Symbol cCoerced = p.symbol("c_coerced", INTEGER); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e", INTEGER); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + // coerce column c for co-partitioning + p.project( + Assignments.builder() + .put(c, PlanBuilder.expression("c")) + .put(d, PlanBuilder.expression("d")) + .put(cCoerced, PlanBuilder.expression("CAST(c AS INTEGER)")) + .build(), + p.values(c, d)), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(c, d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(cCoerced), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, false))), + ImmutableList.of(f), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e), Optional.of(new OrderingScheme(ImmutableList.of(f), ImmutableMap.of(f, DESC_NULLS_FIRST))))))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "c_coerced", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column", expression("COALESCE(c_coerced, e)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c_coerced IS DISTINCT FROM e) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c_coerced"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + project( + ImmutableMap.of("c_coerced", expression("CAST(c AS INTEGER)")), + values("c", "d")))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e"), ImmutableList.of("f"), ImmutableMap.of("f", DESC_NULLS_FIRST))) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testTwoCoPartitioningColumns() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + true, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true), new PassThroughColumn(d, true))), + ImmutableList.of(c), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c, d), Optional.empty()))), + new TableArgumentProperties( + "input_2", + false, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, true), new PassThroughColumn(f, true))), + ImmutableList.of(e), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(e, f), Optional.empty())))), + ImmutableList.of(ImmutableList.of("input_1", "input_2"))); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "d", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("combined_partition_column_1", "combined_partition_column_2"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper and partitioning symbols for co-partitioned nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)"), + "combined_partition_column_1", expression("COALESCE(c, e)"), + "combined_partition_column_2", expression("COALESCE(d, f)")), + join(// co-partition nodes + LEFT, + joinBuilder -> joinBuilder + .filter(""" + NOT (c IS DISTINCT FROM e) + AND NOT (d IS DISTINCT FROM f) + AND ( + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1') + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c", "d"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of("e", "f"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } + + @Test + public void testTwoSourcesWithRowAndSetSemantics() + { + tester().assertThat(new ImplementTableFunctionSource(tester().getMetadata())) + .on(p -> { + Symbol a = p.symbol("a"); + Symbol b = p.symbol("b"); + Symbol c = p.symbol("c"); + Symbol d = p.symbol("d"); + Symbol e = p.symbol("e"); + Symbol f = p.symbol("f"); + return p.tableFunction( + "test_function", + ImmutableList.of(a, b), + ImmutableList.of( + p.values(c, d), + p.values(e, f)), + ImmutableList.of( + new TableArgumentProperties( + "input_1", + false, + false, + new PassThroughSpecification(false, ImmutableList.of(new PassThroughColumn(c, true))), + ImmutableList.of(d), + Optional.of(new DataOrganizationSpecification(ImmutableList.of(c), Optional.empty()))), + new TableArgumentProperties( + "input_2", + true, + false, + new PassThroughSpecification(true, ImmutableList.of(new PassThroughColumn(e, false), new PassThroughColumn(f, false))), + ImmutableList.of(e), + Optional.empty())), + ImmutableList.of()); + }) + .matches(PlanMatchPattern.tableFunctionProcessor(builder -> builder + .name("test_function") + .properOutputs(ImmutableList.of("a", "b")) + .passThroughSymbols(ImmutableSet.of("c", "e", "f")) + .markerSymbols(ImmutableMap.of( + "c", "marker_1", + "d", "marker_1", + "e", "marker_2", + "f", "marker_2")) + .specification(specification(ImmutableList.of("c"), ImmutableList.of("combined_row_number"), ImmutableMap.of("combined_row_number", ASC_NULLS_LAST))), + project(// append marker symbols + ImmutableMap.of( + "marker_1", expression("IF(input_1_row_number = combined_row_number, input_1_row_number, CAST(null AS bigint))"), + "marker_2", expression("IF(input_2_row_number = combined_row_number, input_2_row_number, CAST(null AS bigint))")), + project(// append helper symbols for joined nodes + ImmutableMap.of( + "combined_row_number", expression("IF(COALESCE(input_1_row_number, BIGINT '-1') > COALESCE(input_2_row_number, BIGINT '-1'), input_1_row_number, input_2_row_number)"), + "combined_partition_size", expression("IF(COALESCE(input_1_partition_size, BIGINT '-1') > COALESCE(input_2_partition_size, BIGINT '-1'), input_1_partition_size, input_2_partition_size)")), + join(// join nodes using helper symbols + FULL, + joinBuilder -> joinBuilder + .filter(""" + input_1_row_number = input_2_row_number OR + input_1_row_number > input_2_partition_size AND input_2_row_number = BIGINT '1' OR + input_2_row_number > input_1_partition_size AND input_1_row_number = BIGINT '1' + """) + .left(window(// append helper symbols for source input_1 + builder -> builder + .specification(specification(ImmutableList.of("c"), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_1_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_1_partition_size", functionCall("count", ImmutableList.of())), + // input_1 + values("c", "d"))) + .right(window(// append helper symbols for source input_2 + builder -> builder + .specification(specification(ImmutableList.of(), ImmutableList.of(), ImmutableMap.of())) + .addFunction("input_2_row_number", functionCall("row_number", ImmutableList.of())) + .addFunction("input_2_partition_size", functionCall("count", ImmutableList.of())), + // input_2 + values("e", "f")))))))); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index 55e7c65e0d0e..718a01134b67 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -26,13 +26,16 @@ import io.trino.metadata.OutputTableHandle; import io.trino.metadata.ResolvedFunction; import io.trino.metadata.TableExecuteHandle; +import io.trino.metadata.TableFunctionHandle; import io.trino.metadata.TableHandle; import io.trino.operator.RetryPolicy; import io.trino.spi.connector.CatalogHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SortOrder; +import io.trino.spi.function.SchemaFunctionName; import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.ptf.ConnectorTableFunctionHandle; import io.trino.spi.type.Type; import io.trino.sql.ExpressionUtils; import io.trino.sql.analyzer.TypeSignatureProvider; @@ -87,6 +90,8 @@ import io.trino.sql.planner.plan.StatisticAggregationsDescriptor; import io.trino.sql.planner.plan.TableExecuteNode; import io.trino.sql.planner.plan.TableFinishNode; +import io.trino.sql.planner.plan.TableFunctionNode; +import io.trino.sql.planner.plan.TableFunctionNode.TableArgumentProperties; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.TableWriterNode.CreateTarget; @@ -1174,6 +1179,25 @@ public TableExecuteNode tableExecute( preferredPartitioningScheme); } + public TableFunctionNode tableFunction( + String name, + List properOutputs, + List sources, + List tableArgumentProperties, + List> copartitioningLists) + + { + return new TableFunctionNode( + idAllocator.getNextId(), + name, + ImmutableMap.of(), + properOutputs, + sources, + tableArgumentProperties, + copartitioningLists, + new TableFunctionHandle(TEST_CATALOG_HANDLE, new SchemaFunctionName("system", name), new ConnectorTableFunctionHandle() {}, TestingTransactionHandle.create())); + } + public PartitioningScheme partitioningScheme(List outputSymbols, List partitioningSymbols, Symbol hashSymbol) { return new PartitioningScheme(Partitioning.create(