diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java index da2dba6dec2c..ade4453f72ed 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/PlanOptimizers.java @@ -80,6 +80,7 @@ import io.prestosql.sql.planner.iterative.rule.PruneWindowColumns; import io.prestosql.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; import io.prestosql.sql.planner.iterative.rule.PushDeleteIntoConnector; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferences; import io.prestosql.sql.planner.iterative.rule.PushLimitIntoTableScan; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import io.prestosql.sql.planner.iterative.rule.PushLimitThroughOffset; @@ -381,6 +382,11 @@ public PlanOptimizers( new TransformUncorrelatedInPredicateSubqueryToSemiJoin(), new TransformCorrelatedScalarAggregationToJoin(metadata), new TransformCorrelatedJoinToJoin(metadata))), + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + new PushDownDereferences(typeAnalyzer).rules()), new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java index e8afb662dd34..1ae84c6b4262 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/InlineProjections.java @@ -24,6 +24,7 @@ import io.prestosql.sql.planner.plan.Assignments; import io.prestosql.sql.planner.plan.PlanNode; import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.tree.DereferenceExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.Literal; import io.prestosql.sql.tree.TryExpression; @@ -172,6 +173,7 @@ private static Set extractInliningTargets(ProjectNode parent, ProjectNod .filter(entry -> entry.getValue() == 1) // reference appears just once across all expressions in parent project node .filter(entry -> !tryArguments.contains(entry.getKey())) // they are not inputs to TRY. Otherwise, inlining might change semantics .filter(entry -> !child.getAssignments().isIdentity(entry.getKey())) // skip identities, otherwise, this rule will keep firing forever + .filter(entry -> !(child.getAssignments().get(entry.getKey()) instanceof DereferenceExpression)) // skip dereferences, otherwise, inlining can cause conflicts with PushdownDereferences .map(Map.Entry::getKey) .collect(toSet()); diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferences.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferences.java new file mode 100644 index 000000000000..6ff784aa1663 --- /dev/null +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/PushDownDereferences.java @@ -0,0 +1,604 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.prestosql.matching.Capture; +import io.prestosql.matching.Captures; +import io.prestosql.matching.Pattern; +import io.prestosql.spi.type.Type; +import io.prestosql.sql.planner.ExpressionExtractor; +import io.prestosql.sql.planner.PlanNodeIdAllocator; +import io.prestosql.sql.planner.Symbol; +import io.prestosql.sql.planner.TypeAnalyzer; +import io.prestosql.sql.planner.iterative.Rule; +import io.prestosql.sql.planner.iterative.Rule.Context; +import io.prestosql.sql.planner.plan.AssignUniqueId; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.FilterNode; +import io.prestosql.sql.planner.plan.JoinNode; +import io.prestosql.sql.planner.plan.LimitNode; +import io.prestosql.sql.planner.plan.PlanNode; +import io.prestosql.sql.planner.plan.ProjectNode; +import io.prestosql.sql.planner.plan.RowNumberNode; +import io.prestosql.sql.planner.plan.SemiJoinNode; +import io.prestosql.sql.planner.plan.SortNode; +import io.prestosql.sql.planner.plan.TopNNode; +import io.prestosql.sql.planner.plan.TopNRowNumberNode; +import io.prestosql.sql.planner.plan.UnnestNode; +import io.prestosql.sql.planner.plan.WindowNode; +import io.prestosql.sql.tree.DefaultExpressionTraversalVisitor; +import io.prestosql.sql.tree.DereferenceExpression; +import io.prestosql.sql.tree.Expression; +import io.prestosql.sql.tree.ExpressionRewriter; +import io.prestosql.sql.tree.ExpressionTreeRewriter; +import io.prestosql.sql.tree.SymbolReference; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static io.prestosql.matching.Capture.newCapture; +import static io.prestosql.sql.planner.SymbolsExtractor.extractAll; +import static io.prestosql.sql.planner.plan.Patterns.filter; +import static io.prestosql.sql.planner.plan.Patterns.join; +import static io.prestosql.sql.planner.plan.Patterns.project; +import static io.prestosql.sql.planner.plan.Patterns.semiJoin; +import static io.prestosql.sql.planner.plan.Patterns.source; +import static io.prestosql.sql.planner.plan.Patterns.unnest; +import static java.util.Objects.requireNonNull; + +/** + * Push down dereferences as follows: + *

+ * Extract dereferences from PlanNode which has expressions + * and push them down to a new ProjectNode right below the PlanNode. + * After this step, All dereferences will be in ProjectNode. + *

+ * Pushdown dereferences in ProjectNode down through other types of PlanNode, + * e.g, Filter, Join etc. + */ +public class PushDownDereferences +{ + private final TypeAnalyzer typeAnalyzer; + + public PushDownDereferences(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + public Set> rules() + { + return ImmutableSet.of( + new ExtractFromFilter(typeAnalyzer), + new ExtractFromJoin(typeAnalyzer), + new PushDownDereferenceThrough<>(AssignUniqueId.class, typeAnalyzer), + new PushDownDereferenceThrough<>(WindowNode.class, typeAnalyzer), + new PushDownDereferenceThrough<>(TopNNode.class, typeAnalyzer), + new PushDownDereferenceThrough<>(RowNumberNode.class, typeAnalyzer), + new PushDownDereferenceThrough<>(TopNRowNumberNode.class, typeAnalyzer), + new PushDownDereferenceThrough<>(SortNode.class, typeAnalyzer), + new PushDownDereferenceThrough<>(FilterNode.class, typeAnalyzer), + new PushDownDereferenceThrough<>(LimitNode.class, typeAnalyzer), + new PushDownDereferenceThroughProject(typeAnalyzer), + new PushDownDereferenceThroughUnnest(typeAnalyzer), + new PushDownDereferenceThroughSemiJoin(typeAnalyzer), + new PushDownDereferenceThroughJoin(typeAnalyzer)); + } + + /** + * ExtractFromFilter extracts dereferences and push them down to new ProjectNode below + * Transforms: + *

+     *  FilterNode(expression(a.x))
+     *  
+ * to: + *
+     *   ProjectNode(original symbols)
+     *    FilterNode(expression(symbol))
+     *      Project(symbol := a.x)
+     * 
+ */ + static class ExtractFromFilter + implements Rule + { + private final TypeAnalyzer typeAnalyzer; + + ExtractFromFilter(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return filter(); + } + + @Override + public Result apply(FilterNode node, Captures captures, Context context) + { + BiMap expressions = + HashBiMap.create(getDereferenceSymbolMap(ExpressionExtractor.extractExpressionsNonRecursive(node), context, typeAnalyzer)); + + if (expressions.isEmpty()) { + return Result.empty(); + } + + PlanNode source = node.getSource(); + Assignments assignments = Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(expressions.inverse()).build(); + ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), source, assignments); + + FilterNode filterNode = new FilterNode( + context.getIdAllocator().getNextId(), + projectNode, + ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressions), node.getPredicate())); + + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), filterNode, Assignments.builder().putIdentities(node.getOutputSymbols()).build())); + } + } + + /** + * ExtractFromJoin extracts dereferences in filter expression and push them down + * Transforms: + *
+     *  JoinNode(filter: a.x < 5)
+     *    Source(a)
+     *    Source(b)
+     *  
+ * to: + *
+     *  JoinNode(filter: a_x < 5)
+     *    Project(a_x := a.x)
+     *      Source(a)
+     *    Source(b)
+     * 
+ */ + static class ExtractFromJoin + implements Rule + { + private final TypeAnalyzer typeAnalyzer; + + ExtractFromJoin(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return join(); + } + + @Override + public Result apply(JoinNode joinNode, Captures captures, Context context) + { + BiMap expressions = + HashBiMap.create(getDereferenceSymbolMap(ExpressionExtractor.extractExpressionsNonRecursive(joinNode), context, typeAnalyzer)); + + if (expressions.isEmpty()) { + return Result.empty(); + } + Assignments.Builder leftSideDereferences = Assignments.builder(); + Assignments.Builder rightSideDereferences = Assignments.builder(); + + for (Map.Entry entry : expressions.inverse().entrySet()) { + Symbol baseSymbol = getBase(entry.getValue()); + if (joinNode.getLeft().getOutputSymbols().contains(baseSymbol)) { + leftSideDereferences.put(entry.getKey(), entry.getValue()); + } + else { + rightSideDereferences.put(entry.getKey(), entry.getValue()); + } + } + PlanNode leftNode = createProjectBelow(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator()); + PlanNode rightNode = createProjectBelow(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator()); + + PlanNode newJoinNode = new JoinNode( + context.getIdAllocator().getNextId(), + joinNode.getType(), + leftNode, + rightNode, + joinNode.getCriteria(), + joinNode.getOutputSymbols(), + joinNode.getFilter().map(expression -> ExpressionTreeRewriter.rewriteWith(new PushDownDereferences.DereferenceReplacer(expressions), expression)), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType(), + joinNode.isSpillable(), + joinNode.getDynamicFilters()); + + return Result.ofPlanNode(newJoinNode); + } + } + + /** + * Transforms: + *
+     *  Project(a_x := a.x)
+     *    TargetNode(a)
+     *  
+ * to: + *
+     *  Project(a_x := symbol)
+     *    TargetNode(symbol)
+     *      Project(symbol := a.x)
+     * 
+ */ + static class PushDownDereferenceThrough + implements Rule + { + private final Capture targetCapture = newCapture(); + private final Pattern targetPattern; + + private final TypeAnalyzer typeAnalyzer; + + PushDownDereferenceThrough(Class aClass, TypeAnalyzer typeAnalyzer) + { + targetPattern = Pattern.typeOf(requireNonNull(aClass, "aClass is null")); + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project().with(source().matching(targetPattern.capturedAs(targetCapture))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + N child = captures.get(targetCapture); + Map pushdownDereferences = getPushdownDereferences(context, node, child, typeAnalyzer); + + if (pushdownDereferences.isEmpty()) { + return Result.empty(); + } + + PlanNode source = getOnlyElement(child.getSources()); + + ProjectNode projectNode = new ProjectNode( + context.getIdAllocator().getNextId(), + source, + Assignments.builder().putIdentities(source.getOutputSymbols()).putAll(HashBiMap.create(pushdownDereferences).inverse()).build()); + + PlanNode newChildNode = child.replaceChildren(ImmutableList.of(projectNode)); + Assignments assignments = node.getAssignments().rewrite(new DereferenceReplacer(pushdownDereferences)); + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newChildNode, assignments)); + } + } + + /** + * Transforms: + *
+     *  Project(a_x := a.msg.x)
+     *    Join(a_y = b_y) => [a]
+     *      Project(a_y := a.msg.y)
+     *          Source(a)
+     *      Project(b_y := b.msg.y)
+     *          Source(b)
+     *  
+ * to: + *
+     *  Project(a_x := symbol)
+     *    Join(a_y = b_y) => [symbol]
+     *      Project(symbol := a.msg.x, a_y := a.msg.y)
+     *        Source(a)
+     *      Project(b_y := b.msg.y)
+     *        Source(b)
+     * 
+ */ + static class PushDownDereferenceThroughJoin + implements Rule + { + private final Capture targetCapture = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + PushDownDereferenceThroughJoin(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project().with(source().matching(join().capturedAs(targetCapture))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + JoinNode joinNode = captures.get(targetCapture); + Map pushdownDereferences = getPushdownDereferences(context, node, captures.get(targetCapture), typeAnalyzer); + + if (pushdownDereferences.isEmpty()) { + return Result.empty(); + } + + Assignments.Builder leftSideDereferences = Assignments.builder(); + Assignments.Builder rightSideDereferences = Assignments.builder(); + + for (Map.Entry entry : HashBiMap.create(pushdownDereferences).inverse().entrySet()) { + Symbol baseSymbol = getBase(entry.getValue()); + if (joinNode.getLeft().getOutputSymbols().contains(baseSymbol)) { + leftSideDereferences.put(entry.getKey(), entry.getValue()); + } + else { + rightSideDereferences.put(entry.getKey(), entry.getValue()); + } + } + PlanNode leftNode = createProjectBelow(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator()); + PlanNode rightNode = createProjectBelow(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator()); + + JoinNode newJoinNode = new JoinNode(context.getIdAllocator().getNextId(), + joinNode.getType(), + leftNode, + rightNode, + joinNode.getCriteria(), + ImmutableList.builder().addAll(leftNode.getOutputSymbols()).addAll(rightNode.getOutputSymbols()).build(), + joinNode.getFilter(), + joinNode.getLeftHashSymbol(), + joinNode.getRightHashSymbol(), + joinNode.getDistributionType(), + joinNode.isSpillable(), + joinNode.getDynamicFilters()); + + Assignments assignments = node.getAssignments().rewrite(new DereferenceReplacer(pushdownDereferences)); + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newJoinNode, assignments)); + } + } + + static class PushDownDereferenceThroughSemiJoin + implements Rule + { + private final Capture targetCapture = newCapture(); + private final TypeAnalyzer typeAnalyzer; + + PushDownDereferenceThroughSemiJoin(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project().with(source().matching(semiJoin().capturedAs(targetCapture))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + SemiJoinNode semiJoinNode = captures.get(targetCapture); + Map pushdownDereferences = getPushdownDereferences(context, node, captures.get(targetCapture), typeAnalyzer); + + if (pushdownDereferences.isEmpty()) { + return Result.empty(); + } + + Assignments.Builder filteringSourceDereferences = Assignments.builder(); + Assignments.Builder sourceDereferences = Assignments.builder(); + + for (Map.Entry entry : HashBiMap.create(pushdownDereferences).inverse().entrySet()) { + Symbol baseSymbol = getBase(entry.getValue()); + if (semiJoinNode.getFilteringSource().getOutputSymbols().contains(baseSymbol)) { + filteringSourceDereferences.put(entry.getKey(), entry.getValue()); + } + else { + sourceDereferences.put(entry.getKey(), entry.getValue()); + } + } + PlanNode filteringSource = createProjectBelow(semiJoinNode.getFilteringSource(), filteringSourceDereferences.build(), context.getIdAllocator()); + PlanNode source = createProjectBelow(semiJoinNode.getSource(), sourceDereferences.build(), context.getIdAllocator()); + + PlanNode newSemiJoin = semiJoinNode.replaceChildren(ImmutableList.of(source, filteringSource)); + + Assignments assignments = node.getAssignments().rewrite(new DereferenceReplacer(pushdownDereferences)); + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newSemiJoin, assignments)); + } + } + + static class PushDownDereferenceThroughProject + implements Rule + { + private final Capture targetCapture = newCapture(); + + PushDownDereferenceThroughProject(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + private final TypeAnalyzer typeAnalyzer; + + @Override + public Pattern getPattern() + { + return project().with(source().matching(project().capturedAs(targetCapture))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + ProjectNode child = captures.get(targetCapture); + Map pushdownDereferences = getPushdownDereferences(context, node, captures.get(targetCapture), typeAnalyzer); + + if (pushdownDereferences.isEmpty()) { + return Result.empty(); + } + + ProjectNode newChild = new ProjectNode(context.getIdAllocator().getNextId(), + child.getSource(), + Assignments.builder().putAll(child.getAssignments()).putAll(HashBiMap.create(pushdownDereferences).inverse()).build()); + + Assignments assignments = node.getAssignments().rewrite(new DereferenceReplacer(pushdownDereferences)); + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), newChild, assignments)); + } + } + + static class PushDownDereferenceThroughUnnest + implements Rule + { + private final Capture targetCapture = newCapture(); + + private final TypeAnalyzer typeAnalyzer; + + PushDownDereferenceThroughUnnest(TypeAnalyzer typeAnalyzer) + { + this.typeAnalyzer = requireNonNull(typeAnalyzer, "typeAnalyzer is null"); + } + + @Override + public Pattern getPattern() + { + return project().with(source().matching(unnest().capturedAs(targetCapture))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + UnnestNode unnestNode = captures.get(targetCapture); + Map pushdownDereferences = getPushdownDereferences(context, node, captures.get(targetCapture), typeAnalyzer); + + if (pushdownDereferences.isEmpty()) { + return Result.empty(); + } + + // Create new Project contains all pushdown symbols above original source + Assignments assignments = Assignments.builder().putIdentities(unnestNode.getSource().getOutputSymbols()).putAll(HashBiMap.create(pushdownDereferences).inverse()).build(); + ProjectNode source = new ProjectNode(context.getIdAllocator().getNextId(), unnestNode.getSource(), assignments); + + // Create new UnnestNode + UnnestNode newUnnest = new UnnestNode(context.getIdAllocator().getNextId(), + source, + ImmutableList.builder().addAll(unnestNode.getReplicateSymbols()).addAll(pushdownDereferences.values()).build(), + unnestNode.getUnnestSymbols(), + unnestNode.getOrdinalitySymbol(), + unnestNode.getJoinType(), + unnestNode.getFilter()); + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), + newUnnest, + node.getAssignments().rewrite(new DereferenceReplacer(pushdownDereferences)))); + } + } + + private static Map getPushdownDereferences(Context context, ProjectNode parent, PlanNode child, TypeAnalyzer typeAnalyzer) + { + Map allDereferencesInProject = getDereferenceSymbolMap(parent.getAssignments().getExpressions(), context, typeAnalyzer); + Set childSourceSymbols = child.getSources().stream().map(PlanNode::getOutputSymbols).flatMap(Collection::stream).collect(toImmutableSet()); + + return allDereferencesInProject.entrySet().stream() + .filter(entry -> childSourceSymbols.contains(getBase(entry.getKey()))) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + } + + private static PlanNode createProjectBelow(PlanNode planNode, Assignments dereferences, PlanNodeIdAllocator idAllocator) + { + if (dereferences.isEmpty()) { + return planNode; + } + return new ProjectNode(idAllocator.getNextId(), planNode, Assignments.builder().putIdentities(planNode.getOutputSymbols()).putAll(dereferences).build()); + } + + private static class DereferenceReplacer + extends ExpressionRewriter + { + private final Map expressions; + + DereferenceReplacer(Map expressions) + { + this.expressions = requireNonNull(expressions, "expressions is null"); + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (expressions.containsKey(node)) { + return expressions.get(node).toSymbolReference(); + } + return treeRewriter.defaultRewrite(node, context); + } + } + + private static List extractDereferenceExpressions(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor>() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableList.Builder context) + { + context.add(node); + return null; + } + }.process(expression, builder); + return builder.build(); + } + + private static Map getDereferenceSymbolMap(Collection expressions, Context context, TypeAnalyzer typeAnalyzer) + { + Set dereferences = expressions.stream() + .flatMap(expression -> extractDereferenceExpressions(expression).stream()) + .filter(PushDownDereferences::validPushDown) + .collect(toImmutableSet()); + + // When nested child and parent dereferences both exist, Pushdown rule will be trigger one more time + // and lead to runtime error. E.g. [msg.foo, msg.foo.bar] => [exp, exp.bar] (should stop here but + // since there are still dereferences, pushdown rule will trigger again) + if (dereferences.stream().anyMatch(exp -> baseExists(exp, dereferences))) { + return ImmutableMap.of(); + } + + return dereferences.stream() + .collect(toImmutableMap(Function.identity(), expression -> newSymbol(expression, context, typeAnalyzer))); + } + + private static Symbol newSymbol(Expression expression, Context context, TypeAnalyzer typeAnalyzer) + { + Type type = typeAnalyzer.getType(context.getSession(), context.getSymbolAllocator().getTypes(), expression); + verify(type != null); + return context.getSymbolAllocator().newSymbol(expression, type); + } + + private static boolean baseExists(DereferenceExpression expression, Set dereferences) + { + Expression base = expression.getBase(); + while (base instanceof DereferenceExpression) { + if (dereferences.contains(base)) { + return true; + } + base = ((DereferenceExpression) base).getBase(); + } + return false; + } + + private static boolean validPushDown(DereferenceExpression dereference) + { + Expression base = dereference.getBase(); + return (base instanceof SymbolReference) || (base instanceof DereferenceExpression); + } + + private static Symbol getBase(DereferenceExpression expression) + { + return getOnlyElement(extractAll(expression)); + } +} diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java index 7ea37ff5468b..1a1211dc4215 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Assignments.java @@ -217,9 +217,9 @@ public Builder putAll(Assignments assignments) return putAll(assignments.getMap()); } - public Builder putAll(Map assignments) + public Builder putAll(Map assignments) { - for (Entry assignment : assignments.entrySet()) { + for (Entry assignment : assignments.entrySet()) { put(assignment.getKey(), assignment.getValue()); } return this; diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java index aeaec22d5cbf..eebc5b815f5a 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/plan/Patterns.java @@ -182,6 +182,11 @@ public static Pattern except() return typeOf(ExceptNode.class); } + public static Pattern unnest() + { + return typeOf(UnnestNode.class); + } + public static Property source() { return optionalProperty( diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java new file mode 100644 index 000000000000..794d638bae77 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDereferencePushDown.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.sql.planner.assertions.BasePlanTest; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.anyTree; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.output; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.unnest; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; + +public class TestDereferencePushDown + extends BasePlanTest +{ + @Test + public void testDereferencePushdownJoin() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a, t b " + + "WHERE a.msg.y = b.msg.y", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + values("msg"))), + anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT a.msg.y " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x > bigint '5'", + output(ImmutableList.of("a_y"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + filter("msg.x > bigint '5'", + values("msg")))), + anyTree( + project(ImmutableMap.of("b_y", expression("msg.y")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x + b.msg.x < BIGINT '10'", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + values("msg"))), + anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownFilter() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT a.msg.y, b.msg.x " + + "FROM t a CROSS JOIN t b " + + "WHERE a.msg.x = 7 OR IS_FINITE(b.msg.y)", + anyTree( + join(INNER, ImmutableList.of(), + project(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg")), + project(ImmutableMap.of("b_x", expression("msg.x"), "b_y", expression("msg.y")), + values("msg"))))); + } + + @Test + public void testDereferencePushdownWindow() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT msg.x AS x, ROW_NUMBER() OVER (PARTITION BY msg.y ORDER BY msg.y) AS rn " + + "FROM t ", + anyTree( + project(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg")))); + } + + @Test + public void testDereferencePushdownSemiJoin() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0, 3) AS ROW(x BIGINT, y DOUBLE, z BIGINT)))) " + + "SELECT msg.y " + + "FROM t " + + "WHERE " + + "msg.x IN (SELECT msg.z FROM t)", + anyTree( + semiJoin("a_x", "b_z", "semi_join_symbol", + anyTree( + project(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg"))), + anyTree( + project(ImmutableMap.of("b_z", expression("msg.z")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownLimit() + { + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a, t b " + + "WHERE a.msg.y = b.msg.y " + + "LIMIT 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + values("msg"))), + anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT a.msg.y " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x > BIGINT '5' " + + "LIMIT 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + filter("msg.x > bigint '5'", + values("msg")))), + anyTree( + project(ImmutableMap.of("b_y", expression("msg.y")), + values("msg")))))); + + assertPlan("WITH t(msg) AS (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))" + + "SELECT b.msg.x " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "WHERE a.msg.x + b.msg.x < BIGINT '10' " + + "LIMIT 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + values("msg"))), + anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownUnnest() + { + assertPlan("WITH t(msg, array) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)), ARRAY[1, 2, 3]))) " + + "SELECT a.msg.x " + + "FROM t a JOIN t b ON a.msg.y = b.msg.y " + + "CROSS JOIN UNNEST (a.array) " + + "WHERE a.msg.x + b.msg.x < BIGINT '10'", + output(ImmutableList.of("expr"), + project(ImmutableMap.of("expr", expression("a_x")), + unnest( + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x"), "a_z", expression("array")), + values("msg", "array"))), + anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java index e8b2730f9927..10c765304143 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/assertions/ExpressionVerifier.java @@ -42,6 +42,7 @@ import io.prestosql.sql.tree.SearchedCaseExpression; import io.prestosql.sql.tree.SimpleCaseExpression; import io.prestosql.sql.tree.StringLiteral; +import io.prestosql.sql.tree.SubscriptExpression; import io.prestosql.sql.tree.SymbolReference; import io.prestosql.sql.tree.TryExpression; import io.prestosql.sql.tree.WhenClause; @@ -528,6 +529,17 @@ && process(actual.getPattern(), expected.getPattern()) && process(actual.getEscape(), expected.getEscape()); } + @Override + protected Boolean visitSubscriptExpression(SubscriptExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SubscriptExpression)) { + return false; + } + + SubscriptExpression expected = (SubscriptExpression) expectedExpression; + return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex()); + } + private boolean process(List actuals, List expecteds) { if (actuals.size() != expecteds.size()) { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java index b185e81ea8ed..c84e9d36165d 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestInlineProjections.java @@ -13,13 +13,18 @@ */ package io.prestosql.sql.planner.iterative.rule; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.type.RowType; import io.prestosql.sql.planner.assertions.ExpressionMatcher; import io.prestosql.sql.planner.assertions.PlanMatchPattern; import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; import io.prestosql.sql.planner.plan.Assignments; import org.testng.annotations.Test; +import java.util.Optional; + +import static io.prestosql.spi.type.VarcharType.VARCHAR; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; import static io.prestosql.sql.planner.iterative.rule.test.PlanBuilder.expression; @@ -27,6 +32,8 @@ public class TestInlineProjections extends BaseRuleTest { + private static final RowType MSG_TYPE = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("x"), VARCHAR), new RowType.Field(Optional.of("y"), VARCHAR))); + @Test public void test() { @@ -41,14 +48,16 @@ public void test() .put(p.symbol("multi_literal_2"), expression("literal + 2")) // literal referenced multiple times .put(p.symbol("single_complex"), expression("complex_2 + 2")) // complex expression reference only once .put(p.symbol("try"), expression("try(complex / literal)")) + .put(p.symbol("msg_xx"), expression("z + 1")) .build(), p.project(Assignments.builder() .put(p.symbol("symbol"), expression("x")) .put(p.symbol("complex"), expression("x * 2")) .put(p.symbol("literal"), expression("1")) .put(p.symbol("complex_2"), expression("x - 1")) + .put(p.symbol("z"), expression("msg.x")) .build(), - p.values(p.symbol("x"))))) + p.values(p.symbol("x"), p.symbol("msg", MSG_TYPE))))) .matches( project( ImmutableMap.builder() @@ -59,12 +68,14 @@ public void test() .put("out5", PlanMatchPattern.expression("1 + 2")) .put("out6", PlanMatchPattern.expression("x - 1 + 2")) .put("out7", PlanMatchPattern.expression("try(y / 1)")) + .put("out8", PlanMatchPattern.expression("z + 1")) .build(), project( ImmutableMap.of( "x", PlanMatchPattern.expression("x"), - "y", PlanMatchPattern.expression("x * 2")), - values(ImmutableMap.of("x", 0))))); + "y", PlanMatchPattern.expression("x * 2"), + "z", PlanMatchPattern.expression("msg.x")), + values(ImmutableMap.of("x", 0, "msg", 1))))); } @Test diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushDownDereferencesRules.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushDownDereferencesRules.java new file mode 100644 index 000000000000..a16d68cecc42 --- /dev/null +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestPushDownDereferencesRules.java @@ -0,0 +1,240 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.prestosql.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.prestosql.spi.type.ArrayType; +import io.prestosql.spi.type.RowType; +import io.prestosql.sql.planner.assertions.ExpressionMatcher; +import io.prestosql.sql.planner.assertions.PlanMatchPattern; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferences.ExtractFromFilter; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferences.ExtractFromJoin; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferences.PushDownDereferenceThrough; +import io.prestosql.sql.planner.iterative.rule.PushDownDereferences.PushDownDereferenceThroughJoin; +import io.prestosql.sql.planner.iterative.rule.test.BaseRuleTest; +import io.prestosql.sql.planner.plan.Assignments; +import io.prestosql.sql.planner.plan.LimitNode; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.prestosql.spi.type.BigintType.BIGINT; +import static io.prestosql.spi.type.VarcharType.VARCHAR; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.limit; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.project; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.unnest; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.values; +import static io.prestosql.sql.planner.iterative.rule.test.PlanBuilder.expression; +import static io.prestosql.sql.planner.plan.JoinNode.Type.INNER; + +public class TestPushDownDereferencesRules + extends BaseRuleTest +{ + private static final RowType MSG_TYPE = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("x"), VARCHAR), new RowType.Field(Optional.of("y"), VARCHAR))); + + @Test + public void testDoesNotFire() + { + tester().assertThat(new ExtractFromFilter(tester().getTypeAnalyzer())) + .on(p -> + p.filter(expression("x > BIGINT '5'"), + p.values(p.symbol("x")))) + .doesNotFire(); + + RowType nestedMsgType = RowType.from(ImmutableList.of(new RowType.Field(Optional.of("nested"), MSG_TYPE))); + tester().assertThat(new ExtractFromFilter(tester().getTypeAnalyzer())) + .on(p -> + p.filter(expression("msg.nested.x != 'foo' and CAST(msg.nested as JSON) is not null"), + p.values(p.symbol("msg", nestedMsgType)))) + .doesNotFire(); + } + + @Test + public void testExtractFromFilter() + { + tester().assertThat(new ExtractFromFilter(tester().getTypeAnalyzer())) + .on(p -> + p.filter(expression("msg.x <> 'foo'"), + p.values(p.symbol("msg", MSG_TYPE)))) + .matches( + project(ImmutableMap.of("msg", PlanMatchPattern.expression("msg")), + filter("msg_x <> 'foo'", + project(ImmutableMap.of("msg_x", PlanMatchPattern.expression("msg.x")), + values("msg"))))); + } + + @Test + public void testExtractFromJoin() + { + tester().assertThat(new ExtractFromJoin(tester().getTypeAnalyzer())) + .on(p -> + p.join(INNER, + p.values(p.symbol("msg1", MSG_TYPE)), + p.values(p.symbol("msg2", MSG_TYPE)), + p.expression("msg1.x + msg2.y > BIGINT '10'"))) + .matches( + join(INNER, ImmutableList.of(), Optional.of("msg1_x + msg2_y > BIGINT '10'"), + project( + ImmutableMap.of("msg1_x", PlanMatchPattern.expression("msg1.x")), + values("msg1")), + project( + ImmutableMap.of("msg2_y", PlanMatchPattern.expression("msg2.y")), + values("msg2")))); + } + + @Test + public void testPushDownDereferenceThrough() + { + tester().assertThat(new PushDownDereferenceThrough<>(LimitNode.class, tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("msg_x"), expression("msg.x")) + .put(p.symbol("msg_y"), expression("msg.y")) + .put(p.symbol("z"), expression("z")) + .build(), + p.limit(10, + p.values(p.symbol("msg", MSG_TYPE), p.symbol("z"))))) + .matches( + project( + ImmutableMap.builder() + .put("msg_x", PlanMatchPattern.expression("x")) + .put("msg_y", PlanMatchPattern.expression("y")) + .put("z", PlanMatchPattern.expression("z")) + .build(), + limit(10, + project( + ImmutableMap.builder() + .put("x", PlanMatchPattern.expression("msg.x")) + .put("y", PlanMatchPattern.expression("msg.y")) + .put("z", PlanMatchPattern.expression("z")) + .build(), + values("msg", "z"))))); + } + + @Test + public void testPushdownDereferenceThroughProject() + { + tester().assertThat(new PushDownDereferences.PushDownDereferenceThroughProject(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of(p.symbol("x"), expression("msg.x")), + p.project( + Assignments.of(p.symbol("y"), expression("y")), + p.values(p.symbol("msg", MSG_TYPE), p.symbol("y"))))) + .matches( + project( + ImmutableMap.of("x", PlanMatchPattern.expression("msg_x")), + project( + ImmutableMap.builder() + .put("msg_x", PlanMatchPattern.expression("msg.x")) + .put("y", PlanMatchPattern.expression("y")) + .build(), + values("msg", "y")))); + } + + @Test + public void testPushDownDereferenceThroughJoin() + { + tester().assertThat(new PushDownDereferenceThroughJoin(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("left_x"), expression("msg1.x")) + .put(p.symbol("right_y"), expression("msg2.y")) + .put(p.symbol("z"), expression("z")) + .build(), + p.join(INNER, + p.values(p.symbol("msg1", MSG_TYPE)), + p.values(p.symbol("msg2", MSG_TYPE), p.symbol("z"))))) + .matches( + project( + ImmutableMap.builder() + .put("left_x", PlanMatchPattern.expression("x")) + .put("right_y", PlanMatchPattern.expression("y")) + .put("z", PlanMatchPattern.expression("z")) + .build(), + join(INNER, ImmutableList.of(), + project( + ImmutableMap.of("x", PlanMatchPattern.expression("msg1.x")), + values("msg1")), + project( + ImmutableMap.builder() + .put("y", PlanMatchPattern.expression("msg2.y")) + .put("z", PlanMatchPattern.expression("z")) + .build(), + values("msg2", "z"))))); + } + + @Test + public void testPushdownDereferecesThroughSemiJoin() + { + tester().assertThat(new PushDownDereferences.PushDownDereferenceThroughSemiJoin(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.builder() + .put(p.symbol("left_x"), expression("msg1.x")) + .put(p.symbol("right_y"), expression("msg2.y")) + .build(), + p.semiJoin(p.symbol("left"), + p.symbol("right"), + p.symbol("match"), + Optional.empty(), + Optional.empty(), + p.values(p.symbol("msg1", MSG_TYPE), p.symbol("left")), + p.values(p.symbol("msg2", MSG_TYPE), p.symbol("right"))))) + .matches( + project( + ImmutableMap.builder() + .put("left_x", PlanMatchPattern.expression("msg1_x")) + .put("right_y", PlanMatchPattern.expression("msg2_y")) + .build(), + semiJoin("left", + "right", + "match", + project( + ImmutableMap.of("msg1_x", PlanMatchPattern.expression("msg1.x")), + values("msg1", "left")), + project( + ImmutableMap.of("msg2_y", PlanMatchPattern.expression("msg2.y")), + values("msg2", "right"))))); + } + + @Test + public void testPushdownDereferencesThroughUnnest() + { + ArrayType arrayType = new ArrayType(BIGINT); + tester().assertThat(new PushDownDereferences.PushDownDereferenceThroughUnnest(tester().getTypeAnalyzer())) + .on(p -> + p.project( + Assignments.of(p.symbol("x"), expression("msg.x")), + p.unnest(ImmutableList.of(p.symbol("msg", MSG_TYPE)), + ImmutableMap.of(p.symbol("field"), ImmutableList.of(p.symbol("arr", arrayType))), + Optional.empty(), + INNER, + Optional.empty(), + p.values(p.symbol("msg", MSG_TYPE), p.symbol("arr", arrayType))))) + .matches( + project( + ImmutableMap.of("x", PlanMatchPattern.expression("msg_x")), + unnest( + project( + ImmutableMap.of("msg_x", PlanMatchPattern.expression("msg.x")), + values("msg", "arr"))))); + } +} diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java index e9d5739bc791..f6997277abe0 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/test/PlanBuilder.java @@ -77,6 +77,7 @@ import io.prestosql.sql.planner.plan.TableWriterNode.DeleteTarget; import io.prestosql.sql.planner.plan.TopNNode; import io.prestosql.sql.planner.plan.UnionNode; +import io.prestosql.sql.planner.plan.UnnestNode; import io.prestosql.sql.planner.plan.ValuesNode; import io.prestosql.sql.planner.plan.WindowNode; import io.prestosql.sql.tree.Expression; @@ -494,6 +495,25 @@ public ExchangeNode gatheringExchange(ExchangeNode.Scope scope, PlanNode child) .addInputsSet(child.getOutputSymbols())); } + public UnnestNode unnest( + List replicateSymbols, + Map> unnestSymbols, + Optional ordinalitySymbol, + JoinNode.Type joinType, + Optional filter, + PlanNode source) + + { + return new UnnestNode( + idAllocator.getNextId(), + source, + replicateSymbols, + unnestSymbols, + ordinalitySymbol, + joinType, + filter); + } + public SemiJoinNode semiJoin( Symbol sourceJoinSymbol, Symbol filteringSourceJoinSymbol,