diff --git a/presto-hive/src/test/java/io/prestosql/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java b/presto-hive/src/test/java/io/prestosql/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java index 6f1636e2ff2f..e4fb395238f3 100644 --- a/presto-hive/src/test/java/io/prestosql/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java +++ b/presto-hive/src/test/java/io/prestosql/plugin/hive/optimizer/TestHiveProjectionPushdownIntoTableScan.java @@ -26,6 +26,7 @@ import io.prestosql.plugin.hive.HdfsEnvironment; import io.prestosql.plugin.hive.HiveColumnHandle; import io.prestosql.plugin.hive.HiveHdfsConfiguration; +import io.prestosql.plugin.hive.HiveTableHandle; import io.prestosql.plugin.hive.authentication.HiveIdentity; import io.prestosql.plugin.hive.authentication.NoHdfsAuthentication; import io.prestosql.plugin.hive.metastore.Database; @@ -33,6 +34,7 @@ import io.prestosql.plugin.hive.metastore.file.FileHiveMetastore; import io.prestosql.plugin.hive.testing.TestingHiveConnectorFactory; import io.prestosql.spi.connector.ColumnHandle; +import io.prestosql.spi.predicate.Domain; import io.prestosql.spi.predicate.TupleDomain; import io.prestosql.spi.security.PrincipalType; import io.prestosql.sql.planner.assertions.BasePushdownPlanTest; @@ -48,10 +50,16 @@ import static com.google.common.io.MoreFiles.deleteRecursively; import static com.google.common.io.RecursiveDeleteOption.ALLOW_INSECURE; import static io.prestosql.plugin.hive.TestHiveReaderProjectionsUtil.createProjectedColumnHandle; +import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.any; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; import static io.prestosql.testing.TestingSession.testSessionBuilder; import static java.lang.String.format; import static org.testng.Assert.assertTrue; @@ -116,14 +124,14 @@ public void testPushdownDisabled() } @Test - public void testProjectionPushdown() + public void testDereferencePushdown() { String testTable = "test_simple_projection_pushdown"; QualifiedObjectName completeTableName = new QualifiedObjectName(HIVE_CATALOG_NAME, SCHEMA_NAME, testTable); getQueryRunner().execute(format( - "CREATE TABLE %s (col0) AS" + - " SELECT cast(row(5, 6) as row(a bigint, b bigint)) AS col0 WHERE false", + "CREATE TABLE %s (col0, col1) AS" + + " SELECT cast(row(5, 6) as row(x bigint, y bigint)) AS col0, 5 AS col1 WHERE false", testTable)); Session session = getQueryRunner().getDefaultSession(); @@ -132,18 +140,70 @@ public void testProjectionPushdown() assertTrue(tableHandle.isPresent(), "expected the table handle to be present"); Map columns = getColumnHandles(session, completeTableName); - assertTrue(columns.containsKey("col0"), "expected column not found"); - HiveColumnHandle baseColumnHandle = (HiveColumnHandle) columns.get("col0"); + HiveColumnHandle column0Handle = (HiveColumnHandle) columns.get("col0"); + HiveColumnHandle column1Handle = (HiveColumnHandle) columns.get("col1"); + HiveColumnHandle columnX = createProjectedColumnHandle(column0Handle, ImmutableList.of(0)); + HiveColumnHandle columnY = createProjectedColumnHandle(column0Handle, ImmutableList.of(1)); + + // Simple Projection pushdown assertPlan( - "SELECT col0.a expr_a, col0.b expr_b FROM " + testTable, + "SELECT col0.x expr_x, col0.y expr_y FROM " + testTable, any(tableScan( - equalTo(tableHandle.get().getConnectorHandle()), - TupleDomain.all(), - ImmutableMap.of( - "col0#a", equalTo(createProjectedColumnHandle(baseColumnHandle, ImmutableList.of(0))), - "col0#b", equalTo(createProjectedColumnHandle(baseColumnHandle, ImmutableList.of(1))))))); + equalTo(tableHandle.get().getConnectorHandle()), + TupleDomain.all(), + ImmutableMap.of("col0#x", equalTo(columnX), "col0#y", equalTo(columnY))))); + + // Projection and predicate pushdown + assertPlan( + format("SELECT col0.x FROM %s WHERE col0.x = col1 + 3 and col0.y = 2", testTable), + anyTree( + filter( + "col0_y = bigint '2' AND (col0_x = cast((col1 + 3) as bigint))", + tableScan( + table -> ((HiveTableHandle) table).getCompactEffectivePredicate().getDomains().get() + .equals(ImmutableMap.of(columnY, Domain.singleValue(BIGINT, 2L))), + TupleDomain.all(), + ImmutableMap.of("col0_y", equalTo(columnY), "col0_x", equalTo(columnX), "col1", equalTo(column1Handle)))))); + + // Projection and predicate pushdown with overlapping columns + assertPlan( + format("SELECT col0, col0.y expr_y FROM %s WHERE col0.x = 5", testTable), + anyTree( + filter( + "col0_x = bigint '5'", + tableScan( + table -> ((HiveTableHandle) table).getCompactEffectivePredicate().getDomains().get() + .equals(ImmutableMap.of(columnX, Domain.singleValue(BIGINT, 5L))), + TupleDomain.all(), + ImmutableMap.of("col0", equalTo(column0Handle), "col0_x", equalTo(columnX)))))); + + // Projection and predicate pushdown with joins + assertPlan( + format("SELECT T.col0.x, T.col0, T.col0.y FROM %s T join %s S on T.col1 = S.col1 WHERE (T.col0.x = 2)", testTable, testTable), + anyTree( + project( + ImmutableMap.of( + "expr_0_x", expression("expr_0.x"), + "expr_0", expression("expr_0"), + "expr_0_y", expression("expr_0.y")), + join( + INNER, + ImmutableList.of(equiJoinClause("t_expr_1", "s_expr_1")), + anyTree( + filter( + "expr_0_x = BIGINT '2'", + tableScan( + table -> ((HiveTableHandle) table).getCompactEffectivePredicate().getDomains().get() + .equals(ImmutableMap.of(columnX, Domain.singleValue(BIGINT, 2L))), + TupleDomain.all(), + ImmutableMap.of("expr_0_x", equalTo(columnX), "expr_0", equalTo(column0Handle), "t_expr_1", equalTo(column1Handle))))), + anyTree( + tableScan( + equalTo(tableHandle.get().getConnectorHandle()), + TupleDomain.all(), + ImmutableMap.of("s_expr_1", equalTo(column1Handle)))))))); } @AfterClass(alwaysRun = true) diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index 2c6b1d3bd44a..c26626039942 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -42,6 +42,7 @@ import io.prestosql.sql.planner.iterative.rule.DetermineSemiJoinDistributionType; import io.prestosql.sql.planner.iterative.rule.EliminateCrossJoins; import io.prestosql.sql.planner.iterative.rule.EvaluateZeroSample; +import io.prestosql.sql.planner.iterative.rule.ExtractDereferencesFromFilterAboveScan; import io.prestosql.sql.planner.iterative.rule.ExtractSpatialJoins; import io.prestosql.sql.planner.iterative.rule.GatherAndMergeWindows; import io.prestosql.sql.planner.iterative.rule.ImplementBernoulliSampleAsFilter; @@ -105,6 +106,18 @@ import io.prestosql.sql.planner.iterative.rule.PruneWindowColumns; import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushDeleteIntoConnector; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughFilter; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughJoin; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughProject; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughSemiJoin; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferenceThroughUnnest; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferencesThroughAssignUniqueId; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferencesThroughLimit; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferencesThroughRowNumber; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferencesThroughSort; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferencesThroughTopN; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferencesThroughTopNRowNumber; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferencesThroughWindow; import io.prestosql.sql.planner.iterative.rule.PushLimitIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOffset; @@ -299,7 +312,21 @@ public PlanOptimizers( Set> projectionPushdownRules = ImmutableSet.of( new PushProjectionIntoTableScan(metadata, typeAnalyzer), new PushProjectionThroughUnion(), - new PushProjectionThroughExchange()); + new PushProjectionThroughExchange(), + // Dereference pushdown rules + new PushDownDereferenceThroughProject(typeAnalyzer), + new PushDownDereferenceThroughUnnest(typeAnalyzer), + new PushDownDereferenceThroughSemiJoin(typeAnalyzer), + new PushDownDereferenceThroughJoin(typeAnalyzer), + new PushDownDereferenceThroughFilter(typeAnalyzer), + new ExtractDereferencesFromFilterAboveScan(typeAnalyzer), + new PushDownDereferencesThroughLimit(typeAnalyzer), + new PushDownDereferencesThroughSort(typeAnalyzer), + new PushDownDereferencesThroughAssignUniqueId(typeAnalyzer), + new PushDownDereferencesThroughWindow(typeAnalyzer), + new PushDownDereferencesThroughTopN(typeAnalyzer), + new PushDownDereferencesThroughRowNumber(typeAnalyzer), + new PushDownDereferencesThroughTopNRowNumber(typeAnalyzer)); IterativeOptimizer inlineProjections = new IterativeOptimizer( ruleStats, @@ -497,6 +524,16 @@ public PlanOptimizers( inlineProjections, simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations projectionPushDown, + // Projection pushdown rules may push reducing projections (e.g. dereferences) below filters for potential + // pushdown into the connectors. We invoke PredicatePushdown and PushPredicateIntoTableScan after this + // to leverage predicate pushdown on projected columns. + new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, typeAnalyzer, true, false)), + simplifyOptimizer, // Should be always run after PredicatePushDown + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), new UnaliasSymbolReferences(metadata), // Run again because predicate pushdown and projection pushdown might add more projections new PruneUnreferencedOutputs(metadata, typeAnalyzer), // Make sure to run this before index join. Filtered projections may not have all the columns. new IndexJoinOptimizer(metadata), // Run this after projections and filters have been fully simplified and pushed down @@ -539,6 +576,16 @@ public PlanOptimizers( estimatedExchangesCostCalculator, ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), projectionPushDown, + // Projection pushdown rules may push reducing projections (e.g. dereferences) below filters for potential + // pushdown into the connectors. Invoke PredicatePushdown and PushPredicateIntoTableScan after this + // to leverage predicate pushdown on projected columns. + new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, typeAnalyzer, true, false)), + simplifyOptimizer, // Should be always run after PredicatePushDown + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))), new PruneUnreferencedOutputs(metadata, typeAnalyzer), new IterativeOptimizer( ruleStats, @@ -627,6 +674,17 @@ public PlanOptimizers( costCalculator, ImmutableSet.of(new RemoveRedundantTableScanPredicate(metadata)))); builder.add(projectionPushDown); + // Projection pushdown rules may push reducing projections (e.g. dereferences) below filters for potential + // pushdown into the connectors. Invoke PredicatePushdown and PushPredicateIntoTableScan after this + // to leverage predicate pushdown on projected columns. + builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, typeAnalyzer, true, true))); + builder.add(new RemoveUnsupportedDynamicFilters(metadata)); // Remove unsupported dynamic filters introduced by PredicatePushdown + builder.add(simplifyOptimizer); // Should always run after PredicatePushdown + new IterativeOptimizer( + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of(new PushPredicateIntoTableScan(metadata, typeAnalyzer))); builder.add(inlineProjections); builder.add(new UnaliasSymbolReferences(metadata)); // Run unalias after merging projections to simplify projections more efficiently builder.add(new PruneUnreferencedOutputs(metadata, typeAnalyzer)); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DereferencePushdown.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DereferencePushdown.java new file mode 100644 index 000000000000..f0bddcb928b1 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/DereferencePushdown.java @@ -0,0 +1,130 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.LambdaExpression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Collection; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.prestosql.sql.planner.SymbolsExtractor.extractAll; + +/** + * Provides helper methods to push down dereferences in the query plan. + */ +class DereferencePushdown +{ + private DereferencePushdown() {} + + public static Set extractDereferences(Collection expressions, boolean allowOverlap) + { + Set symbolReferencesAndDereferences = expressions.stream() + .flatMap(expression -> getSymbolReferencesAndDereferences(expression).stream()) + .collect(Collectors.toSet()); + + // Remove overlap if required + Set candidateExpressions = symbolReferencesAndDereferences; + if (!allowOverlap) { + candidateExpressions = symbolReferencesAndDereferences.stream() + .filter(expression -> !prefixExists(expression, symbolReferencesAndDereferences)) + .collect(Collectors.toSet()); + } + + // Retain dereference expressions + return candidateExpressions.stream() + .filter(DereferenceExpression.class::isInstance) + .map(DereferenceExpression.class::cast) + .collect(Collectors.toSet()); + } + + public static boolean exclusiveDereferences(Set projections) + { + return projections.stream() + .allMatch(expression -> expression instanceof SymbolReference || + (expression instanceof DereferenceExpression && + isDereferenceChain((DereferenceExpression) expression) && + !prefixExists(expression, projections))); + } + + public static Symbol getBase(DereferenceExpression expression) + { + return getOnlyElement(extractAll(expression)); + } + + /** + * Extract the sub-expressions of type {@link DereferenceExpression} or {@link SymbolReference} from the {@param expression} + * in a top-down manner. The expressions within the base of a valid {@link DereferenceExpression} sequence are not extracted. + */ + private static List getSymbolReferencesAndDereferences(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + + new DefaultExpressionTraversalVisitor>() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableList.Builder context) + { + if (isDereferenceChain(node)) { + context.add(node); + } + return null; + } + + @Override + protected Void visitSymbolReference(SymbolReference node, ImmutableList.Builder context) + { + context.add(node); + return null; + } + + @Override + protected Void visitLambdaExpression(LambdaExpression node, ImmutableList.Builder context) + { + return null; + } + }.process(expression, builder); + + return builder.build(); + } + + private static boolean isDereferenceChain(DereferenceExpression expression) + { + return (expression.getBase() instanceof SymbolReference) || + ((expression.getBase() instanceof DereferenceExpression) && isDereferenceChain((DereferenceExpression) (expression.getBase()))); + } + + private static boolean prefixExists(Expression expression, Set expressions) + { + Expression current = expression; + while (current instanceof DereferenceExpression) { + current = ((DereferenceExpression) current).getBase(); + if (expressions.contains(current)) { + return true; + } + } + + verify(current instanceof SymbolReference); + return false; + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java new file mode 100644 index 000000000000..abe1af134de2 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/ExtractDereferencesFromFilterAboveScan.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.TableScanNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.plan.Patterns.filter; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.tableScan; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *      Filter(f1(A.x.y) = 1 AND f2(B.m) = 2 AND f3(A.x) = 6)
+ *          Source(A, B, C)
+ *  
+ * to: + *
+ *  Project(A, B, C)
+ *      Filter(f1(D) = 1 AND f2(E) = 2 AND f3(G) = 6)
+ *          Project(A, B, C, D := A.x.y, E := B.m, G := A.x)
+ *              Source(A, B, C)
+ * 
+ * + * This optimizer extracts all dereference expressions from a filter node located above a table scan into a ProjectNode. + * + * Extracting dereferences from a filter (eg. FilterNode(a.x = 5)) can be suboptimal if full columns are being accessed up the + * plan tree (eg. a), because it can result in replicated shuffling of fields (eg. a.x). So it is safer to pushdown dereferences from + * Filter only when there's an explicit projection on top of the filter node (Ref PushDereferencesThroughFilter). + * + * In case of a FilterNode on top of TableScanNode, we want to push all dereferences into a new ProjectNode below, so that + * PushProjectionIntoTableScan optimizer can push those columns in the connector, and provide new column handles for the + * projected subcolumns. PushPredicateIntoTableScan optimizer can then push predicates on these subcolumns into the connector. + */ +public class ExtractDereferencesFromFilterAboveScan + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public ExtractDereferencesFromFilterAboveScan(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return filter() + .with(source().matching(tableScan().capturedAs(CHILD))); + } + + @Override + public Result apply(FilterNode node, Captures captures, Context context) + { + Set dereferences = extractDereferences(ImmutableList.of(node.getPredicate()), true); + if (dereferences.isEmpty()) { + return Result.empty(); + } + + Assignments assignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + Map mappings = HashBiMap.create(assignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + + PlanNode source = node.getSource(); + return Result.ofPlanNode(new ProjectNode( + context.getIdAllocator().getNextId(), + new FilterNode( + context.getIdAllocator().getNextId(), + new ProjectNode( + context.getIdAllocator().getNextId(), + source, + Assignments.builder() + .putIdentities(source.getOutputSymbols()) + .putAll(assignments) + .build()), + replaceExpression(node.getPredicate(), mappings)), + Assignments.identity(node.getOutputSymbols()))); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index e8afb662dd34..1ae84c6b4262 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -24,6 +24,7 @@ import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.Literal; import io.prestosql.sql.tree.TryExpression; @@ -172,6 +173,7 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod .filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node .filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics .filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever + .filter(entry -> !(child.getAssignments().get(entry.getKey()) instanceof DereferenceExpression)) // skip dereferences, otherwise, inlining can cause conflicts with PushdownDereferences .map(Map.Entry::getKey) .collect(toSet()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java new file mode 100644 index 000000000000..c48ee522084d --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughFilter.java @@ -0,0 +1,124 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.plan.Patterns.filter; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(B), G := f3(C))
+ *      Filter(A.x.y = 5 AND B.m = 3)
+ *          Source(A, B, C)
+ *  
+ * to: + *
+ *  Project(D := f1(expr), E := f2(B), G := f3(C))
+ *      Filter(expr.y = 5 AND B.m = 3)
+ *          Project(A, B, C, expr := A.x)
+ *              Source(A, B, C)
+ * 
+ * + * Pushes down dereference projections in project node assignments and filter node predicate. + */ +public class PushDownDereferenceThroughFilter + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferenceThroughFilter(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(filter().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Rule.Context context) + { + FilterNode filterNode = captures.get(CHILD); + + // Pushdown superset of dereference expressions from projections and filtering predicate + List expressions = ImmutableList.builder() + .addAll(node.getAssignments().getExpressions()) + .add(filterNode.getPredicate()) + .build(); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(expressions, false); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments assignments = node.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + PlanNode source = filterNode.getSource(); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new FilterNode( + context.getIdAllocator().getNextId(), + new ProjectNode( + context.getIdAllocator().getNextId(), + source, + Assignments.builder() + .putIdentities(source.getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()), + replaceExpression(filterNode.getPredicate(), mappings)), + assignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java new file mode 100644 index 000000000000..dc8c08492d59 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughJoin.java @@ -0,0 +1,201 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.SymbolsExtractor.extractAll; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.join; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toList; + +/** + * Transforms: + *
+ *  Project(A_X := f1(A.x), G := f2(A_Y.z), E := f3(B))
+ *    Join(A_Y = C_Y) => [A, B]
+ *      Project(A_Y := A.y, A, B)
+ *          Source(A, B)
+ *      Project(C_Y := C.y)
+ *          Source(C, D)
+ *  
+ * to: + *
+ *  Project(A_X := f1(symbol), G := f2(A_Y.z), E := f3(B))
+ *    Join(A_Y = C_Y) => [symbol, B]
+ *      Project(symbol := A.x, A_Y := A.y, A, B)
+ *        Source(A, B)
+ *      Project(C_Y := C.y)
+ *        Source(C, D)
+ * 
+ * + * Pushes down dereference projections through JoinNode. Excludes dereferences on symbols being used in join criteria to avoid + * data replication, since these symbols cannot be pruned. + */ +public class PushDownDereferenceThroughJoin + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferenceThroughJoin(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(join().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + JoinNode joinNode = captures.get(CHILD); + + // Consider dereferences in projections and join filter for pushdown + ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); + expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); + joinNode.getFilter().ifPresent(expressionsBuilder::add); + Set dereferences = extractDereferences(expressionsBuilder.build(), false); + + // Exclude criteria symbols + ImmutableSet.Builder criteriaSymbolsBuilder = ImmutableSet.builder(); + joinNode.getCriteria().forEach(criteria -> { + criteriaSymbolsBuilder.add(criteria.getLeft()); + criteriaSymbolsBuilder.add(criteria.getRight()); + }); + Set excludeSymbols = criteriaSymbolsBuilder.build(); + + dereferences = dereferences.stream() + .filter(expression -> !excludeSymbols.contains(getBase(expression))) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + Assignments.Builder leftAssignmentsBuilder = Assignments.builder(); + Assignments.Builder rightAssignmentsBuilder = Assignments.builder(); + + // Separate dereferences coming from left and right nodes + dereferenceAssignments.entrySet().stream() + .forEach(entry -> { + Symbol baseSymbol = getOnlyElement(extractAll(entry.getValue())); + if (joinNode.getLeft().getOutputSymbols().contains(baseSymbol)) { + leftAssignmentsBuilder.put(entry.getKey(), entry.getValue()); + } + else if (joinNode.getRight().getOutputSymbols().contains(baseSymbol)) { + rightAssignmentsBuilder.put(entry.getKey(), entry.getValue()); + } + else { + throw new IllegalArgumentException(format("Unexpected symbol %s in projectNode", baseSymbol)); + } + }); + + Assignments leftAssignments = leftAssignmentsBuilder.build(); + Assignments rightAssignments = rightAssignmentsBuilder.build(); + + PlanNode leftNode = createProjectNodeIfRequired(joinNode.getLeft(), leftAssignments, context.getIdAllocator()); + PlanNode rightNode = createProjectNodeIfRequired(joinNode.getRight(), rightAssignments, context.getIdAllocator()); + + // Prepare new output symbols for join node + List referredSymbolsInAssignments = newAssignments.getExpressions().stream() + .flatMap(expression -> extractAll(expression).stream()) + .collect(toList()); + + List newLeftOutputSymbols = referredSymbolsInAssignments.stream() + .filter(symbol -> leftNode.getOutputSymbols().contains(symbol)) + .collect(toList()); + + List newRightOutputSymbols = referredSymbolsInAssignments.stream() + .filter(symbol -> rightNode.getOutputSymbols().contains(symbol)) + .collect(toList()); + + JoinNode newJoinNode = new JoinNode( + context.getIdAllocator().getNextId(), + joinNode.getType(), + leftNode, + rightNode, + joinNode.getCriteria(), + newLeftOutputSymbols, + newRightOutputSymbols, + // Use newly created symbols in filter + joinNode.getFilter().map(expression -> replaceExpression(expression, mappings)), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType(), + joinNode.isSpillable(), + joinNode.getDynamicFilters(), + joinNode.getReorderJoinStatsAndCost()); + + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newJoinNode, newAssignments)); + } + + private static PlanNode createProjectNodeIfRequired(PlanNode planNode, Assignments dereferences, PlanNodeIdAllocator idAllocator) + { + if (dereferences.isEmpty()) { + return planNode; + } + return new ProjectNode( + idAllocator.getNextId(), + planNode, + Assignments.builder() + .putIdentities(planNode.getOutputSymbols()) + .putAll(dereferences) + .build()); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java new file mode 100644 index 000000000000..ad4e6bf302da --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughProject.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(c := f(a.x), d := g(b))
+ *    Project(a, b)
+ *  
+ * to: + *
+ *  Project(c := f(symbol), d := g(b))
+ *    Project(a, b, symbol := a.x)
+ * 
+ */ +public class PushDownDereferenceThroughProject + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferenceThroughProject(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(project().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + ProjectNode child = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(node.getAssignments().getExpressions(), false); + + // Exclude dereferences on symbols being synthesized within child + dereferences = dereferences.stream() + .filter(expression -> child.getSource().getOutputSymbols().contains(getBase(expression))) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments assignments = node.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new ProjectNode( + context.getIdAllocator().getNextId(), + child.getSource(), + Assignments.builder() + .putAll(child.getAssignments()) + .putAll(dereferenceAssignments) + .build()), + assignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java new file mode 100644 index 000000000000..58a4182507e2 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughSemiJoin.java @@ -0,0 +1,123 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SemiJoinNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.semiJoin; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(B.x), G := f3(C))
+ *      SemiJoin(sourceJoinSymbol = B, filteringSourceJoinSymbol = B_filtering)
+ *          Source(A, B, C)
+ *          FilteringSource(B_filtering)
+ *  
+ * to: + *
+ *  Project(D := f1(symbol), E := f2(B.x), G := f3(C))
+ *          SemiJoinNode(sourceJoinSymbol = B, filteringSourceJoinSymbol = B_filtering)
+ *              Project(A, B, C, symbol := A.x)
+ *                  Source(A, B, C)
+ *              FilteringSource(B_filtering)
+ * 
+ * + * Pushes down dereference projections through SemiJoinNode. Excludes dereferences on sourceJoinSymbol to avoid + * data replication, since this symbol cannot be pruned. + */ +public class PushDownDereferenceThroughSemiJoin + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferenceThroughSemiJoin(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(semiJoin().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + SemiJoinNode semiJoinNode = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // All dereferences can be assumed on the symbols coming from source, since filteringSource output is not propagated, + // and semiJoinOutput is of type boolean. We exclude pushdown of dereferences on sourceJoinSymbol. + dereferences = dereferences.stream() + .filter(expression -> !getBase(expression).equals(semiJoinNode.getSourceJoinSymbol())) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments assignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + PlanNode newSource = new ProjectNode( + context.getIdAllocator().getNextId(), + semiJoinNode.getSource(), + Assignments.builder() + .putIdentities(semiJoinNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()); + + PlanNode newSemiJoin = semiJoinNode.replaceChildren(ImmutableList.of(newSource, semiJoinNode.getFilteringSource())); + + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newSemiJoin, assignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java new file mode 100644 index 000000000000..d7c10c4a2a90 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferenceThroughUnnest.java @@ -0,0 +1,138 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.UnnestNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.unnest; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(C), B_BIGINT)
+ *      Unnest(replicate = [A, C], unnest = (B_ARRAY -> [B_BIGINT]))
+ *          Source(A, B_ARAAY, C)
+ *  
+ * to: + *
+ *  Project(D := f1(symbol), E := f2(C), B_BIGINT)
+ *      Unnest(replicate = [A, C, symbol], unnest = (B_ARAAY -> [B_BIGINT]))
+ *          Project(A, B_ARRAY, C, symbol := A.x)
+ *              Source(A, B_ARAAY, C)
+ * 
+ * + * Pushes down dereference projections through Unnest. Currently, the pushdown is only supported for dereferences on replicate symbols. + */ +public class PushDownDereferenceThroughUnnest + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferenceThroughUnnest(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(unnest().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + UnnestNode unnestNode = captures.get(CHILD); + + // Extract dereferences from project node's assignments and unnest node's filter + ImmutableList.Builder expressionsBuilder = ImmutableList.builder(); + expressionsBuilder.addAll(projectNode.getAssignments().getExpressions()); + unnestNode.getFilter().ifPresent(expressionsBuilder::add); + + // Extract dereferences for pushdown + Set dereferences = extractDereferences(expressionsBuilder.build(), false); + + // Only retain dereferences on replicate symbols + dereferences = dereferences.stream() + .filter(expression -> unnestNode.getReplicateSymbols().contains(getBase(expression))) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + // Create a new ProjectNode (above the original source) adding dereference projections on replicated symbols + ProjectNode source = new ProjectNode( + context.getIdAllocator().getNextId(), + unnestNode.getSource(), + Assignments.builder() + .putIdentities(unnestNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()); + + // Create projectNode with the new unnest node and assignments with replaced dereferences + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new UnnestNode( + context.getIdAllocator().getNextId(), + source, + ImmutableList.builder() + .addAll(unnestNode.getReplicateSymbols()) + .addAll(dereferenceAssignments.getSymbols()) + .build(), + unnestNode.getMappings(), + unnestNode.getOrdinalitySymbol(), + unnestNode.getJoinType(), + unnestNode.getFilter().map(filter -> replaceExpression(filter, mappings))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java new file mode 100644 index 000000000000..f2537ad607cc --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughAssignUniqueId.java @@ -0,0 +1,112 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.AssignUniqueId; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.plan.Patterns.assignUniqueId; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(C := f1(A.x), D := f2(B))
+ *      AssignUniqueId
+ *          Source(A, B)
+ *  
+ * to: + *
+ *  Project(C := f1(symbol), D := f2(B))
+ *      AssignUniqueId
+ *          Project(A, B, symbol := A.x)
+ *              Source(A, B)
+ * 
+ */ +public class PushDownDereferencesThroughAssignUniqueId + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughAssignUniqueId(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(assignUniqueId().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + AssignUniqueId assignUniqueId = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // We do not need to filter dereferences on idColumn symbol since it is supposed to be of BIGINT type. + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + assignUniqueId.replaceChildren(ImmutableList.of( + new ProjectNode( + context.getIdAllocator().getNextId(), + assignUniqueId.getSource(), + Assignments.builder() + .putIdentities(assignUniqueId.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java new file mode 100644 index 000000000000..c96bb5a10e26 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughLimit.java @@ -0,0 +1,119 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.LimitNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.limit; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(B.x), G := f3(C))
+ *      Limit(5, tiesResolvingScheme = [B])
+ *          Source(A, B, C)
+ *  
+ * to: + *
+ *  Project(D := f1(symbol), E := f2(B.x), G := f3(C))
+ *      Limit(5, tiesResolvingScheme = [B])
+ *          Project(symbol := A.x, A, B, C)
+ *              Source(A, B, C)
+ * 
+ */ +public class PushDownDereferencesThroughLimit + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughLimit(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(limit().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + LimitNode limitNode = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // Exclude dereferences on symbols being used in tiesResolvingScheme + if (limitNode.getTiesResolvingScheme().isPresent()) { + dereferences = dereferences.stream() + .filter(expression -> !limitNode.getTiesResolvingScheme().get().getOrderBy().contains(getBase(expression))) + .collect(toImmutableSet()); + } + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + limitNode.replaceChildren(ImmutableList.of( + new ProjectNode( + context.getIdAllocator().getNextId(), + limitNode.getSource(), + Assignments.builder() + .putIdentities(limitNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java new file mode 100644 index 000000000000..397ac951a6e5 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughMarkDistinct.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.MarkDistinctNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.markDistinct; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(B.x), G := f3(C))
+ *      MarkDistinct(distinctSymbols = [B])
+ *          Source(A, B, C)
+ *  
+ * to: + *
+ *  Project(D := f1(symbol), E := f2(B.x), G := f3(C))
+ *      MarkDistinct(distinctSymbols = [B])
+ *          Project(A, B, C, symbol := A.x)
+ *              Source(A, B, C)
+ * 
+ * + * Pushes down dereference projections through MarkDistinct. Excludes dereferences on "distinct symbols" to avoid data + * replication, since these symbols cannot be pruned. + */ +public class PushDownDereferencesThroughMarkDistinct + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughMarkDistinct(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(markDistinct().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + MarkDistinctNode markDistinctNode = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // Exclude dereferences on distinct symbols being used in markDistinctNode. We do not need to filter + // dereferences on markerSymbol since it is supposed to be of boolean type. + dereferences = dereferences.stream() + .filter(expression -> !markDistinctNode.getDistinctSymbols().contains(getBase(expression))) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + markDistinctNode.replaceChildren(ImmutableList.of( + new ProjectNode( + context.getIdAllocator().getNextId(), + markDistinctNode.getSource(), + Assignments.builder() + .putIdentities(markDistinctNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java new file mode 100644 index 000000000000..ed68ad9279a5 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughRowNumber.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.RowNumberNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.rowNumber; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(B.x), G := f3(C))
+ *      RowNumber(partitionBy = [B])
+ *          Source(A, B, C)
+ *  
+ * to: + *
+ *  Project(D := f1(symbol), E := f2(B.x), G := f3(C))
+ *      RowNumber(partitionBy = [B])
+ *          Project(A, B, C, symbol := A.x)
+ *              Source(A, B, C)
+ * 
+ * + * Pushes down dereference projections through RowNumber. Excludes dereferences on symbols in partitionBy to avoid data + * replication, since these symbols cannot be pruned. + */ +public class PushDownDereferencesThroughRowNumber + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughRowNumber(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(rowNumber().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + RowNumberNode rowNumberNode = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // Exclude dereferences on symbols being used in partitionBy + dereferences = dereferences.stream() + .filter(expression -> !rowNumberNode.getPartitionBy().contains(getBase(expression))) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + rowNumberNode.replaceChildren(ImmutableList.of( + new ProjectNode( + context.getIdAllocator().getNextId(), + rowNumberNode.getSource(), + Assignments.builder() + .putIdentities(rowNumberNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java new file mode 100644 index 000000000000..891151043d6d --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughSort.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.sort; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(B.x), G := f3(C))
+ *      Sort(orderBy = [B])
+ *          Source(A, B, C)
+ *  
+ * to: + *
+ *  Project(D := f1(symbol), E := f2(B.x), G := f3(C))
+ *      Sort(orderBy = [B])
+ *          Project(A, B, C, symbol := A.x)
+ *              Source(A, B, C)
+ * 
+ * + * Pushes down dereference projections through Sort. Excludes dereferences on symbols in ordering scheme to avoid data + * replication, since these symbols cannot be pruned. + */ +public class PushDownDereferencesThroughSort + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughSort(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(sort().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + SortNode sortNode = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // Exclude dereferences on symbols used in ordering scheme to avoid replication of data + dereferences = dereferences.stream() + .filter(expression -> !sortNode.getOrderingScheme().getOrderBy().contains(getBase(expression))) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + sortNode.replaceChildren(ImmutableList.of( + new ProjectNode( + context.getIdAllocator().getNextId(), + sortNode.getSource(), + Assignments.builder() + .putIdentities(sortNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java new file mode 100644 index 000000000000..635be37ecd87 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughTopN.java @@ -0,0 +1,120 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.TopNNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.topN; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(D := f1(A.x), E := f2(B.x), G := f3(C))
+ *      TopN(orderBy = [B])
+ *          Source(A, B, C)
+ *  
+ * to: + *
+ *  Project(D := f1(symbol), E := f2(B.x), G := f3(C))
+ *      TopN(orderBy = [B])
+ *          Project(A, B, C, symbol := A.x)
+ *              Source(A, B, C)
+ * 
+ * + * Pushes down dereference projections through TopN. Excludes dereferences on symbols in ordering scheme to avoid data + * replication, since these symbols cannot be pruned. + */ +public class PushDownDereferencesThroughTopN + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughTopN(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(topN().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + TopNNode topNNode = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // Exclude dereferences on symbols being used in orderBy + dereferences = dereferences.stream() + .filter(expression -> !topNNode.getOrderingScheme().getOrderBy().contains(getBase(expression))) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + topNNode.replaceChildren(ImmutableList.of( + new ProjectNode( + context.getIdAllocator().getNextId(), + topNNode.getSource(), + Assignments.builder() + .putIdentities(topNNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRowNumber.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRowNumber.java new file mode 100644 index 000000000000..6caec59fdd64 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughTopNRowNumber.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.TopNRowNumberNode; +import io.prestosql.sql.planner.plan.WindowNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.topNRowNumber; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(E := f1(A.x), G := f2(B.x), H := f3(C.x), J := f4(D))
+ *      TopNRowNumber(partitionBy = [B], orderBy = [C])
+ *          Source(A, B, C, D)
+ *  
+ * to: + *
+ *  Project(E := f1(symbol), G := f2(B.x), H := f3(C.x), J := f4(D))
+ *      TopNRowNumber(partitionBy = [B], orderBy = [C])
+ *          Project(A, B, C, D, symbol := A.x)
+ *              Source(A, B, C, D)
+ * 
+ * + * Pushes down dereference projections through TopNRowNumber. Excludes dereferences on symbols in partitionBy and ordering scheme + * to avoid data replication, since these symbols cannot be pruned. + */ +public class PushDownDereferencesThroughTopNRowNumber + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughTopNRowNumber(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(topNRowNumber().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + TopNRowNumberNode topNRowNumberNode = captures.get(CHILD); + + // Extract dereferences from project node assignments for pushdown + Set dereferences = extractDereferences(projectNode.getAssignments().getExpressions(), false); + + // Exclude dereferences on symbols being used in partitionBy and orderBy + WindowNode.Specification specification = topNRowNumberNode.getSpecification(); + dereferences = dereferences.stream() + .filter(expression -> { + Symbol symbol = getBase(expression); + return !specification.getPartitionBy().contains(symbol) + && !specification.getOrderingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()).contains(symbol); + }) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + topNRowNumberNode.replaceChildren(ImmutableList.of( + new ProjectNode( + context.getIdAllocator().getNextId(), + topNRowNumberNode.getSource(), + Assignments.builder() + .putIdentities(topNRowNumberNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()))), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java new file mode 100644 index 000000000000..2c162fad703a --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferencesThroughWindow.java @@ -0,0 +1,156 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.WindowNode; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Map; +import java.util.Set; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.ExpressionNodeInliner.replaceExpression; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.getBase; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.window; +import static java.util.Objects.requireNonNull; + +/** + * Transforms: + *
+ *  Project(G := f1(A.x), H := f2(B.x), J := f3(C.x), K := f4(D.x), L := f5(F))
+ *      Window(orderBy = [B], partitionBy = [C], min_D := min(D))
+ *          Source(A, B, C, D, E, F)
+ *  
+ * to: + *
+ *  Project(G := f1(symbol), H := f2(B.x), J := f3(C.x), K := f4(D.x), L := f5(F))
+ *      Window(orderBy = [B], partitionBy = [C], min_D := min(D))
+ *          Project(A, B, C, D, E, F, symbol := A.x)
+ *              Source(A, B, C, D, E, F)
+ * 
+ * + * Pushes down dereference projections through Window. Excludes dereferences on symbols in ordering scheme and partitionBy + * to avoid data replication, since these symbols cannot be pruned. + */ +public class PushDownDereferencesThroughWindow + implements Rule +{ + private static final Capture CHILD = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferencesThroughWindow(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project() + .with(source().matching(window().capturedAs(CHILD))); + } + + @Override + public Result apply(ProjectNode projectNode, Captures captures, Context context) + { + WindowNode windowNode = captures.get(CHILD); + + // Extract dereferences for pushdown + Set dereferences = extractDereferences( + ImmutableList.builder() + .addAll(projectNode.getAssignments().getExpressions()) + // also include dereference projections used in window functions + .addAll(windowNode.getWindowFunctions().values().stream() + .flatMap(function -> function.getArguments().stream()) + .collect(toImmutableList())) + .build(), + false); + + WindowNode.Specification specification = windowNode.getSpecification(); + dereferences = dereferences.stream() + .filter(expression -> { + Symbol symbol = getBase(expression); + // Exclude partitionBy, orderBy and synthesized symbols + return !specification.getPartitionBy().contains(symbol) && + !specification.getOrderingScheme().map(OrderingScheme::getOrderBy).orElse(ImmutableList.of()).contains(symbol) && + !windowNode.getCreatedSymbols().contains(symbol); + }) + .collect(toImmutableSet()); + + if (dereferences.isEmpty()) { + return Result.empty(); + } + + // Create new symbols for dereference expressions + Assignments dereferenceAssignments = Assignments.of(dereferences, context.getSession(), context.getSymbolAllocator(), typeAnalyzer); + + // Rewrite project node assignments using new symbols for dereference expressions + Map mappings = HashBiMap.create(dereferenceAssignments.getMap()) + .inverse() + .entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> entry.getValue().toSymbolReference())); + Assignments newAssignments = projectNode.getAssignments().rewrite(expression -> replaceExpression(expression, mappings)); + + return Result.ofPlanNode( + new ProjectNode( + context.getIdAllocator().getNextId(), + new WindowNode( + windowNode.getId(), + new ProjectNode( + context.getIdAllocator().getNextId(), + windowNode.getSource(), + Assignments.builder() + .putIdentities(windowNode.getSource().getOutputSymbols()) + .putAll(dereferenceAssignments) + .build()), + windowNode.getSpecification(), + // Replace dereference expressions in functions + windowNode.getWindowFunctions().entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + entry -> { + WindowNode.Function oldFunction = entry.getValue(); + return new WindowNode.Function( + oldFunction.getResolvedFunction(), + oldFunction.getArguments().stream() + .map(expression -> replaceExpression(expression, mappings)) + .collect(toImmutableList()), + oldFunction.getFrame(), + oldFunction.isIgnoreNulls()); + })), + windowNode.getHashSymbol(), + windowNode.getPrePartitionedInputs(), + windowNode.getPreSortedOrderPrefix()), + newAssignments)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughProject.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughProject.java index aabbae431d8b..177676b8b23b 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughProject.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushLimitThroughProject.java @@ -14,6 +14,7 @@ package io.prestosql.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.prestosql.matching.Capture; import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; @@ -25,7 +26,11 @@ import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.SymbolReference; +import java.util.Set; + import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.exclusiveDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; import static io.prestosql.sql.planner.iterative.rule.Util.transpose; import static io.prestosql.sql.planner.plan.Patterns.limit; import static io.prestosql.sql.planner.plan.Patterns.project; @@ -54,6 +59,14 @@ public Result apply(LimitNode parent, Captures captures, Context context) { ProjectNode projectNode = captures.get(CHILD); + // Do not push down if the projection is made up of symbol references and exclusive dereferences. This prevents + // undoing of PushDownDereferencesThroughLimit. We still push limit in the case of overlapping dereferences since + // it enables PushDownDereferencesThroughLimit rule to push optimal dereferences. + Set projections = ImmutableSet.copyOf(projectNode.getAssignments().getExpressions()); + if (!extractDereferences(projections, false).isEmpty() && exclusiveDereferences(projections)) { + return Result.empty(); + } + // for a LimitNode without ties, simply reorder the nodes if (!parent.isWithTies()) { return Result.ofPlanNode(transpose(parent, projectNode)); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNThroughProject.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNThroughProject.java index 6a545b707197..b57ea850f821 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNThroughProject.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushTopNThroughProject.java @@ -14,6 +14,7 @@ package io.prestosql.sql.planner.iterative.rule; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import io.prestosql.matching.Capture; import io.prestosql.matching.Captures; import io.prestosql.matching.Pattern; @@ -31,8 +32,11 @@ import java.util.List; import java.util.Optional; +import java.util.Set; import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.exclusiveDereferences; +import static io.prestosql.sql.planner.iterative.rule.DereferencePushdown.extractDereferences; import static io.prestosql.sql.planner.plan.Patterns.project; import static io.prestosql.sql.planner.plan.Patterns.source; import static io.prestosql.sql.planner.plan.Patterns.topN; @@ -77,6 +81,14 @@ public Result apply(TopNNode parent, Captures captures, Context context) { ProjectNode projectNode = captures.get(PROJECT_CHILD); + // Do not push down if the projection is made up of symbol references and exclusive dereferences. This prevents + // undoing of PushDownDereferencesThroughTopN. We still push topN in the case of overlapping dereferences since + // it enables PushDownDereferencesThroughTopN rule to push optimal dereferences. + Set projections = ImmutableSet.copyOf(projectNode.getAssignments().getExpressions()); + if (!extractDereferences(projections, false).isEmpty() && exclusiveDereferences(projections)) { + return Result.empty(); + } + // do not push topN between projection and filter(table scan) so that they can be merged into a PageProcessor PlanNode projectSource = context.getLookup().resolve(projectNode.getSource()); if (projectSource instanceof FilterNode) { diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java index 2e6b407fc258..726ed56deae0 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java @@ -19,7 +19,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; +import io.prestosql.Session; +import io.prestosql.spi.type.Type; import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.SymbolAllocator; +import io.prestosql.sql.planner.TypeAnalyzer; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; import io.prestosql.sql.tree.ExpressionTreeRewriter; @@ -78,6 +82,18 @@ public static Assignments of(Symbol symbol1, Expression expression1, Symbol symb return builder().put(symbol1, expression1).put(symbol2, expression2).build(); } + public static Assignments of(Collection expressions, Session session, SymbolAllocator symbolAllocator, TypeAnalyzer typeAnalyzer) + { + Assignments.Builder assignments = Assignments.builder(); + + for (Expression expression : expressions) { + Type type = typeAnalyzer.getType(session, symbolAllocator.getTypes(), expression); + assignments.put(symbolAllocator.newSymbol(expression, type), expression); + } + + return assignments.build(); + } + private final Map assignments; @JsonCreator @@ -217,9 +233,9 @@ public Builder putAll(Assignments assignments) return putAll(assignments.getMap()); } - public Builder putAll(Map assignments) + public Builder putAll(Map assignments) { - for (Entry assignment : assignments.entrySet()) { + for (Entry assignment : assignments.entrySet()) { put(assignment.getKey(), assignment.getValue()); } return this; diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/WindowNode.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/WindowNode.java index e5c382243526..b3df880989ef 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/WindowNode.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/WindowNode.java @@ -336,9 +336,9 @@ public Function( @JsonProperty("frame") Frame frame, @JsonProperty("ignoreNulls") boolean ignoreNulls) { - this.resolvedFunction = requireNonNull(resolvedFunction, "Signature is null"); + this.resolvedFunction = requireNonNull(resolvedFunction, "resolvedFunction is null"); this.arguments = requireNonNull(arguments, "arguments is null"); - this.frame = requireNonNull(frame, "Frame is null"); + this.frame = requireNonNull(frame, "frame is null"); this.ignoreNulls = ignoreNulls; } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java new file mode 100644 index 000000000000..7edfd7a40446 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java @@ -0,0 +1,253 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.limit; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.unnest; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; + +public class TestDereferencePushDown + extends BasePlanTest +{ + @Test + public void testDereferencePushdownMultiLevel() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)))) " + + "SELECT a.msg.x, a.msg, b.msg.y FROM t a CROSS JOIN t b", + output(ImmutableList.of("a_msg_x", "a_msg", "b_msg_y"), + strictProject( + ImmutableMap.of( + "a_msg_x", PlanMatchPattern.expression("a_msg.x"), + "a_msg", PlanMatchPattern.expression("a_msg"), + "b_msg_y", PlanMatchPattern.expression("b_msg_y")), + join(INNER, ImmutableList.of(), + values("a_msg"), + strictProject( + ImmutableMap.of("b_msg_y", PlanMatchPattern.expression("b_msg.y")), + values("b_msg")))))); + } + + @Test + public void testDereferencePushdownJoin() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a, t b " + + "WHERE a.msg.y = b.msg.y", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + strictProject(ImmutableMap.of("a_y", expression("msg.y")), + values("msg"))), + anyTree( + strictProject(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT a.msg.y " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x > bigint '5'", + output(ImmutableList.of("a_y"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + strictProject(ImmutableMap.of("a_y", expression("msg.y")), + filter("msg.x > bigint '5'", + values("msg")))), + anyTree( + strictProject(ImmutableMap.of("b_y", expression("msg.y")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x + b.msg.x < BIGINT '10'", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), Optional.of("a_x + b_x < bigint '10'"), + anyTree( + strictProject(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + values("msg"))), + anyTree( + strictProject(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownFilter() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT a.msg.y, b.msg.x " + + "FROM t a CROSS JOIN t b " + + "WHERE a.msg.x = 7 OR IS_FINITE(b.msg.y)", + anyTree( + join(INNER, ImmutableList.of(), + strictProject(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg")), + strictProject(ImmutableMap.of("b_x", expression("msg.x"), "b_y", expression("msg.y")), + values("msg"))))); + } + + @Test + public void testDereferencePushdownWindow() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT msg.x AS x, ROW_NUMBER() OVER (PARTITION BY msg.y) AS rn " + + "FROM t ", + anyTree( + strictProject(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg")))); + + assertPlan( + "WITH t(msg1, msg2, msg3, msg4, msg5) AS (VALUES " + + // Use two rows to avoid any optimizations around short-circuting operations + "ROW(" + + " CAST(ROW(1, 0.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(2, 0.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(3, 0.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(4, 0.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(5, 0.0) AS ROW(x BIGINT, y DOUBLE)))," + + "ROW(" + + " CAST(ROW(1, 1.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(2, 2.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(3, 3.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(4, 4.0) AS ROW(x BIGINT, y DOUBLE))," + + " CAST(ROW(5, 5.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT " + + " msg1.x AS x1, " + + " msg2.x AS x2, " + + " msg3.x AS x3, " + + " msg4.x AS x4, " + + " msg5.x AS x5, " + + " MIN(msg3) OVER (PARTITION BY msg1 ORDER BY msg2) AS msg6," + + " MIN(msg4.x) OVER (PARTITION BY msg1 ORDER BY msg2) AS bigint_msg4 " + + "FROM t", + anyTree( + project( + ImmutableMap.of( + "msg1", expression("msg1"), // not pushed down because used in partition by + "msg2", expression("msg2"), // not pushed down because used in order by + "msg3", expression("msg3"), // not pushed down because used in window function + "msg4_x", expression("msg4.x"), // pushed down because msg4.x used in window function + "msg5_x", expression("msg5.x")), // pushed down because window node does not refer it + values("msg1", "msg2", "msg3", "msg4", "msg5")))); + } + + @Test + public void testDereferencePushdownSemiJoin() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0, 3) AS ROW(x BIGINT, y DOUBLE, z BIGINT)))) " + + "SELECT msg.y " + + "FROM t " + + "WHERE " + + "msg.x IN (SELECT msg.z FROM t)", + anyTree( + semiJoin("a_x", "b_z", "semi_join_symbol", + anyTree( + strictProject(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg"))), + anyTree( + strictProject(ImmutableMap.of("b_z", expression("msg.z")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownLimit() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))), ROW(CAST(ROW(3, 4.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT msg.x * 3 FROM t limit 1", + anyTree( + strictProject(ImmutableMap.of("x_into_3", expression("msg_x * BIGINT '3'")), + limit(1, + strictProject(ImmutableMap.of("msg_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a, t b " + + "WHERE a.msg.y = b.msg.y " + + "LIMIT 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + strictProject(ImmutableMap.of("a_y", expression("msg.y")), + values("msg"))), + anyTree( + strictProject(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT a.msg.y " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x > BIGINT '5' " + + "LIMIT 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + strictProject(ImmutableMap.of("a_y", expression("msg.y")), + filter("msg.x > bigint '5'", + values("msg")))), + anyTree( + strictProject(ImmutableMap.of("b_y", expression("msg.y")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x + b.msg.x < BIGINT '10' " + + "LIMIT 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), Optional.of("a_x + b_x < bigint '10'"), + anyTree( + strictProject(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + values("msg"))), + anyTree( + strictProject(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownUnnest() + { + assertPlan("WITH t(msg, array) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)), ARRAY[1, 2, 3])) " + + "SELECT a.msg.x " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "CROSS JOIN UNNEST (a.array) " + + "WHERE a.msg.x + b.msg.x < BIGINT '10'", + output(ImmutableList.of("expr"), + strictProject(ImmutableMap.of("expr", expression("a_x")), + unnest( + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + Optional.of("a_x + b_x < bigint '10'"), + anyTree( + strictProject(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x"), "a_z", expression("array")), + values("msg", "array"))), + anyTree( + strictProject(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java index 39463fbe61f6..d9c8fc65f149 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java @@ -43,6 +43,7 @@ import io.prestosql.sql.tree.SearchedCaseExpression; import io.prestosql.sql.tree.SimpleCaseExpression; import io.prestosql.sql.tree.StringLiteral; +import io.prestosql.sql.tree.SubscriptExpression; import io.prestosql.sql.tree.SymbolReference; import io.prestosql.sql.tree.TryExpression; import io.prestosql.sql.tree.WhenClause; @@ -545,6 +546,17 @@ && process(actual.getPattern(), expected.getPattern()) && process(actual.getEscape(), expected.getEscape()); } + @Override + protected Boolean visitSubscriptExpression(SubscriptExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SubscriptExpression)) { + return false; + } + + SubscriptExpression expected = (SubscriptExpression) expectedExpression; + return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex()); + } + private boolean process(List actuals, List expecteds) { if (actuals.size() != expecteds.size()) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java index 22735c8ca66e..3bfa53c47b5a 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/PlanMatchPattern.java @@ -497,14 +497,14 @@ public static PlanMatchPattern unnest(PlanMatchPattern source) return node(UnnestNode.class, source); } - public static PlanMatchPattern unnest(List replicateSymbols, Map mappings, PlanMatchPattern source) + public static PlanMatchPattern unnest(List replicateSymbols, List mappings, PlanMatchPattern source) { return unnest(replicateSymbols, mappings, Optional.empty(), INNER, Optional.empty(), source); } public static PlanMatchPattern unnest( List replicateSymbols, - Map mappings, + List mappings, Optional ordinalitySymbol, Type type, Optional filter, @@ -513,15 +513,17 @@ public static PlanMatchPattern unnest( PlanMatchPattern result = node(UnnestNode.class, source) .with(new UnnestMatcher( replicateSymbols, - mappings.values().stream() - .map(UnnestedSymbolMatcher::getSymbol) - .collect(toImmutableList()), + mappings, ordinalitySymbol, type, filter.map(predicate -> rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(predicate, new ParsingOptions()))))); - for (Map.Entry mapping : mappings.entrySet()) { - result.withAlias(mapping.getKey(), mapping.getValue()); - } + + mappings.forEach(mapping -> { + for (int i = 0; i < mapping.getOutputs().size(); i++) { + result.withAlias(mapping.getOutputs().get(i), new UnnestedSymbolMatcher(mapping.getInput(), i)); + } + }); + ordinalitySymbol.ifPresent(symbol -> result.withAlias(symbol, new OrdinalitySymbolMatcher())); return result; @@ -1046,6 +1048,33 @@ public String toString() } } + public static class UnnestMapping + { + private final String input; + private final List outputs; + + private UnnestMapping(String input, List outputs) + { + this.input = requireNonNull(input, "input is null"); + this.outputs = requireNonNull(outputs, "outputs is null"); + } + + public static UnnestMapping unnestMapping(String input, List outputs) + { + return new UnnestMapping(input, outputs); + } + + public String getInput() + { + return input; + } + + public List getOutputs() + { + return outputs; + } + } + public static class Ordering { private final String field; diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java index a6eb4f114a04..3a2021a04d98 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestMatcher.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.Optional; +import java.util.stream.IntStream; import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkState; @@ -36,15 +37,15 @@ final class UnnestMatcher implements Matcher { private final List replicateSymbols; - private final List unnestSymbols; + private final List unnestMappings; private final Optional ordinalitySymbol; private final JoinNode.Type type; private final Optional filter; - public UnnestMatcher(List replicateSymbols, List unnestSymbols, Optional ordinalitySymbol, JoinNode.Type type, Optional filter) + public UnnestMatcher(List replicateSymbols, List unnestMappings, Optional ordinalitySymbol, JoinNode.Type type, Optional filter) { this.replicateSymbols = requireNonNull(replicateSymbols, "replicateSymbols is null"); - this.unnestSymbols = requireNonNull(unnestSymbols, "mappings is null"); + this.unnestMappings = requireNonNull(unnestMappings, "unnestMappings is null"); this.ordinalitySymbol = requireNonNull(ordinalitySymbol, "ordinalitySymbol is null"); this.type = requireNonNull(type, "type is null"); this.filter = requireNonNull(filter, "filter is null"); @@ -78,16 +79,16 @@ public MatchResult detailMatches(PlanNode node, StatsProvider stats, Session ses return NO_MATCH; } - if (unnestNode.getMappings().size() != unnestSymbols.size()) { + if (unnestNode.getMappings().size() != unnestMappings.size()) { return NO_MATCH; } - if (!unnestSymbols.stream() - .map(symbolAliases::get) - .map(Symbol::from) - .collect(toImmutableList()) - .equals(unnestNode.getMappings().stream() - .map(Mapping::getInput) - .collect(toImmutableList()))) { + + if (!IntStream.range(0, unnestMappings.size()).boxed().allMatch(index -> { + Mapping nodeMapping = unnestNode.getMappings().get(index); + PlanMatchPattern.UnnestMapping patternMapping = unnestMappings.get(index); + return nodeMapping.getInput().toSymbolReference().equals(symbolAliases.get(patternMapping.getInput())) && + patternMapping.getOutputs().size() == nodeMapping.getOutputs().size(); + })) { return NO_MATCH; } @@ -119,7 +120,7 @@ public String toString() .omitNullValues() .add("type", type) .add("replicateSymbols", replicateSymbols) - .add("unnestSymbols", unnestSymbols) + .add("unnestMappings", unnestMappings) .add("ordinalitySymbol", ordinalitySymbol.orElse(null)) .add("filter", filter.orElse(null)) .toString(); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestedSymbolMatcher.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestedSymbolMatcher.java index 14e6c6b25d6d..2aa3b6a66184 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestedSymbolMatcher.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/UnnestedSymbolMatcher.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Optional; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.Iterables.getOnlyElement; @@ -32,10 +33,13 @@ public class UnnestedSymbolMatcher implements RvalueMatcher { private final String symbol; + private final int index; - public UnnestedSymbolMatcher(String symbol) + public UnnestedSymbolMatcher(String symbol, int index) { this.symbol = requireNonNull(symbol, "symbol is null"); + checkArgument(index >= 0, "index cannot be negative"); + this.index = index; } @Override @@ -57,9 +61,12 @@ public Optional getAssignedSymbol(PlanNode node, Session session, Metada } Mapping mapping = getOnlyElement(matches); - checkState(mapping.getOutputs().size() == 1, "alias matching not supported for multiple output symbols"); - return Optional.of(getOnlyElement(mapping.getOutputs())); + if (index >= mapping.getOutputs().size()) { + return Optional.empty(); + } + + return Optional.of(mapping.getOutputs().get(index)); } public String getSymbol() diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java index b185e81ea8ed..c84e9d36165d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java @@ -13,13 +13,18 @@ */ package io.prestosql.sql.planner.iterative.rule; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.type.RowType; import io.prestosql.sql.planner.assertions.ExpressionMatcher; import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.plan.Assignments; import org.testng.annotations.Test; +import java.util.Optional; + +import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; import static io.prestosql.sql.planner.iterative.rule.test.PlanBuilder.expression; @@ -27,6 +32,8 @@ public class TestInlineProjections extends BaseRuleTest { + private static final RowType MSG_TYPE = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("x"), VARCHAR), new RowType.Field(Optional.of("y"), VARCHAR))); + @Test public void test() { @@ -41,14 +48,16 @@ public void test() .put(p.symbol("multi_literal_2"), expression("literal + 2")) // literal referenced multiple times .put(p.symbol("single_complex"), expression("complex_2 + 2")) // complex expression reference only once .put(p.symbol("try"), expression("try(complex / literal)")) + .put(p.symbol("msg_xx"), expression("z + 1")) .build(), p.project(Assignments.builder() .put(p.symbol("symbol"), expression("x")) .put(p.symbol("complex"), expression("x * 2")) .put(p.symbol("literal"), expression("1")) .put(p.symbol("complex_2"), expression("x - 1")) + .put(p.symbol("z"), expression("msg.x")) .build(), - p.values(p.symbol("x"))))) + p.values(p.symbol("x"), p.symbol("msg", MSG_TYPE))))) .matches( project( ImmutableMap.builder() @@ -59,12 +68,14 @@ public void test() .put("out5", PlanMatchPattern.expression("1 + 2")) .put("out6", PlanMatchPattern.expression("x - 1 + 2")) .put("out7", PlanMatchPattern.expression("try(y / 1)")) + .put("out8", PlanMatchPattern.expression("z + 1")) .build(), project( ImmutableMap.of( "x", PlanMatchPattern.expression("x"), - "y", PlanMatchPattern.expression("x * 2")), - values(ImmutableMap.of("x", 0))))); + "y", PlanMatchPattern.expression("x * 2"), + "z", PlanMatchPattern.expression("msg.x")), + values(ImmutableMap.of("x", 0, "msg", 1))))); } @Test diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestColumns.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestColumns.java index 6bb6adca07d2..072d950d1076 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestColumns.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestColumns.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.assertions.UnnestedSymbolMatcher; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; import io.prestosql.sql.planner.plan.Assignments; @@ -25,6 +24,7 @@ import java.util.Optional; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.unnest; @@ -58,7 +58,7 @@ public void testPruneOrdinalitySymbol() ImmutableMap.of("replicate_symbol", expression("replicate_symbol"), "unnested_symbol", expression("unnested_symbol")), unnest( ImmutableList.of("replicate_symbol"), - ImmutableMap.of("unnested_symbol", new UnnestedSymbolMatcher("unnest_symbol")), + ImmutableList.of(unnestMapping("unnest_symbol", ImmutableList.of("unnested_symbol"))), Optional.empty(), INNER, Optional.empty(), @@ -89,7 +89,7 @@ public void testPruneReplicateSymbol() ImmutableMap.of("unnested_symbol", expression("unnested_symbol"), "ordinality_symbol", expression("ordinality_symbol")), unnest( ImmutableList.of(), - ImmutableMap.of("unnested_symbol", new UnnestedSymbolMatcher("unnest_symbol")), + ImmutableList.of(unnestMapping("unnest_symbol", ImmutableList.of("unnested_symbol"))), Optional.of("ordinality_symbol"), INNER, Optional.empty(), diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java index 59a085263f73..e9d4741b32bc 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPruneUnnestSourceColumns.java @@ -16,11 +16,11 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.prestosql.sql.planner.Symbol; -import io.prestosql.sql.planner.assertions.UnnestedSymbolMatcher; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.plan.UnnestNode.Mapping; import org.testng.annotations.Test; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.unnest; @@ -46,7 +46,7 @@ public void testNotAllInputsReferenced() .matches( unnest( ImmutableList.of("replicate_symbol"), - ImmutableMap.of("unnested_symbol", new UnnestedSymbolMatcher("unnest_symbol")), + ImmutableList.of(unnestMapping("unnest_symbol", ImmutableList.of("unnested_symbol"))), strictProject( ImmutableMap.of("replicate_symbol", expression("replicate_symbol"), "unnest_symbol", expression("unnest_symbol")), values("replicate_symbol", "unnest_symbol", "unused_symbol")))); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushDownDereferencesRules.java new file mode 100644 index 000000000000..ad07565a5e0e --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -0,0 +1,686 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.connector.CatalogName; +import io.prestosql.metadata.TableHandle; +import io.prestosql.plugin.tpch.TpchColumnHandle; +import io.prestosql.plugin.tpch.TpchTableHandle; +import io.prestosql.spi.block.SortOrder; +import io.prestosql.spi.predicate.TupleDomain; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.RowType; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.planner.OrderingScheme; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.UnnestNode; +import io.prestosql.sql.planner.plan.WindowNode; +import io.prestosql.sql.tree.FrameBound; +import io.prestosql.sql.tree.QualifiedName; +import io.prestosql.sql.tree.SortItem; +import io.prestosql.sql.tree.WindowFrame; +import io.prestosql.testing.TestingTransactionHandle; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.google.common.base.Predicates.equalTo; +import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; +import static io.prestosql.spi.block.SortOrder.ASC_NULLS_FIRST; +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.BooleanType.BOOLEAN; +import static io.prestosql.spi.type.RowType.field; +import static io.prestosql.spi.type.RowType.rowType; +import static io.prestosql.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.UnnestMapping.unnestMapping; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.assignUniqueId; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.functionCall; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.limit; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.markDistinct; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.rowNumber; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.strictProject; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topN; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topNRowNumber; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.unnest; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.window; +import static io.prestosql.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static io.prestosql.sql.planner.iterative.rule.test.RuleTester.CATALOG_ID; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; +import static io.prestosql.sql.tree.SortItem.NullOrdering.FIRST; +import static io.prestosql.sql.tree.SortItem.Ordering.ASCENDING; +import static java.util.Collections.singletonList; + +public class TestPushDownDereferencesRules + extends BaseRuleTest +{ + private static final RowType ROW_TYPE = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("x"), BIGINT), new RowType.Field(Optional.of("y"), BIGINT))); + + @Test + public void testDoesNotFire() + { + // rule does not fire for symbols + tester().assertThat(new PushDownDereferenceThroughFilter(tester().getTypeAnalyzer())) + .on(p -> + p.filter(expression("x > BIGINT '5'"), + p.values(p.symbol("x")))) + .doesNotFire(); + + // Pushdown is not enabled if dereferences come from an expression that is not a simple dereference chain + tester().assertThat(new PushDownDereferenceThroughProject(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of( + p.symbol("expr_1"), expression("cast(row(a, b) as row(f1 row(x bigint, y bigint), f2 bigint)).f1"), + p.symbol("expr_2"), expression("cast(row(a, b) as row(f1 row(x bigint, y bigint), f2 bigint)).f1.y")), + p.project( + Assignments.of( + p.symbol("a", ROW_TYPE), expression("a"), + p.symbol("b"), expression("b")), + p.values(p.symbol("a", ROW_TYPE), p.symbol("b"))))) + .doesNotFire(); + + // Does not fire when base symbols are referenced along with the dereferences + tester().assertThat(new PushDownDereferenceThroughProject(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of(p.symbol("expr", ROW_TYPE), expression("a"), p.symbol("a_x"), expression("a.x")), + p.project( + Assignments.of(p.symbol("a", ROW_TYPE), expression("a")), + p.values(p.symbol("a", ROW_TYPE))))) + .doesNotFire(); + } + + @Test + public void testPushdownDereferenceThroughProject() + { + tester().assertThat(new PushDownDereferenceThroughProject(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of(p.symbol("x"), expression("msg.x")), + p.project( + Assignments.of( + p.symbol("y"), expression("y"), + p.symbol("msg", ROW_TYPE), expression("msg")), + p.values(p.symbol("msg", ROW_TYPE), p.symbol("y"))))) + .matches( + strictProject( + ImmutableMap.of("x", PlanMatchPattern.expression("msg_x")), + strictProject( + ImmutableMap.of( + "msg_x", PlanMatchPattern.expression("msg.x"), + "y", PlanMatchPattern.expression("y"), + "msg", PlanMatchPattern.expression("msg")), + values("msg", "y")))); + } + + @Test + public void testPushDownDereferenceThroughJoin() + { + tester().assertThat(new PushDownDereferenceThroughJoin(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("left_x"), expression("msg1.x")) + .put(p.symbol("right_y"), expression("msg2.y")) + .put(p.symbol("z"), expression("z")) + .build(), + p.join(INNER, + p.values(p.symbol("msg1", ROW_TYPE), p.symbol("unreferenced_symbol")), + p.values(p.symbol("msg2", ROW_TYPE), p.symbol("z"))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("left_x", PlanMatchPattern.expression("x")) + .put("right_y", PlanMatchPattern.expression("y")) + .put("z", PlanMatchPattern.expression("z")) + .build(), + join(INNER, ImmutableList.of(), + strictProject( + ImmutableMap.of( + "x", PlanMatchPattern.expression("msg1.x"), + "msg1", PlanMatchPattern.expression("msg1"), + "unreferenced_symbol", PlanMatchPattern.expression("unreferenced_symbol")), + values("msg1", "unreferenced_symbol")), + strictProject( + ImmutableMap.builder() + .put("y", PlanMatchPattern.expression("msg2.y")) + .put("z", PlanMatchPattern.expression("z")) + .put("msg2", PlanMatchPattern.expression("msg2")) + .build(), + values("msg2", "z"))))); + + // Verify pushdown for filters + tester().assertThat(new PushDownDereferenceThroughJoin(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of( + p.symbol("expr"), expression("msg1.x"), + p.symbol("expr_2"), expression("msg2")), + p.join(INNER, + p.values(p.symbol("msg1", ROW_TYPE)), + p.values(p.symbol("msg2", ROW_TYPE)), + p.expression("msg1.x + msg2.y > BIGINT '10'")))) + .matches( + project( + ImmutableMap.of( + "expr", PlanMatchPattern.expression("msg1_x"), + "expr_2", PlanMatchPattern.expression("msg2")), + join(INNER, ImmutableList.of(), Optional.of("msg1_x + msg2.y > BIGINT '10'"), + strictProject( + ImmutableMap.of( + "msg1_x", PlanMatchPattern.expression("msg1.x"), + "msg1", PlanMatchPattern.expression("msg1")), + values("msg1")), + values("msg2")))); + } + + @Test + public void testPushdownDereferencesThroughSemiJoin() + { + tester().assertThat(new PushDownDereferenceThroughSemiJoin(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg1_x"), expression("msg1.x")) + .put(p.symbol("msg2_x"), expression("msg2.x")) + .build(), + p.semiJoin( + p.symbol("msg2", ROW_TYPE), + p.symbol("filtering_msg", ROW_TYPE), + p.symbol("match"), + Optional.empty(), + Optional.empty(), + p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE)), + p.values(p.symbol("filtering_msg", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("msg1_x", PlanMatchPattern.expression("expr")) + .put("msg2_x", PlanMatchPattern.expression("msg2.x")) // Not pushed down because msg2 is sourceJoinSymbol + .build(), + semiJoin( + "msg2", + "filtering_msg", + "match", + strictProject( + ImmutableMap.of( + "expr", PlanMatchPattern.expression("msg1.x"), + "msg1", PlanMatchPattern.expression("msg1"), + "msg2", PlanMatchPattern.expression("msg2")), + values("msg1", "msg2")), + values("filtering_msg")))); + } + + @Test + public void testPushdownDereferencesThroughUnnest() + { + ArrayType arrayType = new ArrayType(BIGINT); + tester().assertThat(new PushDownDereferenceThroughUnnest(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of(p.symbol("x"), expression("msg.x")), + p.unnest( + ImmutableList.of(p.symbol("msg", ROW_TYPE)), + ImmutableList.of(new UnnestNode.Mapping(p.symbol("arr", arrayType), ImmutableList.of(p.symbol("field")))), + Optional.empty(), + INNER, + Optional.empty(), + p.values(p.symbol("msg", ROW_TYPE), p.symbol("arr", arrayType))))) + .matches( + strictProject( + ImmutableMap.of("x", PlanMatchPattern.expression("msg_x")), + unnest( + strictProject( + ImmutableMap.of( + "msg_x", PlanMatchPattern.expression("msg.x"), + "msg", PlanMatchPattern.expression("msg"), + "arr", PlanMatchPattern.expression("arr")), + values("msg", "arr"))))); + + // Test with dereferences on unnested column + RowType rowType = rowType(field("f1", BIGINT), field("f2", BIGINT)); + ArrayType nestedColumnType = new ArrayType(rowType(field("f1", BIGINT), field("f2", rowType))); + + tester().assertThat(new PushDownDereferenceThroughUnnest(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of( + p.symbol("deref_replicate", BIGINT), expression("replicate.f2"), + p.symbol("deref_unnest", BIGINT), expression("unnested_row.f2")), + p.unnest( + ImmutableList.of(p.symbol("replicate", rowType)), + ImmutableList.of( + new UnnestNode.Mapping( + p.symbol("nested", nestedColumnType), + ImmutableList.of(p.symbol("unnested_bigint", BIGINT), p.symbol("unnested_row", rowType)))), + p.values(p.symbol("replicate", rowType), p.symbol("nested", nestedColumnType))))) + .matches( + strictProject( + ImmutableMap.of( + "deref_replicate", PlanMatchPattern.expression("symbol"), + "deref_unnest", PlanMatchPattern.expression("unnested_row.f2")), // not pushed down + unnest( + ImmutableList.of("replicate", "symbol"), + ImmutableList.of(unnestMapping("nested", ImmutableList.of("unnested_bigint", "unnested_row"))), + strictProject( + ImmutableMap.of( + "symbol", PlanMatchPattern.expression("replicate.f2"), + "replicate", PlanMatchPattern.expression("replicate"), + "nested", PlanMatchPattern.expression("nested")), + values("replicate", "nested"))))); + } + + @Test + public void testExtractDereferencesFromFilterAboveScan() + { + TableHandle testTable = new TableHandle( + new CatalogName(CATALOG_ID), + new TpchTableHandle("orders", 1.0), + TestingTransactionHandle.create(), + Optional.empty()); + + RowType nestedRowType = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("nested"), ROW_TYPE))); + tester().assertThat(new ExtractDereferencesFromFilterAboveScan(tester().getTypeAnalyzer())) + .on(p -> + p.filter(expression("a.nested.x != 5 AND b.y = 2 AND CAST(a.nested as JSON) is not null"), + p.tableScan( + testTable, + ImmutableList.of(p.symbol("a", nestedRowType), p.symbol("b", ROW_TYPE)), + ImmutableMap.of( + p.symbol("a", nestedRowType), new TpchColumnHandle("a", nestedRowType), + p.symbol("b", ROW_TYPE), new TpchColumnHandle("b", ROW_TYPE))))) + .matches(project( + filter("expr != 5 AND expr_0 = 2 AND CAST(expr_1 as JSON) is not null", + strictProject( + ImmutableMap.of( + "expr", PlanMatchPattern.expression("a.nested.x"), + "expr_0", PlanMatchPattern.expression("b.y"), + "expr_1", PlanMatchPattern.expression("a.nested"), + "a", PlanMatchPattern.expression("a"), + "b", PlanMatchPattern.expression("b")), + tableScan( + equalTo(testTable.getConnectorHandle()), + TupleDomain.all(), + ImmutableMap.of( + "a", equalTo(new TpchColumnHandle("a", nestedRowType)), + "b", equalTo(new TpchColumnHandle("b", ROW_TYPE)))))))); + } + + @Test + public void testPushdownDereferenceThroughFilter() + { + tester().assertThat(new PushDownDereferenceThroughFilter(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of( + p.symbol("expr", BIGINT), expression("msg.x"), + p.symbol("expr_2", BIGINT), expression("msg2.x")), + p.filter( + expression("msg.x <> 'foo' AND msg2 is NOT NULL"), + p.values(p.symbol("msg", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.of( + "expr", PlanMatchPattern.expression("msg_x"), + "expr_2", PlanMatchPattern.expression("msg2.x")), // not pushed down since predicate contains msg2 reference + filter( + "msg_x <> 'foo' AND msg2 is NOT NULL", + strictProject( + ImmutableMap.of( + "msg_x", PlanMatchPattern.expression("msg.x"), + "msg", PlanMatchPattern.expression("msg"), + "msg2", PlanMatchPattern.expression("msg2")), + values("msg", "msg2"))))); + } + + @Test + public void testPushDownDereferenceThroughLimit() + { + tester().assertThat(new PushDownDereferencesThroughLimit(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg1_x"), expression("msg1.x")) + .put(p.symbol("msg2_y"), expression("msg2.y")) + .put(p.symbol("z"), expression("z")) + .build(), + p.limit(10, + ImmutableList.of(p.symbol("msg2", ROW_TYPE)), + p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE), p.symbol("z"))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("msg1_x", PlanMatchPattern.expression("x")) + .put("msg2_y", PlanMatchPattern.expression("msg2.y")) + .put("z", PlanMatchPattern.expression("z")) + .build(), + limit( + 10, + ImmutableList.of(sort("msg2", ASCENDING, FIRST)), + strictProject( + ImmutableMap.builder() + .put("x", PlanMatchPattern.expression("msg1.x")) + .put("z", PlanMatchPattern.expression("z")) + .put("msg1", PlanMatchPattern.expression("msg1")) + .put("msg2", PlanMatchPattern.expression("msg2")) + .build(), + values("msg1", "msg2", "z"))))); + } + + @Test + public void testPushDownDereferenceThroughSort() + { + // Does not fire if symbols are used in the ordering scheme + tester().assertThat(new PushDownDereferencesThroughSort(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg_x"), expression("msg.x")) + .put(p.symbol("msg_y"), expression("msg.y")) + .put(p.symbol("z"), expression("z")) + .build(), + p.sort( + ImmutableList.of(p.symbol("z"), p.symbol("msg", ROW_TYPE)), + p.values(p.symbol("msg", ROW_TYPE), p.symbol("z"))))) + .doesNotFire(); + + tester().assertThat(new PushDownDereferencesThroughSort(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg_x"), expression("msg.x")) + .put(p.symbol("z"), expression("z")) + .build(), + p.sort( + ImmutableList.of(p.symbol("z")), + p.values(p.symbol("msg", ROW_TYPE), p.symbol("z"))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("msg_x", PlanMatchPattern.expression("x")) + .put("z", PlanMatchPattern.expression("z")) + .build(), + sort(ImmutableList.of(sort("z", ASCENDING, SortItem.NullOrdering.FIRST)), + strictProject( + ImmutableMap.builder() + .put("x", PlanMatchPattern.expression("msg.x")) + .put("z", PlanMatchPattern.expression("z")) + .put("msg", PlanMatchPattern.expression("msg")) + .build(), + values("msg", "z"))))); + } + + @Test + public void testPushdownDereferenceThroughRowNumber() + { + tester().assertThat(new PushDownDereferencesThroughRowNumber(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg1_x"), expression("msg1.x")) + .put(p.symbol("msg2_x"), expression("msg2.x")) + .build(), + p.rowNumber( + ImmutableList.of(p.symbol("msg1", ROW_TYPE)), + Optional.empty(), + p.symbol("row_number"), + p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("msg1_x", PlanMatchPattern.expression("msg1.x")) + .put("msg2_x", PlanMatchPattern.expression("expr")) + .build(), + rowNumber( + pattern -> pattern + .partitionBy(ImmutableList.of("msg1")), + strictProject( + ImmutableMap.builder() + .put("expr", PlanMatchPattern.expression("msg2.x")) + .put("msg1", PlanMatchPattern.expression("msg1")) + .put("msg2", PlanMatchPattern.expression("msg2")) + .build(), + values("msg1", "msg2"))))); + } + + @Test + public void testPushdownDereferenceThroughTopNRowNumber() + { + tester().assertThat(new PushDownDereferencesThroughTopNRowNumber(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg1_x"), expression("msg1.x")) + .put(p.symbol("msg2_x"), expression("msg2.x")) + .put(p.symbol("msg3_x"), expression("msg3.x")) + .build(), + p.topNRowNumber( + new WindowNode.Specification( + ImmutableList.of(p.symbol("msg1", ROW_TYPE)), + Optional.of(new OrderingScheme( + ImmutableList.of(p.symbol("msg2", ROW_TYPE)), + ImmutableMap.of(p.symbol("msg2", ROW_TYPE), ASC_NULLS_FIRST)))), + 5, + p.symbol("row_number"), + Optional.empty(), + p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE), p.symbol("msg3", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("msg1_x", PlanMatchPattern.expression("msg1.x")) + .put("msg2_x", PlanMatchPattern.expression("msg2.x")) + .put("msg3_x", PlanMatchPattern.expression("expr")) + .build(), + topNRowNumber( + pattern -> pattern.specification(singletonList("msg1"), singletonList("msg2"), ImmutableMap.of("msg2", ASC_NULLS_FIRST)), + strictProject( + ImmutableMap.builder() + .put("expr", PlanMatchPattern.expression("msg3.x")) + .put("msg1", PlanMatchPattern.expression("msg1")) + .put("msg2", PlanMatchPattern.expression("msg2")) + .put("msg3", PlanMatchPattern.expression("msg3")) + .build(), + values("msg1", "msg2", "msg3"))))); + } + + @Test + public void testPushdownDereferenceThroughTopN() + { + tester().assertThat(new PushDownDereferencesThroughTopN(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg1_x"), expression("msg1.x")) + .put(p.symbol("msg2_x"), expression("msg2.x")) + .build(), + p.topN(5, ImmutableList.of(p.symbol("msg1", ROW_TYPE)), + p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("msg1_x", PlanMatchPattern.expression("msg1.x")) + .put("msg2_x", PlanMatchPattern.expression("expr")) + .build(), + topN(5, ImmutableList.of(sort("msg1", ASCENDING, FIRST)), + strictProject( + ImmutableMap.builder() + .put("expr", PlanMatchPattern.expression("msg2.x")) + .put("msg1", PlanMatchPattern.expression("msg1")) + .put("msg2", PlanMatchPattern.expression("msg2")) + .build(), + values("msg1", "msg2"))))); + } + + @Test + public void testPushdownDereferenceThroughWindow() + { + tester().assertThat(new PushDownDereferencesThroughWindow(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg1_x"), expression("msg1.x")) + .put(p.symbol("msg2_x"), expression("msg2.x")) + .put(p.symbol("msg3_x"), expression("msg3.x")) + .put(p.symbol("msg4_x"), expression("msg4.x")) + .put(p.symbol("msg5_x"), expression("msg5.x")) + .build(), + p.window( + new WindowNode.Specification( + ImmutableList.of(p.symbol("msg1", ROW_TYPE)), + Optional.of(new OrderingScheme( + ImmutableList.of(p.symbol("msg2", ROW_TYPE)), + ImmutableMap.of(p.symbol("msg2", ROW_TYPE), ASC_NULLS_FIRST)))), + ImmutableMap.of( + p.symbol("msg6", ROW_TYPE), + // min function on MSG_TYPE + new WindowNode.Function( + createTestMetadataManager().resolveFunction(QualifiedName.of("min"), fromTypes(ROW_TYPE)), + ImmutableList.of(p.symbol("msg3", ROW_TYPE).toSymbolReference()), + new WindowNode.Frame( + WindowFrame.Type.RANGE, + FrameBound.Type.UNBOUNDED_PRECEDING, + Optional.empty(), + FrameBound.Type.UNBOUNDED_FOLLOWING, + Optional.empty(), + Optional.empty(), + Optional.empty()), + true)), + p.values( + p.symbol("msg1", ROW_TYPE), + p.symbol("msg2", ROW_TYPE), + p.symbol("msg3", ROW_TYPE), + p.symbol("msg4", ROW_TYPE), + p.symbol("msg5", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.builder() + .put("msg1_x", PlanMatchPattern.expression("msg1.x")) // not pushed down because used in partitionBy + .put("msg2_x", PlanMatchPattern.expression("msg2.x")) // not pushed down because used in orderBy + .put("msg3_x", PlanMatchPattern.expression("msg3.x")) // not pushed down because the whole column is used in windowNode function + .put("msg4_x", PlanMatchPattern.expression("expr")) // pushed down because msg4.x is being used in the function + .put("msg5_x", PlanMatchPattern.expression("expr2")) // pushed down because not referenced in windowNode + .build(), + window( + windowMatcherBuilder -> windowMatcherBuilder + .specification(singletonList("msg1"), singletonList("msg2"), ImmutableMap.of("msg2", SortOrder.ASC_NULLS_FIRST)) + .addFunction(functionCall("min", singletonList("msg3"))), + strictProject( + ImmutableMap.builder() + .put("msg1", PlanMatchPattern.expression("msg1")) + .put("msg2", PlanMatchPattern.expression("msg2")) + .put("msg3", PlanMatchPattern.expression("msg3")) + .put("msg4", PlanMatchPattern.expression("msg4")) + .put("msg5", PlanMatchPattern.expression("msg5")) + .put("expr", PlanMatchPattern.expression("msg4.x")) + .put("expr2", PlanMatchPattern.expression("msg5.x")) + .build(), + values("msg1", "msg2", "msg3", "msg4", "msg5"))))); + } + + @Test + public void testPushdownDereferenceThroughAssignUniqueId() + { + tester().assertThat(new PushDownDereferencesThroughAssignUniqueId(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("expr"), expression("msg1.x")) + .build(), + p.assignUniqueId( + p.symbol("unique"), + p.values(p.symbol("msg1", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.of("expr", PlanMatchPattern.expression("msg1_x")), + assignUniqueId( + "unique", + strictProject( + ImmutableMap.builder() + .put("msg1", PlanMatchPattern.expression("msg1")) + .put("msg1_x", PlanMatchPattern.expression("msg1.x")) + .build(), + values("msg1"))))); + } + + @Test + public void testPushdownDereferenceThroughMarkDistinct() + { + tester().assertThat(new PushDownDereferencesThroughMarkDistinct(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg1_x"), expression("msg1.x")) + .put(p.symbol("msg2_x"), expression("msg2.x")) + .build(), + p.markDistinct( + p.symbol("is_distinct", BOOLEAN), + singletonList(p.symbol("msg2", ROW_TYPE)), + p.values(p.symbol("msg1", ROW_TYPE), p.symbol("msg2", ROW_TYPE))))) + .matches( + strictProject( + ImmutableMap.of( + "msg1_x", PlanMatchPattern.expression("expr"), // pushed down + "msg2_x", PlanMatchPattern.expression("msg2.x")), // not pushed down because used in markDistinct + markDistinct( + "is_distinct", + singletonList("msg2"), + strictProject( + ImmutableMap.builder() + .put("msg1", PlanMatchPattern.expression("msg1")) + .put("msg2", PlanMatchPattern.expression("msg2")) + .put("expr", PlanMatchPattern.expression("msg1.x")) + .build(), + values("msg1", "msg2"))))); + } + + @Test + public void testMultiLevelPushdown() + { + Type complexType = rowType(field("f1", rowType(field("f1", BIGINT), field("f2", BIGINT))), field("f2", BIGINT)); + tester().assertThat(new PushDownDereferenceThroughProject(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of( + p.symbol("expr_1"), expression("a.f1"), + p.symbol("expr_2"), expression("a.f1.f1 + 2 + b.f1.f1 + b.f1.f2")), + p.project( + Assignments.identity(ImmutableList.of(p.symbol("a", complexType), p.symbol("b", complexType))), + p.values(p.symbol("a", complexType), p.symbol("b", complexType))))) + .matches( + strictProject( + ImmutableMap.of( + "expr_1", PlanMatchPattern.expression("a_f1"), + "expr_2", PlanMatchPattern.expression("a_f1.f1 + 2 + b_f1_f1 + b_f1_f2")), + strictProject( + ImmutableMap.of( + "a", PlanMatchPattern.expression("a"), + "b", PlanMatchPattern.expression("b"), + "a_f1", PlanMatchPattern.expression("a.f1"), + "b_f1_f1", PlanMatchPattern.expression("b.f1.f1"), + "b_f1_f2", PlanMatchPattern.expression("b.f1.f2")), + values("a", "b")))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughProject.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughProject.java index 45e0735063bf..05f4c1daadc3 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughProject.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushLimitThroughProject.java @@ -15,14 +15,20 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.type.RowType; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.assertions.ExpressionMatcher; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.tree.ArithmeticBinaryExpression; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Identifier; import io.prestosql.sql.tree.SymbolReference; import org.testng.annotations.Test; +import java.util.Optional; + +import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.limit; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; @@ -133,4 +139,44 @@ public void testDoesntPushdownLimitThroughIdentityProjection() p.values(a))); }).doesNotFire(); } + + @Test + public void testDoesntPushDownLimitThroughExclusiveDereferences() + { + RowType rowType = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("x"), BIGINT), new RowType.Field(Optional.of("y"), BIGINT))); + + tester().assertThat(new PushLimitThroughProject()) + .on(p -> { + Symbol a = p.symbol("a", rowType); + return p.limit(1, + p.project( + Assignments.of( + p.symbol("b"), new DereferenceExpression(a.toSymbolReference(), new Identifier("x")), + p.symbol("c"), new DereferenceExpression(a.toSymbolReference(), new Identifier("y"))), + p.values(a))); + }) + .doesNotFire(); + } + + @Test + public void testPushDownLimitThroughOverlappingDereferences() + { + RowType rowType = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("x"), BIGINT), new RowType.Field(Optional.of("y"), BIGINT))); + + tester().assertThat(new PushLimitThroughProject()) + .on(p -> { + Symbol a = p.symbol("a", rowType); + return p.limit(1, + p.project( + Assignments.of( + p.symbol("b"), new DereferenceExpression(a.toSymbolReference(), new Identifier("x")), + p.symbol("c", rowType), a.toSymbolReference()), + p.values(a))); + }) + .matches( + project( + ImmutableMap.of("b", expression("a.x"), "c", expression("a")), + limit(1, + values("a")))); + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java index b712a704c912..2e985d0de425 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushProjectionThroughUnion.java @@ -16,13 +16,18 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.type.RowType; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import io.prestosql.sql.planner.iterative.rule.test.PlanBuilder; import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.tree.ArithmeticBinaryExpression; import io.prestosql.sql.tree.LongLiteral; import org.testng.annotations.Test; +import java.util.Optional; + +import static io.prestosql.spi.type.BigintType.BIGINT; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.union; @@ -31,6 +36,8 @@ public class TestPushProjectionThroughUnion extends BaseRuleTest { + private static final RowType ROW_TYPE = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("x"), BIGINT), new RowType.Field(Optional.of("y"), BIGINT))); + @Test public void testDoesNotFire() { @@ -73,29 +80,38 @@ public void test() Symbol a = p.symbol("a"); Symbol b = p.symbol("b"); Symbol c = p.symbol("c"); + Symbol d = p.symbol("d", ROW_TYPE); Symbol cTimes3 = p.symbol("c_times_3"); + Symbol dX = p.symbol("d_x"); + Symbol z = p.symbol("z", ROW_TYPE); + Symbol w = p.symbol("w", ROW_TYPE); return p.project( - Assignments.of(cTimes3, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, c.toSymbolReference(), new LongLiteral("3"))), + Assignments.of( + cTimes3, new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.MULTIPLY, c.toSymbolReference(), new LongLiteral("3")), + dX, PlanBuilder.expression("d.x")), p.union( ImmutableListMultimap.builder() .put(c, a) .put(c, b) + .put(d, z) + .put(d, w) .build(), ImmutableList.of( - p.values(a), - p.values(b)))); + p.values(a, z), + p.values(b, w)))); }) .matches( union( project( - ImmutableMap.of("a_times_3", expression("a * 3")), - values(ImmutableList.of("a"))), + ImmutableMap.of("a_times_3", expression("a * 3"), "z_x", expression("z.x")), + values(ImmutableList.of("a", "z"))), project( - ImmutableMap.of("b_times_3", expression("b * 3")), - values(ImmutableList.of("b")))) - // verify that data originally on symbols aliased as x1 and x2 is part of exchange output - .withNumberOfOutputColumns(1) + ImmutableMap.of("b_times_3", expression("b * 3"), "w_x", expression("w.x")), + values(ImmutableList.of("b", "w")))) + .withNumberOfOutputColumns(2) .withAlias("a_times_3") - .withAlias("b_times_3")); + .withAlias("b_times_3") + .withAlias("z_x") + .withAlias("w_x")); } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNThroughProject.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNThroughProject.java index cfa44a09d09e..60501d8570a4 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNThroughProject.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushTopNThroughProject.java @@ -15,16 +15,23 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.type.RowType; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.assertions.ExpressionMatcher; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.tree.ArithmeticBinaryExpression; import io.prestosql.sql.tree.BooleanLiteral; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Identifier; import io.prestosql.sql.tree.SymbolReference; import io.prestosql.testing.TestingMetadata; import org.testng.annotations.Test; +import java.util.Optional; + +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.sort; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.topN; @@ -36,6 +43,10 @@ public class TestPushTopNThroughProject extends BaseRuleTest { + private static final RowType rowType = RowType.from(ImmutableList.of( + new RowType.Field(Optional.of("x"), BIGINT), + new RowType.Field(Optional.of("y"), BIGINT))); + @Test public void testPushdownTopNNonIdentityProjection() { @@ -131,4 +142,49 @@ public void testDoNotPushdownTopNThroughProjectionOverTableScan() ImmutableMap.of(a, new TestingMetadata.TestingColumnHandle("a"))))); }).doesNotFire(); } + + @Test + public void testDoesntPushDownTopNThroughExclusiveDereferences() + { + tester().assertThat(new PushTopNThroughProject()) + .on(p -> { + Symbol a = p.symbol("a", rowType); + return p.topN( + 1, + ImmutableList.of(p.symbol("c")), + p.project( + Assignments.builder() + .put(p.symbol("b"), new DereferenceExpression(a.toSymbolReference(), new Identifier("x"))) + .put(p.symbol("c"), new DereferenceExpression(a.toSymbolReference(), new Identifier("y"))) + .build(), + p.values(a))); + }).doesNotFire(); + } + + @Test + public void testPushTopNThroughOverlappingDereferences() + { + tester().assertThat(new PushTopNThroughProject()) + .on(p -> { + Symbol a = p.symbol("a", rowType); + Symbol d = p.symbol("d"); + return p.topN( + 1, + ImmutableList.of(d), + p.project( + Assignments.builder() + .put(p.symbol("b"), new DereferenceExpression(a.toSymbolReference(), new Identifier("x"))) + .put(p.symbol("c", rowType), a.toSymbolReference()) + .put(d, d.toSymbolReference()) + .build(), + p.values(a, d))); + }) + .matches( + project( + ImmutableMap.of("b", expression("a.x"), "c", expression("a"), "d", expression("d")), + topN( + 1, + ImmutableList.of(sort("d", ASCENDING, FIRST)), + values("a", "d")))); + } } diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java index 2bb8c84078b3..be4754ac2fdc 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java @@ -64,11 +64,9 @@ import io.prestosql.sql.planner.plan.MarkDistinctNode; import io.prestosql.sql.planner.plan.OffsetNode; import io.prestosql.sql.planner.plan.OutputNode; -import io.prestosql.sql.planner.plan.PlanFragmentId; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.PlanNodeId; import io.prestosql.sql.planner.plan.ProjectNode; -import io.prestosql.sql.planner.plan.RemoteSourceNode; import io.prestosql.sql.planner.plan.RowNumberNode; import io.prestosql.sql.planner.plan.SampleNode; import io.prestosql.sql.planner.plan.SemiJoinNode; @@ -944,11 +942,6 @@ public TopNRowNumberNode topNRowNumber(Specification specification, int maxRowCo hashSymbol); } - public RemoteSourceNode remoteSourceNode(List fragmentIds, List symbols, ExchangeNode.Type exchangeType) - { - return new RemoteSourceNode(idAllocator.getNextId(), fragmentIds, symbols, Optional.empty(), exchangeType); - } - public static Expression expression(String sql) { return ExpressionUtils.rewriteIdentifiersToSymbolReferences(new SqlParser().createExpression(sql, new ParsingOptions()));