diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index af6750435fd41..4b820df73bc20 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -69,6 +69,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PruneValuesColumns; import com.facebook.presto.sql.planner.iterative.rule.PruneWindowColumns; import com.facebook.presto.sql.planner.iterative.rule.PushAggregationThroughOuterJoin; +import com.facebook.presto.sql.planner.iterative.rule.PushDownDereferences; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughMarkDistinct; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughOuterJoin; import com.facebook.presto.sql.planner.iterative.rule.PushLimitThroughProject; @@ -341,6 +342,13 @@ public PlanOptimizers( new TransformUncorrelatedInPredicateSubqueryToSemiJoin(), new TransformCorrelatedScalarAggregationToJoin(metadata.getFunctionManager()), new TransformCorrelatedLateralJoinToJoin())), + new IterativeOptimizer( + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.>builder() + .addAll(new PushDownDereferences(metadata, sqlParser).rules()) + .build()), new IterativeOptimizer( ruleStats, statsCalculator, diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java new file mode 100644 index 0000000000000..4734d790e109b --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushDownDereferences.java @@ -0,0 +1,560 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.LimitNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TopNNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.Type; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.ExpressionExtractor; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.iterative.Rule.Context; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.RowNumberNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.plan.SortNode; +import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.sql.relational.OriginalExpressionUtils; +import com.facebook.presto.sql.tree.DefaultExpressionTraversalVisitor; +import com.facebook.presto.sql.tree.DereferenceExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.facebook.presto.sql.tree.NodeRef; +import com.facebook.presto.sql.tree.SymbolReference; +import com.google.common.collect.BiMap; +import com.google.common.collect.HashBiMap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; + +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.Function; + +import static com.facebook.presto.matching.Capture.newCapture; +import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.VariablesExtractor.extractAll; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignmentsAsSymbolReferences; +import static com.facebook.presto.sql.planner.plan.Patterns.join; +import static com.facebook.presto.sql.planner.plan.Patterns.project; +import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin; +import static com.facebook.presto.sql.planner.plan.Patterns.source; +import static com.facebook.presto.sql.planner.plan.Patterns.unnest; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToExpression; +import static com.facebook.presto.sql.relational.OriginalExpressionUtils.castToRowExpression; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.util.Collections.emptyList; +import static java.util.Objects.requireNonNull; + +/** + * Push down dereferences as follows: + *

+ * Extract dereferences from PlanNode which has expressions + * and push them down to a new ProjectNode right below the PlanNode. + * After this step, All dereferences will be in ProjectNode. + *

+ * Pushdown dereferences in ProjectNode down through other types of PlanNode, + * e.g, Filter, Join etc. + */ +public class PushDownDereferences +{ + private final Metadata metadata; + private final SqlParser sqlParser; + + public PushDownDereferences(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + + public Set> rules() + { + return ImmutableSet.of( + new ExtractFromFilter(), + new ExtractFromJoin(), + new PushDownDereferenceThrough<>(AssignUniqueId.class), + new PushDownDereferenceThrough<>(WindowNode.class), + new PushDownDereferenceThrough<>(TopNNode.class), + new PushDownDereferenceThrough<>(RowNumberNode.class), + new PushDownDereferenceThrough<>(SortNode.class), + new PushDownDereferenceThrough<>(FilterNode.class), + new PushDownDereferenceThrough<>(LimitNode.class), + new PushDownDereferenceThroughProject(), + new PushDownDereferenceThroughUnnest(), + new PushDownDereferenceThroughSemiJoin(), + new PushDownDereferenceThroughJoin()); + } + + /** + * Extract dereferences and push them down to new ProjectNode below + * Transforms: + *

+     *  TargetNode(expression(a.x))
+     *  
+ * to: + *
+     *   ProjectNode(original symbols)
+     *    TargetNode(expression(symbol))
+     *      Project(symbol := a.x)
+     * 
+ */ + abstract class ExtractProjectDereferences + implements Rule + { + private final Class aClass; + + ExtractProjectDereferences(Class aClass) + { + this.aClass = aClass; + } + + @Override + public Pattern getPattern() + { + return Pattern.typeOf(aClass); + } + + @Override + public Result apply(N node, Captures captures, Context context) + { + Map expressions = + getDereferenceSymbolMap(ExpressionExtractor.extractExpressionsNonRecursive(node), context, metadata, sqlParser); + + if (expressions.isEmpty()) { + return Result.empty(); + } + + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), rewrite(context, node, HashBiMap.create(expressions)), identityAssignmentsAsSymbolReferences(node.getOutputVariables()))); + } + + protected abstract N rewrite(Context context, N node, BiMap expressions); + } + + class ExtractFromFilter + extends ExtractProjectDereferences + { + ExtractFromFilter() + { + super(FilterNode.class); + } + + @Override + protected FilterNode rewrite(Context context, FilterNode node, BiMap expressions) + { + PlanNode source = node.getSource(); + + ImmutableMap + dereferencesMap = + expressions.inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, + entry -> castToRowExpression(entry.getValue()))); + Assignments assignments = Assignments.builder().putAll(identityAssignmentsAsSymbolReferences(source.getOutputVariables())).putAll(dereferencesMap).build(); + ProjectNode projectNode = new ProjectNode(context.getIdAllocator().getNextId(), source, assignments); + + RowExpression filter = castToRowExpression(ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(expressions), castToExpression(node.getPredicate()))); + + return new FilterNode( + context.getIdAllocator().getNextId(), + projectNode, + filter); + } + } + + class ExtractFromJoin + extends ExtractProjectDereferences + { + ExtractFromJoin() + { + super(JoinNode.class); + } + + @Override + protected JoinNode rewrite(Context context, JoinNode joinNode, BiMap expressions) + { + Assignments.Builder leftSideDereferences = Assignments.builder(); + Assignments.Builder rightSideDereferences = Assignments.builder(); + + for (Map.Entry entry : expressions.inverse().entrySet()) { + VariableReferenceExpression baseSymbol = getBase(entry.getValue(), context.getVariableAllocator().getTypes()); + if (joinNode.getLeft().getOutputVariables().contains(baseSymbol)) { + leftSideDereferences.put(entry.getKey(), castToRowExpression(entry.getValue())); + } + else { + rightSideDereferences.put(entry.getKey(), castToRowExpression(entry.getValue())); + } + } + PlanNode leftNode = createProjectBelow(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator()); + PlanNode rightNode = createProjectBelow(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator()); + + return new JoinNode( + context.getIdAllocator().getNextId(), + joinNode.getType(), + leftNode, + rightNode, + joinNode.getCriteria(), + ImmutableList.builder().addAll(leftNode.getOutputVariables()).addAll(rightNode.getOutputVariables()).build(), + joinNode.getFilter().map(expression -> castToRowExpression(ExpressionTreeRewriter.rewriteWith(new PushDownDereferences.DereferenceReplacer(expressions), castToExpression(expression)))), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), + joinNode.getDistributionType()); + } + } + + /** + * Push down dereferences from ProjectNode to child nodes if possible + */ + private abstract class PushdownDereferencesInProject + implements Rule + { + private final Capture targetCapture = newCapture(); + private final Pattern targetPattern; + + protected PushdownDereferencesInProject(Pattern targetPattern) + { + this.targetPattern = requireNonNull(targetPattern, "targetPattern is null"); + } + + @Override + public Pattern getPattern() + { + return project().with(source().matching(targetPattern.capturedAs(targetCapture))); + } + + @Override + public Result apply(ProjectNode node, Captures captures, Context context) + { + N child = captures.get(targetCapture); + Map allDereferencesInProject = getDereferenceSymbolMap(node.getAssignments().getExpressions(), context, metadata, sqlParser); + + Set childSourceSymbols = child.getSources().stream().map(PlanNode::getOutputVariables).flatMap(Collection::stream).collect(toImmutableSet()); + + Map pushdownDereferences = allDereferencesInProject.entrySet().stream() + .filter(entry -> childSourceSymbols.contains(getBase(entry.getKey(), context.getVariableAllocator().getTypes()))) + .collect(toImmutableMap(Map.Entry::getKey, Map.Entry::getValue)); + + if (pushdownDereferences.isEmpty()) { + return Result.empty(); + } + + Result result = pushDownDereferences(context, child, HashBiMap.create(pushdownDereferences)); + if (result.isEmpty()) { + return Result.empty(); + } + + Assignments.Builder builder = Assignments.builder(); + for (Map.Entry entry : node.getAssignments().entrySet()) { + if (OriginalExpressionUtils.isExpression(entry.getValue())) { + builder.put(entry.getKey(), castToRowExpression(ExpressionTreeRewriter.rewriteWith(new DereferenceReplacer(pushdownDereferences), castToExpression(entry.getValue())))); + } + else { + builder.put(entry.getKey(), entry.getValue()); + } + } + return Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), result.getTransformedPlan().get(), builder.build())); + } + + protected abstract Result pushDownDereferences(Context context, N targetNode, BiMap expressions); + } + + /** + * Transforms: + *
+     *  Project(a_x := a.x)
+     *    TargetNode(a)
+     *  
+ * to: + *
+     *  Project(a_x := symbol)
+     *    TargetNode(symbol)
+     *      Project(symbol := a.x)
+     * 
+ */ + public class PushDownDereferenceThrough + extends PushdownDereferencesInProject + { + public PushDownDereferenceThrough(Class aClass) + { + super(Pattern.typeOf(aClass)); + } + + @Override + protected Result pushDownDereferences(Context context, N targetNode, BiMap expressions) + { + PlanNode source = getOnlyElement(targetNode.getSources()); + + ImmutableMap + dereferencesMap = + expressions.inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, + entry -> castToRowExpression(entry.getValue()))); + ProjectNode projectNode = new ProjectNode( + context.getIdAllocator().getNextId(), + source, + Assignments.builder().putAll(identityAssignmentsAsSymbolReferences(source.getOutputVariables())).putAll(dereferencesMap).build()); + return Result.ofPlanNode(targetNode.replaceChildren(ImmutableList.of(projectNode))); + } + } + + /** + * Transforms: + *
+     *  Project(a_x := a.msg.x)
+     *    Join(a_y = b_y) => [a]
+     *      Project(a_y := a.msg.y)
+     *          Source(a)
+     *      Project(b_y := b.msg.y)
+     *          Source(b)
+     *  
+ * to: + *
+     *  Project(a_x := symbol)
+     *    Join(a_y = b_y) => [symbol]
+     *      Project(symbol := a.msg.x, a_y := a.msg.y)
+     *        Source(a)
+     *      Project(b_y := b.msg.y)
+     *        Source(b)
+     * 
+ */ + public class PushDownDereferenceThroughJoin + extends PushdownDereferencesInProject + { + PushDownDereferenceThroughJoin() + { + super(join()); + } + + @Override + protected Result pushDownDereferences(Context context, JoinNode joinNode, BiMap expressions) + { + Assignments.Builder leftSideDereferences = Assignments.builder(); + Assignments.Builder rightSideDereferences = Assignments.builder(); + + for (Map.Entry entry : expressions.inverse().entrySet()) { + VariableReferenceExpression baseSymbol = getBase(entry.getValue(), context.getVariableAllocator().getTypes()); + if (joinNode.getLeft().getOutputVariables().contains(baseSymbol)) { + leftSideDereferences.put(entry.getKey(), castToRowExpression(entry.getValue())); + } + else { + rightSideDereferences.put(entry.getKey(), castToRowExpression(entry.getValue())); + } + } + PlanNode leftNode = createProjectBelow(joinNode.getLeft(), leftSideDereferences.build(), context.getIdAllocator()); + PlanNode rightNode = createProjectBelow(joinNode.getRight(), rightSideDereferences.build(), context.getIdAllocator()); + + return Result.ofPlanNode(new JoinNode( + context.getIdAllocator().getNextId(), + joinNode.getType(), + leftNode, + rightNode, + joinNode.getCriteria(), + ImmutableList.builder().addAll(leftNode.getOutputVariables()).addAll(rightNode.getOutputVariables()).build(), + joinNode.getFilter(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), + joinNode.getDistributionType())); + } + } + + public class PushDownDereferenceThroughSemiJoin + extends PushdownDereferencesInProject + { + PushDownDereferenceThroughSemiJoin() + { + super(semiJoin()); + } + + @Override + protected Result pushDownDereferences(Context context, SemiJoinNode semiJoinNode, BiMap expressions) + { + Assignments.Builder filteringSourceDereferences = Assignments.builder(); + Assignments.Builder sourceDereferences = Assignments.builder(); + + for (Map.Entry entry : expressions.inverse().entrySet()) { + VariableReferenceExpression baseSymbol = getBase(entry.getValue(), context.getVariableAllocator().getTypes()); + if (semiJoinNode.getFilteringSource().getOutputVariables().contains(baseSymbol)) { + filteringSourceDereferences.put(entry.getKey(), castToRowExpression(entry.getValue())); + } + else { + sourceDereferences.put(entry.getKey(), castToRowExpression(entry.getValue())); + } + } + PlanNode filteringSource = createProjectBelow(semiJoinNode.getFilteringSource(), filteringSourceDereferences.build(), context.getIdAllocator()); + PlanNode source = createProjectBelow(semiJoinNode.getSource(), sourceDereferences.build(), context.getIdAllocator()); + return Result.ofPlanNode(semiJoinNode.replaceChildren(ImmutableList.of(source, filteringSource))); + } + } + + public class PushDownDereferenceThroughProject + extends PushdownDereferencesInProject + { + PushDownDereferenceThroughProject() + { + super(project()); + } + + @Override + protected Result pushDownDereferences(Context context, ProjectNode projectNode, BiMap expressions) + { + ImmutableMap + dereferencesMap = + expressions.inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, + entry -> castToRowExpression(entry.getValue()))); + + return Result.ofPlanNode( + new ProjectNode(context.getIdAllocator().getNextId(), + projectNode.getSource(), + Assignments.builder().putAll(projectNode.getAssignments()).putAll(dereferencesMap).build())); + } + } + + public class PushDownDereferenceThroughUnnest + extends PushdownDereferencesInProject + { + PushDownDereferenceThroughUnnest() + { + super(unnest()); + } + + @Override + protected Result pushDownDereferences(Context context, UnnestNode unnestNode, BiMap expressions) + { + // Create new Project contains all pushdown symbols above original source + ImmutableMap + dereferencesMap = + expressions.inverse().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, + entry -> castToRowExpression(entry.getValue()))); + Assignments assignments = Assignments.builder().putAll(identityAssignmentsAsSymbolReferences(unnestNode.getSource().getOutputVariables())).putAll(dereferencesMap).build(); + ProjectNode source = new ProjectNode(context.getIdAllocator().getNextId(), unnestNode.getSource(), assignments); + + // Create new UnnestNode + UnnestNode newUnnest = new UnnestNode(context.getIdAllocator().getNextId(), + source, + ImmutableList.builder().addAll(unnestNode.getReplicateVariables()).addAll(expressions.values()).build(), + unnestNode.getUnnestVariables(), + unnestNode.getOrdinalityVariable()); + return Result.ofPlanNode(newUnnest); + } + } + + private static PlanNode createProjectBelow(PlanNode planNode, Assignments dereferences, PlanNodeIdAllocator idAllocator) + { + if (dereferences.isEmpty()) { + return planNode; + } + return new ProjectNode(idAllocator.getNextId(), planNode, Assignments.builder().putAll(identityAssignmentsAsSymbolReferences(planNode.getOutputVariables())).putAll(dereferences).build()); + } + + private static class DereferenceReplacer + extends ExpressionRewriter + { + private final Map expressions; + + DereferenceReplacer(Map expressions) + { + this.expressions = requireNonNull(expressions, "expressions is null"); + } + + @Override + public Expression rewriteDereferenceExpression(DereferenceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + if (expressions.containsKey(node)) { + return new SymbolReference(expressions.get(node).getName()); + } + return treeRewriter.defaultRewrite(node, context); + } + } + + private static List extractDereferenceExpressions(Expression expression) + { + ImmutableList.Builder builder = ImmutableList.builder(); + new DefaultExpressionTraversalVisitor>() + { + @Override + protected Void visitDereferenceExpression(DereferenceExpression node, ImmutableList.Builder context) + { + context.add(node); + return null; + } + }.process(expression, builder); + return builder.build(); + } + + private static Map getDereferenceSymbolMap(Collection expressions, Context context, Metadata metadata, SqlParser sqlParser) + { + Set dereferences = expressions.stream() + .filter(OriginalExpressionUtils::isExpression) + .map(OriginalExpressionUtils::castToExpression) + .flatMap(expression -> extractDereferenceExpressions(expression).stream()) + .filter(PushDownDereferences::validPushDown) + .collect(toImmutableSet()); + + // TODO DereferenceExpression Base will cause unnecessary rewritten of same expression. + // E.g. [msg.foo, msg.foo.bar] => [exp_1, exp_1.bar] => ... + if (dereferences.stream().anyMatch(exp -> baseExists(exp, dereferences))) { + return ImmutableMap.of(); + } + + return dereferences.stream() + .collect(toImmutableMap(Function.identity(), expression -> newSymbol(expression, context, metadata, sqlParser))); + } + + private static VariableReferenceExpression newSymbol(Expression expression, Context context, Metadata metadata, SqlParser sqlParser) + { + Type type = getExpressionTypes(context.getSession(), metadata, sqlParser, context.getVariableAllocator().getTypes(), expression, emptyList(), WarningCollector.NOOP).get(NodeRef.of(expression)); + verify(type != null); + return context.getVariableAllocator().newVariable(expression, type); + } + + private static boolean baseExists(DereferenceExpression expression, Set dereferences) + { + Expression base = expression.getBase(); + while (base instanceof DereferenceExpression) { + if (dereferences.contains(base)) { + return true; + } + base = ((DereferenceExpression) base).getBase(); + } + return false; + } + + private static boolean validPushDown(DereferenceExpression dereference) + { + Expression base = dereference.getBase(); + return (base instanceof SymbolReference) || (base instanceof DereferenceExpression); + } + + private static VariableReferenceExpression getBase(Expression expression, TypeProvider types) + { + return getOnlyElement(extractAll(expression, types)); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDereferencePushDown.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDereferencePushDown.java new file mode 100644 index 0000000000000..0d0c2e35e584e --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestDereferencePushDown.java @@ -0,0 +1,212 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.Ordering; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.equiJoinClause; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.exchange; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.filter; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.output; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.sort; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.unnest; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Scope.LOCAL; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.GATHER; +import static com.facebook.presto.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.tree.SortItem.NullOrdering.LAST; +import static com.facebook.presto.sql.tree.SortItem.Ordering.ASCENDING; + +public class TestDereferencePushDown + extends BasePlanTest +{ + @Test + public void testDereferencePushdownJoin() + { + assertPlan("WITH t(msg) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT b.msg.x FROM t a, t b WHERE a.msg.y = b.msg.y", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + values("msg")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT a.msg.y FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x > bigint '5'", + output(ImmutableList.of("a_y"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + filter("msg.x > bigint '5'", + values("msg"))) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y")), + values("msg")))))); + + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT b.msg.x FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x + b.msg.x < bigint '10'", + output(ImmutableList.of("b_x"), + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + values("msg")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownFilter() + { + assertPlan("WITH t(msg) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT a.msg.y, b.msg.x from t a cross join t b where a.msg.x = 7 or is_finite(b.msg.y)", + anyTree( + join(INNER, ImmutableList.of(), + project(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg")), + project(ImmutableMap.of("b_x", expression("msg.x"), "b_y", expression("msg.y")), + values("msg"))))); + } + + @Test + public void testDereferencePushdownWindow() + { + assertPlan("WITH t(msg) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT * from (select msg.x as x, ROW_NUMBER() over (partition by msg.y order by msg.y) as rn from t) where rn = 1", + anyTree( + project(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg")))); + } + + @Test + public void testDereferencePushdownSemiJoin() + { + assertPlan("WITH t(msg) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0, 3) AS ROW(x BIGINT, y DOUBLE, z BIGINT))))) " + + "SELECT msg.y FROM t WHERE msg.x IN (SELECT msg.z FROM t)", + anyTree( + semiJoin("a_x", "b_z", "SEMI_JOIN_RESULT", + anyTree( + project(ImmutableMap.of("a_x", expression("msg.x"), "a_y", expression("msg.y")), + values("msg"))), + anyTree( + project(ImmutableMap.of("b_z", expression("msg.z")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownAggregation() + { + Session session = Session.builder(getQueryRunner().getDefaultSession()).setSystemProperty("legacy_row_field_ordinal_access", "true").build(); + // java.lang.IllegalStateException: Node ROW ("field".x, "field".y) is not supported +// assertPlanWithSession("WITH t(msg) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0, 3) AS ROW(x BIGINT, y DOUBLE, z BIGINT))))) " + +// "SELECT r.field0 from (select max_by(ROW(msg.x, msg.y), msg.z) as r from t)", session, false, +// anyTree( +// project(ImmutableMap.of("field0", expression("xxx.field0")), +// aggregation(ImmutableMap.of("xxx", functionCall("MAX_BY", ImmutableList.of("a", "b"))), +// project(ImmutableMap.of("a", expression("ROW(msg.x, msg.y)"), "b", expression("msg.z")), +// values("msg")))))); + } + + @Test + public void testDereferencePushdownLimit() + { + assertPlan("WITH t(msg) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT b.msg.x FROM t a, t b WHERE a.msg.y = b.msg.y limit 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + values("msg")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT a.msg.y FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x > bigint '5' limit 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y")), + filter("msg.x > bigint '5'", + values("msg"))) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y")), + values("msg")))))); + + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT b.msg.x FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x + b.msg.x < bigint '10' limit 100", + anyTree(join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + values("msg")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))); + } + + @Test + public void testDereferencePushdownSort() + { + ImmutableList orderBy = ImmutableList.of(sort("b_x", ASCENDING, LAST)); + assertPlan("WITH t(msg) AS ( SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE))))) " + + "SELECT a.msg.x FROM t a JOIN t b ON a.msg.y = b.msg.y WHERE a.msg.x < bigint '10' ORDER BY b.msg.x", + output(ImmutableList.of("expr"), + project(ImmutableMap.of("expr", expression("a_x")), + exchange(LOCAL, GATHER, orderBy, + sort(orderBy, + exchange(LOCAL, REPARTITION, + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x")), + filter("msg.x < bigint '10'", + values("msg"))) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))))))); + } + + @Test + public void testDereferencePushdownUnnest() + { + assertPlan("WITH t(msg, array) AS (SELECT * FROM (VALUES ROW(CAST(ROW(1, 2.0) AS ROW(x BIGINT, y DOUBLE)), ARRAY[1, 2, 3]))) " + + "SELECT a.msg.x FROM t a JOIN t b ON a.msg.y = b.msg.y CROSS JOIN UNNEST (a.array) WHERE a.msg.x + b.msg.x < bigint '10'", + output(ImmutableList.of("expr"), + project(ImmutableMap.of("expr", expression("a_x")), + unnest( + join(INNER, ImmutableList.of(equiJoinClause("a_y", "b_y")), + Optional.of("a_x + b_x < bigint '10'"), + anyTree( + project(ImmutableMap.of("a_y", expression("msg.y"), "a_x", expression("msg.x"), "a_z", expression("array")), + values("msg", "array")) + ), anyTree( + project(ImmutableMap.of("b_y", expression("msg.y"), "b_x", expression("msg.x")), + values("msg")))))))); + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java index 81aae0689f841..9d8525891c1ea 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/ExpressionVerifier.java @@ -21,6 +21,7 @@ import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.DecimalLiteral; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -34,8 +35,10 @@ import com.facebook.presto.sql.tree.Node; import com.facebook.presto.sql.tree.NotExpression; import com.facebook.presto.sql.tree.NullLiteral; +import com.facebook.presto.sql.tree.SearchedCaseExpression; import com.facebook.presto.sql.tree.SimpleCaseExpression; import com.facebook.presto.sql.tree.StringLiteral; +import com.facebook.presto.sql.tree.SubscriptExpression; import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; @@ -411,6 +414,43 @@ protected Boolean visitInListExpression(InListExpression actual, Node expected) return process(actual.getValues(), expectedInList.getValues()); } + @Override + protected Boolean visitDereferenceExpression(DereferenceExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof DereferenceExpression)) { + return false; + } + + DereferenceExpression expected = (DereferenceExpression) expectedExpression; + if (actual.getField().equals(expected.getField())) { + return process(actual.getBase(), expected.getBase()); + } + return false; + } + + @Override + protected Boolean visitSubscriptExpression(SubscriptExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SubscriptExpression)) { + return false; + } + + SubscriptExpression expected = (SubscriptExpression) expectedExpression; + + return process(actual.getBase(), expected.getBase()) && process(actual.getIndex(), expected.getIndex()); + } + + @Override + protected Boolean visitSearchedCaseExpression(SearchedCaseExpression actual, Node expectedExpression) + { + if (!(expectedExpression instanceof SearchedCaseExpression)) { + return false; + } + + SearchedCaseExpression expected = (SearchedCaseExpression) expectedExpression; + return process(actual.getDefaultValue(), expected.getDefaultValue()) && process(actual.getWhenClauses(), expected.getWhenClauses()); + } + private boolean process(List actuals, List expecteds) { if (actuals.size() != expecteds.size()) { diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index c174fe18a7994..2f1429f1eafbc 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.RowType; import com.facebook.presto.sql.planner.LiteralInterpreter; import com.facebook.presto.sql.relational.FunctionResolution; import com.facebook.presto.sql.tree.ArithmeticBinaryExpression; @@ -32,6 +33,7 @@ import com.facebook.presto.sql.tree.CoalesceExpression; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.DecimalLiteral; +import com.facebook.presto.sql.tree.DereferenceExpression; import com.facebook.presto.sql.tree.DoubleLiteral; import com.facebook.presto.sql.tree.FunctionCall; import com.facebook.presto.sql.tree.GenericLiteral; @@ -50,6 +52,7 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.facebook.presto.sql.tree.TryExpression; import com.facebook.presto.sql.tree.WhenClause; +import com.google.common.base.Preconditions; import io.airlift.slice.Slice; import java.util.List; @@ -69,6 +72,7 @@ import static com.facebook.presto.spi.function.OperatorType.NOT_EQUAL; import static com.facebook.presto.spi.function.OperatorType.SUBTRACT; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.COALESCE; +import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.DEREFERENCE; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IN; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL; import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.SWITCH; @@ -322,6 +326,28 @@ protected Boolean visitBooleanLiteral(BooleanLiteral expected, RowExpression act return compareLiteral(expected, actual); } + @Override + protected Boolean visitDereferenceExpression(DereferenceExpression expected, RowExpression actual) + { + if (!(actual instanceof SpecialFormExpression) || !(((SpecialFormExpression) actual).getForm().equals(DEREFERENCE))) { + return false; + } + SpecialFormExpression actualDereference = (SpecialFormExpression) actual; + if (actualDereference.getArguments().size() == 2 && + actualDereference.getArguments().get(0).getType() instanceof RowType && + actualDereference.getArguments().get(1) instanceof ConstantExpression) { + RowType rowType = (RowType) actualDereference.getArguments().get(0).getType(); + Object value = LiteralInterpreter.evaluate(TEST_SESSION.toConnectorSession(), (ConstantExpression) actualDereference.getArguments().get(1)); + Preconditions.checkState(value instanceof Long); + long index = (Long) value; + Preconditions.checkState(index >= 0 && index < rowType.getFields().size()); + RowType.Field field = rowType.getFields().get((int) index); + Preconditions.checkState(field.getName().isPresent()); + return expected.getField().getValue().equals(field.getName().get()) && process(expected.getBase(), actualDereference.getArguments().get(0)); + } + return false; + } + private static String getValueFromLiteral(Node expression) { if (expression instanceof LongLiteral) {