diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index bec2860357021..fc05f98200c99 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -119,6 +119,7 @@ import com.facebook.presto.sql.planner.optimizations.PruneUnreferencedOutputs; import com.facebook.presto.sql.planner.optimizations.PushdownSubfields; import com.facebook.presto.sql.planner.optimizations.ReplicateSemiJoinInDelete; +import com.facebook.presto.sql.planner.optimizations.RowExpressionPredicatePushDown; import com.facebook.presto.sql.planner.optimizations.SetFlatteningOptimizer; import com.facebook.presto.sql.planner.optimizations.StatsRecordingPlanOptimizer; import com.facebook.presto.sql.planner.optimizations.TransformQuantifiedComparisonApplyToLateralJoin; @@ -256,6 +257,7 @@ public PlanOptimizers( new SimplifyRowExpressions(metadata).rules()); PlanOptimizer predicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new PredicatePushDown(metadata, sqlParser)); + PlanOptimizer rowExpressionPredicatePushDown = new StatsRecordingPlanOptimizer(optimizerStats, new RowExpressionPredicatePushDown(metadata, sqlParser)); builder.add( // Clean up all the sugar in expressions, e.g. AtTimeZone, must be run before all the other optimizers @@ -472,17 +474,6 @@ public PlanOptimizers( ImmutableSet.of(new PushTableWriteThroughUnion()))); // Must run before AddExchanges builder.add(new StatsRecordingPlanOptimizer(optimizerStats, new AddExchanges(metadata, sqlParser))); } - //noinspection UnusedAssignment - estimatedExchangesCostCalculator = null; // Prevent accidental use after AddExchanges - - builder.add( - new IterativeOptimizer( - ruleStats, - statsCalculator, - costCalculator, - ImmutableSet.of(new RemoveEmptyDelete()))); // Run RemoveEmptyDelete after table scan is removed by PickTableLayout/AddExchanges - - builder.add(predicatePushDown); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate // TODO: move this before optimization if possible!! // Replace all expressions with row expressions @@ -493,6 +484,17 @@ public PlanOptimizers( new TranslateExpressions(metadata, sqlParser).rules())); // After this point, all planNodes should not contain OriginalExpression + //noinspection UnusedAssignment + estimatedExchangesCostCalculator = null; // Prevent accidental use after AddExchanges + + builder.add( + new IterativeOptimizer( + ruleStats, + statsCalculator, + costCalculator, + ImmutableSet.of(new RemoveEmptyDelete()))); // Run RemoveEmptyDelete after table scan is removed by PickTableLayout/AddExchanges + + builder.add(rowExpressionPredicatePushDown); // Run predicate push down one more time in case we can leverage new information from layouts' effective predicate builder.add(simplifyRowExpressionOptimizer); // Should be always run after PredicatePushDown builder.add(projectionPushDown); builder.add(inlineProjections); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RowExpressionPredicatePushDown.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RowExpressionPredicatePushDown.java new file mode 100644 index 0000000000000..4f128c5dff5e1 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/RowExpressionPredicatePushDown.java @@ -0,0 +1,1327 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.execution.warnings.WarningCollector; +import com.facebook.presto.expressions.LogicalRowExpressions; +import com.facebook.presto.expressions.RowExpressionNodeInliner; +import com.facebook.presto.metadata.FunctionManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.function.OperatorType; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.PlanNodeIdAllocator; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.spi.type.TypeManager; +import com.facebook.presto.sql.parser.SqlParser; +import com.facebook.presto.sql.planner.PlanVariableAllocator; +import com.facebook.presto.sql.planner.RowExpressionEqualityInference; +import com.facebook.presto.sql.planner.RowExpressionPredicateExtractor; +import com.facebook.presto.sql.planner.RowExpressionVariableInliner; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.VariablesExtractor; +import com.facebook.presto.sql.planner.plan.AssignUniqueId; +import com.facebook.presto.sql.planner.plan.ExchangeNode; +import com.facebook.presto.sql.planner.plan.GroupIdNode; +import com.facebook.presto.sql.planner.plan.JoinNode; +import com.facebook.presto.sql.planner.plan.MarkDistinctNode; +import com.facebook.presto.sql.planner.plan.SampleNode; +import com.facebook.presto.sql.planner.plan.SemiJoinNode; +import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; +import com.facebook.presto.sql.planner.plan.SortNode; +import com.facebook.presto.sql.planner.plan.SpatialJoinNode; +import com.facebook.presto.sql.planner.plan.UnionNode; +import com.facebook.presto.sql.planner.plan.UnnestNode; +import com.facebook.presto.sql.planner.plan.WindowNode; +import com.facebook.presto.sql.relational.Expressions; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.facebook.presto.sql.relational.RowExpressionDomainTranslator; +import com.facebook.presto.sql.relational.RowExpressionOptimizer; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.EnumSet; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static com.facebook.presto.expressions.LogicalRowExpressions.FALSE_CONSTANT; +import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT; +import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts; +import static com.facebook.presto.spi.function.OperatorType.EQUAL; +import static com.facebook.presto.spi.type.BigintType.BIGINT; +import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; +import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes; +import static com.facebook.presto.sql.planner.plan.AssignmentUtils.identityAssignments; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.PARTITIONED; +import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.FULL; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; +import static com.facebook.presto.sql.planner.plan.JoinNode.Type.RIGHT; +import static com.facebook.presto.sql.relational.Expressions.call; +import static com.facebook.presto.sql.relational.Expressions.constant; +import static com.facebook.presto.sql.relational.Expressions.constantNull; +import static com.facebook.presto.sql.relational.Expressions.uniqueSubExpressions; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Predicates.in; +import static com.google.common.base.Predicates.not; +import static com.google.common.base.Verify.verify; +import static com.google.common.collect.Iterables.filter; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class RowExpressionPredicatePushDown + implements PlanOptimizer +{ + private final Metadata metadata; + private final RowExpressionPredicateExtractor effectivePredicateExtractor; + private final SqlParser sqlParser; + + public RowExpressionPredicatePushDown(Metadata metadata, SqlParser sqlParser) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.effectivePredicateExtractor = new RowExpressionPredicateExtractor(new RowExpressionDomainTranslator(metadata), metadata.getFunctionManager(), metadata.getTypeManager()); + this.sqlParser = requireNonNull(sqlParser, "sqlParser is null"); + } + + @Override + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, PlanVariableAllocator variableAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + { + requireNonNull(plan, "plan is null"); + requireNonNull(session, "session is null"); + requireNonNull(types, "types is null"); + requireNonNull(idAllocator, "idAllocator is null"); + + return SimplePlanRewriter.rewriteWith( + new Rewriter(variableAllocator, idAllocator, metadata, effectivePredicateExtractor, sqlParser, session), + plan, + TRUE_CONSTANT); + } + + private static class Rewriter + extends SimplePlanRewriter + { + private final PlanVariableAllocator variableAllocator; + private final PlanNodeIdAllocator idAllocator; + private final Metadata metadata; + private final RowExpressionPredicateExtractor effectivePredicateExtractor; + private final Session session; + private final ExpressionEquivalence expressionEquivalence; + private final RowExpressionDeterminismEvaluator determinismEvaluator; + private final LogicalRowExpressions logicalRowExpressions; + private final TypeManager typeManager; + private final FunctionManager functionManager; + + private Rewriter( + PlanVariableAllocator variableAllocator, + PlanNodeIdAllocator idAllocator, + Metadata metadata, + RowExpressionPredicateExtractor effectivePredicateExtractor, + SqlParser sqlParser, + Session session) + { + this.variableAllocator = requireNonNull(variableAllocator, "variableAllocator is null"); + this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.effectivePredicateExtractor = requireNonNull(effectivePredicateExtractor, "effectivePredicateExtractor is null"); + this.session = requireNonNull(session, "session is null"); + this.expressionEquivalence = new ExpressionEquivalence(metadata, sqlParser); + this.determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata); + this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, new FunctionResolution(metadata.getFunctionManager()), metadata.getFunctionManager()); + this.typeManager = metadata.getTypeManager(); + this.functionManager = metadata.getFunctionManager(); + } + + @Override + public PlanNode visitPlan(PlanNode node, RewriteContext context) + { + PlanNode rewrittenNode = context.defaultRewrite(node, TRUE_CONSTANT); + if (!context.get().equals(TRUE_CONSTANT)) { + // Drop in a FilterNode b/c we cannot push our predicate down any further + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, context.get()); + } + return rewrittenNode; + } + + @Override + public PlanNode visitExchange(ExchangeNode node, RewriteContext context) + { + boolean modified = false; + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < node.getSources().size(); i++) { + Map outputsToInputs = new HashMap<>(); + for (int index = 0; index < node.getInputs().get(i).size(); index++) { + outputsToInputs.put( + node.getOutputVariables().get(index), + node.getInputs().get(i).get(index)); + } + + RowExpression sourcePredicate = RowExpressionVariableInliner.inlineVariables(outputsToInputs, context.get()); + PlanNode source = node.getSources().get(i); + PlanNode rewrittenSource = context.rewrite(source, sourcePredicate); + if (rewrittenSource != source) { + modified = true; + } + builder.add(rewrittenSource); + } + + if (modified) { + return new ExchangeNode( + node.getId(), + node.getType(), + node.getScope(), + node.getPartitioningScheme(), + builder.build(), + node.getInputs(), + node.getOrderingScheme()); + } + + return node; + } + + @Override + public PlanNode visitWindow(WindowNode node, RewriteContext context) + { + // TODO: This could be broader. We can push down conjucts if they are constant for all rows in a window partition. + // The simplest way to guarantee this is if the conjucts are deterministic functions of the partitioning variables. + // This can leave out cases where they're both functions of some set of common expressions and the partitioning + // function is injective, but that's a rare case. The majority of window nodes are expected to be partitioned by + // pre-projected variables. + Predicate isSupported = conjunct -> + determinismEvaluator.isDeterministic(conjunct) && + VariablesExtractor.extractUnique(conjunct).stream().allMatch(node.getPartitionBy()::contains); + + Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(isSupported)); + + PlanNode rewrittenNode = context.defaultRewrite(node, logicalRowExpressions.combineConjuncts(conjuncts.get(true))); + + if (!conjuncts.get(false).isEmpty()) { + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(conjuncts.get(false))); + } + + return rewrittenNode; + } + + @Override + public PlanNode visitProject(ProjectNode node, RewriteContext context) + { + Set deterministicVariables = node.getAssignments().entrySet().stream() + .filter(entry -> determinismEvaluator.isDeterministic(entry.getValue())) + .map(Map.Entry::getKey) + .collect(Collectors.toSet()); + + Predicate deterministic = conjunct -> deterministicVariables.containsAll(VariablesExtractor.extractUnique(conjunct)); + + Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(deterministic)); + + // Push down conjuncts from the inherited predicate that only depend on deterministic assignments with + // certain limitations. + List deterministicConjuncts = conjuncts.get(true); + + // We partition the expressions in the deterministicConjuncts into two lists, and only inline the + // expressions that are in the inlining targets list. + Map> inlineConjuncts = deterministicConjuncts.stream() + .collect(Collectors.partitioningBy(expression -> isInliningCandidate(expression, node))); + + List inlinedDeterministicConjuncts = inlineConjuncts.get(true).stream() + .map(entry -> RowExpressionVariableInliner.inlineVariables(node.getAssignments().getMap(), entry)) + .collect(Collectors.toList()); + + PlanNode rewrittenNode = context.defaultRewrite(node, logicalRowExpressions.combineConjuncts(inlinedDeterministicConjuncts)); + + // All deterministic conjuncts that contains non-inlining targets, and non-deterministic conjuncts, + // if any, will be in the filter node. + List nonInliningConjuncts = inlineConjuncts.get(false); + nonInliningConjuncts.addAll(conjuncts.get(false)); + + if (!nonInliningConjuncts.isEmpty()) { + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(nonInliningConjuncts)); + } + + return rewrittenNode; + } + + private boolean isInliningCandidate(RowExpression expression, ProjectNode node) + { + // TryExpressions should not be pushed down. However they are now being handled as lambda + // passed to a FunctionCall now and should not affect predicate push down. So we want to make + // sure the conjuncts are not TryExpressions. + FunctionResolution functionResolution = new FunctionResolution(functionManager); + verify(uniqueSubExpressions(expression) + .stream() + .noneMatch(subExpression -> subExpression instanceof CallExpression && + functionResolution.isTryFunction(((CallExpression) subExpression).getFunctionHandle()))); + + // candidate symbols for inlining are + // 1. references to simple constants + // 2. references to complex expressions that appear only once + // which come from the node, as opposed to an enclosing scope. + Set childOutputSet = ImmutableSet.copyOf(node.getOutputVariables()); + Map dependencies = VariablesExtractor.extractAll(expression).stream() + .filter(childOutputSet::contains) + .collect(Collectors.groupingBy(identity(), Collectors.counting())); + + return dependencies.entrySet().stream() + .allMatch(entry -> entry.getValue() == 1 || node.getAssignments().get(entry.getKey()) instanceof ConstantExpression); + } + + @Override + public PlanNode visitGroupId(GroupIdNode node, RewriteContext context) + { + Map commonGroupingVariableMapping = node.getGroupingColumns().entrySet().stream() + .filter(entry -> node.getCommonGroupingColumns().contains(entry.getKey())) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + Predicate pushdownEligiblePredicate = conjunct -> VariablesExtractor.extractUnique(conjunct).stream() + .allMatch(commonGroupingVariableMapping.keySet()::contains); + + Map> conjuncts = extractConjuncts(context.get()).stream().collect(Collectors.partitioningBy(pushdownEligiblePredicate)); + + // Push down conjuncts from the inherited predicate that apply to common grouping symbols + PlanNode rewrittenNode = context.defaultRewrite(node, RowExpressionVariableInliner.inlineVariables(commonGroupingVariableMapping, logicalRowExpressions.combineConjuncts(conjuncts.get(true)))); + + // All other conjuncts, if any, will be in the filter node. + if (!conjuncts.get(false).isEmpty()) { + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(conjuncts.get(false))); + } + + return rewrittenNode; + } + + @Override + public PlanNode visitMarkDistinct(MarkDistinctNode node, RewriteContext context) + { + Set pushDownableVariables = ImmutableSet.copyOf(node.getDistinctVariables()); + Map> conjuncts = extractConjuncts(context.get()).stream() + .collect(Collectors.partitioningBy(conjunct -> pushDownableVariables.containsAll(VariablesExtractor.extractUnique(conjunct)))); + + PlanNode rewrittenNode = context.defaultRewrite(node, logicalRowExpressions.combineConjuncts(conjuncts.get(true))); + + if (!conjuncts.get(false).isEmpty()) { + rewrittenNode = new FilterNode(idAllocator.getNextId(), rewrittenNode, logicalRowExpressions.combineConjuncts(conjuncts.get(false))); + } + return rewrittenNode; + } + + @Override + public PlanNode visitSort(SortNode node, RewriteContext context) + { + return context.defaultRewrite(node, context.get()); + } + + @Override + public PlanNode visitUnion(UnionNode node, RewriteContext context) + { + boolean modified = false; + ImmutableList.Builder builder = ImmutableList.builder(); + for (int i = 0; i < node.getSources().size(); i++) { + RowExpression sourcePredicate = RowExpressionVariableInliner.inlineVariables(node.sourceVariableMap(i), context.get()); + PlanNode source = node.getSources().get(i); + PlanNode rewrittenSource = context.rewrite(source, sourcePredicate); + if (rewrittenSource != source) { + modified = true; + } + builder.add(rewrittenSource); + } + + if (modified) { + return new UnionNode(node.getId(), builder.build(), node.getVariableMapping()); + } + + return node; + } + + @Deprecated + @Override + public PlanNode visitFilter(FilterNode node, RewriteContext context) + { + PlanNode rewrittenPlan = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(node.getPredicate(), context.get())); + if (!(rewrittenPlan instanceof FilterNode)) { + return rewrittenPlan; + } + + FilterNode rewrittenFilterNode = (FilterNode) rewrittenPlan; + if (!areExpressionsEquivalent(rewrittenFilterNode.getPredicate(), node.getPredicate()) + || node.getSource() != rewrittenFilterNode.getSource()) { + return rewrittenPlan; + } + + return node; + } + + @Override + public PlanNode visitJoin(JoinNode node, RewriteContext context) + { + RowExpression inheritedPredicate = context.get(); + + // See if we can rewrite outer joins in terms of a plain inner join + node = tryNormalizeToOuterToInnerJoin(node, inheritedPredicate); + + RowExpression leftEffectivePredicate = effectivePredicateExtractor.extract(node.getLeft()); + RowExpression rightEffectivePredicate = effectivePredicateExtractor.extract(node.getRight()); + RowExpression joinPredicate = extractJoinPredicate(node); + + RowExpression leftPredicate; + RowExpression rightPredicate; + RowExpression postJoinPredicate; + RowExpression newJoinPredicate; + + switch (node.getType()) { + case INNER: + InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin(inheritedPredicate, + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + node.getLeft().getOutputVariables()); + leftPredicate = innerJoinPushDownResult.getLeftPredicate(); + rightPredicate = innerJoinPushDownResult.getRightPredicate(); + postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate(); + newJoinPredicate = innerJoinPushDownResult.getJoinPredicate(); + break; + case LEFT: + OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate, + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + node.getLeft().getOutputVariables()); + leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate(); + rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate(); + postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate(); + newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate(); + break; + case RIGHT: + OuterJoinPushDownResult rightOuterJoinPushDownResult = processLimitedOuterJoin(inheritedPredicate, + rightEffectivePredicate, + leftEffectivePredicate, + joinPredicate, + node.getRight().getOutputVariables()); + leftPredicate = rightOuterJoinPushDownResult.getInnerJoinPredicate(); + rightPredicate = rightOuterJoinPushDownResult.getOuterJoinPredicate(); + postJoinPredicate = rightOuterJoinPushDownResult.getPostJoinPredicate(); + newJoinPredicate = rightOuterJoinPushDownResult.getJoinPredicate(); + break; + case FULL: + leftPredicate = TRUE_CONSTANT; + rightPredicate = TRUE_CONSTANT; + postJoinPredicate = inheritedPredicate; + newJoinPredicate = joinPredicate; + break; + default: + throw new UnsupportedOperationException("Unsupported join type: " + node.getType()); + } + + newJoinPredicate = simplifyExpression(newJoinPredicate); + // TODO: find a better way to directly optimize FALSE LITERAL in join predicate + if (newJoinPredicate.equals(FALSE_CONSTANT)) { + newJoinPredicate = buildEqualsExpression(functionManager, constant(0L, BIGINT), constant(1L, BIGINT)); + } + + PlanNode leftSource = context.rewrite(node.getLeft(), leftPredicate); + PlanNode rightSource = context.rewrite(node.getRight(), rightPredicate); + + PlanNode output = node; + + // Create identity projections for all existing symbols + Assignments.Builder leftProjections = Assignments.builder() + .putAll(identityAssignments(node.getLeft().getOutputVariables())); + + Assignments.Builder rightProjections = Assignments.builder() + .putAll(identityAssignments(node.getRight().getOutputVariables())); + + // Create new projections for the new join clauses + List equiJoinClauses = new ArrayList<>(); + ImmutableList.Builder joinFilterBuilder = ImmutableList.builder(); + for (RowExpression conjunct : extractConjuncts(newJoinPredicate)) { + if (joinEqualityExpression(node.getLeft().getOutputVariables()).test(conjunct)) { + boolean alignedComparison = Iterables.all(VariablesExtractor.extractUnique(getLeft(conjunct)), in(node.getLeft().getOutputVariables())); + RowExpression leftExpression = (alignedComparison) ? getLeft(conjunct) : getRight(conjunct); + RowExpression rightExpression = (alignedComparison) ? getRight(conjunct) : getLeft(conjunct); + + VariableReferenceExpression leftVariable = variableForExpression(leftExpression); + if (!node.getLeft().getOutputVariables().contains(leftVariable)) { + leftProjections.put(leftVariable, leftExpression); + } + + VariableReferenceExpression rightVariable = variableForExpression(rightExpression); + if (!node.getRight().getOutputVariables().contains(rightVariable)) { + rightProjections.put(rightVariable, rightExpression); + } + + equiJoinClauses.add(new JoinNode.EquiJoinClause(leftVariable, rightVariable)); + } + else { + joinFilterBuilder.add(conjunct); + } + } + + Optional newJoinFilter = Optional.of(logicalRowExpressions.combineConjuncts(joinFilterBuilder.build())); + if (newJoinFilter.get() == TRUE_CONSTANT) { + newJoinFilter = Optional.empty(); + } + + if (node.getType() == INNER && newJoinFilter.isPresent() && equiJoinClauses.isEmpty()) { + // if we do not have any equi conjunct we do not pushdown non-equality condition into + // inner join, so we plan execution as nested-loops-join followed by filter instead + // hash join. + // todo: remove the code when we have support for filter function in nested loop join + postJoinPredicate = logicalRowExpressions.combineConjuncts(postJoinPredicate, newJoinFilter.get()); + newJoinFilter = Optional.empty(); + } + + boolean filtersEquivalent = + newJoinFilter.isPresent() == node.getFilter().isPresent() && + (!newJoinFilter.isPresent() || areExpressionsEquivalent(newJoinFilter.get(), node.getFilter().get())); + + if (leftSource != node.getLeft() || + rightSource != node.getRight() || + !filtersEquivalent || + !ImmutableSet.copyOf(equiJoinClauses).equals(ImmutableSet.copyOf(node.getCriteria()))) { + leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); + rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); + + // if the distribution type is already set, make sure that changes from PredicatePushDown + // don't make the join node invalid. + Optional distributionType = node.getDistributionType(); + if (node.getDistributionType().isPresent()) { + if (node.getType().mustPartition()) { + distributionType = Optional.of(PARTITIONED); + } + if (node.getType().mustReplicate(equiJoinClauses)) { + distributionType = Optional.of(REPLICATED); + } + } + + output = new JoinNode( + node.getId(), + node.getType(), + leftSource, + rightSource, + equiJoinClauses, + ImmutableList.builder() + .addAll(leftSource.getOutputVariables()) + .addAll(rightSource.getOutputVariables()) + .build(), + newJoinFilter, + node.getLeftHashVariable(), + node.getRightHashVariable(), + distributionType); + } + + if (!postJoinPredicate.equals(TRUE_CONSTANT)) { + output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate); + } + + if (!node.getOutputVariables().equals(output.getOutputVariables())) { + output = new ProjectNode(idAllocator.getNextId(), output, identityAssignments(node.getOutputVariables())); + } + + return output; + } + + private static RowExpression getLeft(RowExpression expression) + { + checkArgument(expression instanceof CallExpression && ((CallExpression) expression).getArguments().size() == 2, "must be binary call expression"); + return ((CallExpression) expression).getArguments().get(0); + } + + private static RowExpression getRight(RowExpression expression) + { + checkArgument(expression instanceof CallExpression && ((CallExpression) expression).getArguments().size() == 2, "must be binary call expression"); + return ((CallExpression) expression).getArguments().get(1); + } + + @Override + public PlanNode visitSpatialJoin(SpatialJoinNode node, RewriteContext context) + { + RowExpression inheritedPredicate = context.get(); + + // See if we can rewrite left join in terms of a plain inner join + if (node.getType() == SpatialJoinNode.Type.LEFT && canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate)) { + node = new SpatialJoinNode( + node.getId(), + SpatialJoinNode.Type.INNER, + node.getLeft(), + node.getRight(), + node.getOutputVariables(), + node.getFilter(), + node.getLeftPartitionVariable(), + node.getRightPartitionVariable(), + node.getKdbTree()); + } + + RowExpression leftEffectivePredicate = effectivePredicateExtractor.extract(node.getLeft()); + RowExpression rightEffectivePredicate = effectivePredicateExtractor.extract(node.getRight()); + RowExpression joinPredicate = node.getFilter(); + + RowExpression leftPredicate; + RowExpression rightPredicate; + RowExpression postJoinPredicate; + RowExpression newJoinPredicate; + + switch (node.getType()) { + case INNER: + InnerJoinPushDownResult innerJoinPushDownResult = processInnerJoin( + inheritedPredicate, + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + node.getLeft().getOutputVariables()); + leftPredicate = innerJoinPushDownResult.getLeftPredicate(); + rightPredicate = innerJoinPushDownResult.getRightPredicate(); + postJoinPredicate = innerJoinPushDownResult.getPostJoinPredicate(); + newJoinPredicate = innerJoinPushDownResult.getJoinPredicate(); + break; + case LEFT: + OuterJoinPushDownResult leftOuterJoinPushDownResult = processLimitedOuterJoin( + inheritedPredicate, + leftEffectivePredicate, + rightEffectivePredicate, + joinPredicate, + node.getLeft().getOutputVariables()); + leftPredicate = leftOuterJoinPushDownResult.getOuterJoinPredicate(); + rightPredicate = leftOuterJoinPushDownResult.getInnerJoinPredicate(); + postJoinPredicate = leftOuterJoinPushDownResult.getPostJoinPredicate(); + newJoinPredicate = leftOuterJoinPushDownResult.getJoinPredicate(); + break; + default: + throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType()); + } + + newJoinPredicate = simplifyExpression(newJoinPredicate); + verify(!newJoinPredicate.equals(FALSE_CONSTANT), "Spatial join predicate is missing"); + + PlanNode leftSource = context.rewrite(node.getLeft(), leftPredicate); + PlanNode rightSource = context.rewrite(node.getRight(), rightPredicate); + + PlanNode output = node; + if (leftSource != node.getLeft() || + rightSource != node.getRight() || + !areExpressionsEquivalent(newJoinPredicate, joinPredicate)) { + // Create identity projections for all existing symbols + Assignments.Builder leftProjections = Assignments.builder() + .putAll(identityAssignments(node.getLeft().getOutputVariables())); + + Assignments.Builder rightProjections = Assignments.builder() + .putAll(identityAssignments(node.getRight().getOutputVariables())); + + leftSource = new ProjectNode(idAllocator.getNextId(), leftSource, leftProjections.build()); + rightSource = new ProjectNode(idAllocator.getNextId(), rightSource, rightProjections.build()); + + output = new SpatialJoinNode( + node.getId(), + node.getType(), + leftSource, + rightSource, + node.getOutputVariables(), + newJoinPredicate, + node.getLeftPartitionVariable(), + node.getRightPartitionVariable(), + node.getKdbTree()); + } + + if (!postJoinPredicate.equals(TRUE_CONSTANT)) { + output = new FilterNode(idAllocator.getNextId(), output, postJoinPredicate); + } + + return output; + } + + private VariableReferenceExpression variableForExpression(RowExpression expression) + { + if (expression instanceof VariableReferenceExpression) { + return (VariableReferenceExpression) expression; + } + + return variableAllocator.newVariable(expression); + } + + private OuterJoinPushDownResult processLimitedOuterJoin(RowExpression inheritedPredicate, RowExpression outerEffectivePredicate, RowExpression innerEffectivePredicate, RowExpression joinPredicate, Collection outerVariables) + { + checkArgument(Iterables.all(VariablesExtractor.extractUnique(outerEffectivePredicate), in(outerVariables)), "outerEffectivePredicate must only contain variables from outerVariables"); + checkArgument(Iterables.all(VariablesExtractor.extractUnique(innerEffectivePredicate), not(in(outerVariables))), "innerEffectivePredicate must not contain variables from outerVariables"); + + ImmutableList.Builder outerPushdownConjuncts = ImmutableList.builder(); + ImmutableList.Builder innerPushdownConjuncts = ImmutableList.builder(); + ImmutableList.Builder postJoinConjuncts = ImmutableList.builder(); + ImmutableList.Builder joinConjuncts = ImmutableList.builder(); + + // Strip out non-deterministic conjuncts + postJoinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic))); + inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate); + + outerEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(outerEffectivePredicate); + innerEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(innerEffectivePredicate); + joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(determinismEvaluator::isDeterministic))); + joinPredicate = logicalRowExpressions.filterDeterministicConjuncts(joinPredicate); + + // Generate equality inferences + RowExpressionEqualityInference inheritedInference = createEqualityInference(inheritedPredicate); + RowExpressionEqualityInference outerInference = createEqualityInference(inheritedPredicate, outerEffectivePredicate); + + RowExpressionEqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(outerVariables)); + RowExpression outerOnlyInheritedEqualities = logicalRowExpressions.combineConjuncts(equalityPartition.getScopeEqualities()); + RowExpressionEqualityInference potentialNullSymbolInference = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, innerEffectivePredicate, joinPredicate); + + // See if we can push inherited predicates down + for (RowExpression conjunct : nonInferrableConjuncts(inheritedPredicate)) { + RowExpression outerRewritten = outerInference.rewriteExpression(conjunct, in(outerVariables)); + if (outerRewritten != null) { + outerPushdownConjuncts.add(outerRewritten); + + // A conjunct can only be pushed down into an inner side if it can be rewritten in terms of the outer side + RowExpression innerRewritten = potentialNullSymbolInference.rewriteExpression(outerRewritten, not(in(outerVariables))); + if (innerRewritten != null) { + innerPushdownConjuncts.add(innerRewritten); + } + } + else { + postJoinConjuncts.add(conjunct); + } + } + // Add the equalities from the inferences back in + outerPushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); + postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); + postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); + + // See if we can push down any outer effective predicates to the inner side + for (RowExpression conjunct : nonInferrableConjuncts(outerEffectivePredicate)) { + RowExpression rewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerVariables))); + if (rewritten != null) { + innerPushdownConjuncts.add(rewritten); + } + } + + // See if we can push down join predicates to the inner side + for (RowExpression conjunct : nonInferrableConjuncts(joinPredicate)) { + RowExpression innerRewritten = potentialNullSymbolInference.rewriteExpression(conjunct, not(in(outerVariables))); + if (innerRewritten != null) { + innerPushdownConjuncts.add(innerRewritten); + } + else { + joinConjuncts.add(conjunct); + } + } + + // Push outer and join equalities into the inner side. For example: + // SELECT * FROM nation LEFT OUTER JOIN region ON nation.regionkey = region.regionkey and nation.name = region.name WHERE nation.name = 'blah' + + RowExpressionEqualityInference potentialNullSymbolInferenceWithoutInnerInferred = createEqualityInference(outerOnlyInheritedEqualities, outerEffectivePredicate, joinPredicate); + innerPushdownConjuncts.addAll(potentialNullSymbolInferenceWithoutInnerInferred.generateEqualitiesPartitionedBy(not(in(outerVariables))).getScopeEqualities()); + + // TODO: we can further improve simplifying the equalities by considering other relationships from the outer side + RowExpressionEqualityInference.EqualityPartition joinEqualityPartition = createEqualityInference(joinPredicate).generateEqualitiesPartitionedBy(not(in(outerVariables))); + innerPushdownConjuncts.addAll(joinEqualityPartition.getScopeEqualities()); + joinConjuncts.addAll(joinEqualityPartition.getScopeComplementEqualities()) + .addAll(joinEqualityPartition.getScopeStraddlingEqualities()); + + return new OuterJoinPushDownResult(logicalRowExpressions.combineConjuncts(outerPushdownConjuncts.build()), + logicalRowExpressions.combineConjuncts(innerPushdownConjuncts.build()), + logicalRowExpressions.combineConjuncts(joinConjuncts.build()), + logicalRowExpressions.combineConjuncts(postJoinConjuncts.build())); + } + + private static class OuterJoinPushDownResult + { + private final RowExpression outerJoinPredicate; + private final RowExpression innerJoinPredicate; + private final RowExpression joinPredicate; + private final RowExpression postJoinPredicate; + + private OuterJoinPushDownResult(RowExpression outerJoinPredicate, RowExpression innerJoinPredicate, RowExpression joinPredicate, RowExpression postJoinPredicate) + { + this.outerJoinPredicate = outerJoinPredicate; + this.innerJoinPredicate = innerJoinPredicate; + this.joinPredicate = joinPredicate; + this.postJoinPredicate = postJoinPredicate; + } + + private RowExpression getOuterJoinPredicate() + { + return outerJoinPredicate; + } + + private RowExpression getInnerJoinPredicate() + { + return innerJoinPredicate; + } + + public RowExpression getJoinPredicate() + { + return joinPredicate; + } + + private RowExpression getPostJoinPredicate() + { + return postJoinPredicate; + } + } + + private InnerJoinPushDownResult processInnerJoin(RowExpression inheritedPredicate, RowExpression leftEffectivePredicate, RowExpression rightEffectivePredicate, RowExpression joinPredicate, Collection leftVariables) + { + checkArgument(Iterables.all(VariablesExtractor.extractUnique(leftEffectivePredicate), in(leftVariables)), "leftEffectivePredicate must only contain variables from leftVariables"); + checkArgument(Iterables.all(VariablesExtractor.extractUnique(rightEffectivePredicate), not(in(leftVariables))), "rightEffectivePredicate must not contain variables from leftVariables"); + + ImmutableList.Builder leftPushDownConjuncts = ImmutableList.builder(); + ImmutableList.Builder rightPushDownConjuncts = ImmutableList.builder(); + ImmutableList.Builder joinConjuncts = ImmutableList.builder(); + + // Strip out non-deterministic conjuncts + joinConjuncts.addAll(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic))); + inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate); + + joinConjuncts.addAll(filter(extractConjuncts(joinPredicate), not(determinismEvaluator::isDeterministic))); + joinPredicate = logicalRowExpressions.filterDeterministicConjuncts(joinPredicate); + + leftEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(leftEffectivePredicate); + rightEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(rightEffectivePredicate); + + // Generate equality inferences + RowExpressionEqualityInference allInference = new RowExpressionEqualityInference.Builder(functionManager, typeManager) + .addEqualityInference(inheritedPredicate, leftEffectivePredicate, rightEffectivePredicate, joinPredicate) + .build(); + RowExpressionEqualityInference allInferenceWithoutLeftInferred = new RowExpressionEqualityInference.Builder(functionManager, typeManager) + .addEqualityInference(inheritedPredicate, rightEffectivePredicate, joinPredicate) + .build(); + RowExpressionEqualityInference allInferenceWithoutRightInferred = new RowExpressionEqualityInference.Builder(functionManager, typeManager) + .addEqualityInference(inheritedPredicate, leftEffectivePredicate, joinPredicate) + .build(); + + // Sort through conjuncts in inheritedPredicate that were not used for inference + for (RowExpression conjunct : new RowExpressionEqualityInference.Builder(functionManager, typeManager).nonInferrableConjuncts(inheritedPredicate)) { + RowExpression leftRewrittenConjunct = allInference.rewriteExpression(conjunct, in(leftVariables)); + if (leftRewrittenConjunct != null) { + leftPushDownConjuncts.add(leftRewrittenConjunct); + } + + RowExpression rightRewrittenConjunct = allInference.rewriteExpression(conjunct, not(in(leftVariables))); + if (rightRewrittenConjunct != null) { + rightPushDownConjuncts.add(rightRewrittenConjunct); + } + + // Drop predicate after join only if unable to push down to either side + if (leftRewrittenConjunct == null && rightRewrittenConjunct == null) { + joinConjuncts.add(conjunct); + } + } + + // See if we can push the right effective predicate to the left side + for (RowExpression conjunct : new RowExpressionEqualityInference.Builder(functionManager, typeManager).nonInferrableConjuncts(rightEffectivePredicate)) { + RowExpression rewritten = allInference.rewriteExpression(conjunct, in(leftVariables)); + if (rewritten != null) { + leftPushDownConjuncts.add(rewritten); + } + } + + // See if we can push the left effective predicate to the right side + for (RowExpression conjunct : new RowExpressionEqualityInference.Builder(functionManager, typeManager).nonInferrableConjuncts(leftEffectivePredicate)) { + RowExpression rewritten = allInference.rewriteExpression(conjunct, not(in(leftVariables))); + if (rewritten != null) { + rightPushDownConjuncts.add(rewritten); + } + } + + // See if we can push any parts of the join predicates to either side + for (RowExpression conjunct : new RowExpressionEqualityInference.Builder(functionManager, typeManager).nonInferrableConjuncts(joinPredicate)) { + RowExpression leftRewritten = allInference.rewriteExpression(conjunct, in(leftVariables)); + if (leftRewritten != null) { + leftPushDownConjuncts.add(leftRewritten); + } + + RowExpression rightRewritten = allInference.rewriteExpression(conjunct, not(in(leftVariables))); + if (rightRewritten != null) { + rightPushDownConjuncts.add(rightRewritten); + } + + if (leftRewritten == null && rightRewritten == null) { + joinConjuncts.add(conjunct); + } + } + + // Add equalities from the inference back in + leftPushDownConjuncts.addAll(allInferenceWithoutLeftInferred.generateEqualitiesPartitionedBy(in(leftVariables)).getScopeEqualities()); + rightPushDownConjuncts.addAll(allInferenceWithoutRightInferred.generateEqualitiesPartitionedBy(not(in(leftVariables))).getScopeEqualities()); + joinConjuncts.addAll(allInference.generateEqualitiesPartitionedBy(in(leftVariables)::apply).getScopeStraddlingEqualities()); // scope straddling equalities get dropped in as part of the join predicate + + return new Rewriter.InnerJoinPushDownResult( + logicalRowExpressions.combineConjuncts(leftPushDownConjuncts.build()), + logicalRowExpressions.combineConjuncts(rightPushDownConjuncts.build()), + logicalRowExpressions.combineConjuncts(joinConjuncts.build()), TRUE_CONSTANT); + } + + private static class InnerJoinPushDownResult + { + private final RowExpression leftPredicate; + private final RowExpression rightPredicate; + private final RowExpression joinPredicate; + private final RowExpression postJoinPredicate; + + private InnerJoinPushDownResult(RowExpression leftPredicate, RowExpression rightPredicate, RowExpression joinPredicate, RowExpression postJoinPredicate) + { + this.leftPredicate = leftPredicate; + this.rightPredicate = rightPredicate; + this.joinPredicate = joinPredicate; + this.postJoinPredicate = postJoinPredicate; + } + + private RowExpression getLeftPredicate() + { + return leftPredicate; + } + + private RowExpression getRightPredicate() + { + return rightPredicate; + } + + private RowExpression getJoinPredicate() + { + return joinPredicate; + } + + private RowExpression getPostJoinPredicate() + { + return postJoinPredicate; + } + } + + private RowExpression extractJoinPredicate(JoinNode joinNode) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (JoinNode.EquiJoinClause equiJoinClause : joinNode.getCriteria()) { + builder.add(toRowExpression(equiJoinClause)); + } + joinNode.getFilter().ifPresent(builder::add); + return logicalRowExpressions.combineConjuncts(builder.build()); + } + + private RowExpression toRowExpression(JoinNode.EquiJoinClause equiJoinClause) + { + return buildEqualsExpression(functionManager, equiJoinClause.getLeft(), equiJoinClause.getRight()); + } + + private JoinNode tryNormalizeToOuterToInnerJoin(JoinNode node, RowExpression inheritedPredicate) + { + checkArgument(EnumSet.of(INNER, RIGHT, LEFT, FULL).contains(node.getType()), "Unsupported join type: %s", node.getType()); + + if (node.getType() == JoinNode.Type.INNER) { + return node; + } + + if (node.getType() == JoinNode.Type.FULL) { + boolean canConvertToLeftJoin = canConvertOuterToInner(node.getLeft().getOutputVariables(), inheritedPredicate); + boolean canConvertToRightJoin = canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate); + if (!canConvertToLeftJoin && !canConvertToRightJoin) { + return node; + } + if (canConvertToLeftJoin && canConvertToRightJoin) { + return new JoinNode(node.getId(), INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputVariables(), node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); + } + else { + return new JoinNode(node.getId(), canConvertToLeftJoin ? LEFT : RIGHT, + node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputVariables(), node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); + } + } + + if (node.getType() == JoinNode.Type.LEFT && !canConvertOuterToInner(node.getRight().getOutputVariables(), inheritedPredicate) || + node.getType() == JoinNode.Type.RIGHT && !canConvertOuterToInner(node.getLeft().getOutputVariables(), inheritedPredicate)) { + return node; + } + return new JoinNode(node.getId(), JoinNode.Type.INNER, node.getLeft(), node.getRight(), node.getCriteria(), node.getOutputVariables(), node.getFilter(), node.getLeftHashVariable(), node.getRightHashVariable(), node.getDistributionType()); + } + + private boolean canConvertOuterToInner(List innerVariablesForOuterJoin, RowExpression inheritedPredicate) + { + Set innerVariables = ImmutableSet.copyOf(innerVariablesForOuterJoin); + for (RowExpression conjunct : extractConjuncts(inheritedPredicate)) { + if (determinismEvaluator.isDeterministic(conjunct)) { + // Ignore a conjunct for this test if we can not deterministically get responses from it + RowExpression response = nullInputEvaluator(innerVariables, conjunct); + if (response == null || Expressions.isNull(response) || FALSE_CONSTANT.equals(response)) { + // If there is a single conjunct that returns FALSE or NULL given all NULL inputs for the inner side symbols of an outer join + // then this conjunct removes all effects of the outer join, and effectively turns this into an equivalent of an inner join. + // So, let's just rewrite this join as an INNER join + return true; + } + } + } + return false; + } + + // Temporary implementation for joins because the SimplifyExpressions optimizers can not run properly on join clauses + private RowExpression simplifyExpression(RowExpression expression) + { + return new RowExpressionOptimizer(metadata).optimize(expression, ExpressionOptimizer.Level.SERIALIZABLE, session.toConnectorSession()); + } + + private boolean areExpressionsEquivalent(RowExpression leftExpression, RowExpression rightExpression) + { + return expressionEquivalence.areExpressionsEquivalent(simplifyExpression(leftExpression), simplifyExpression(rightExpression)); + } + + /** + * Evaluates an expression's response to binding the specified input symbols to NULL + */ + private RowExpression nullInputEvaluator(final Collection nullSymbols, RowExpression expression) + { + expression = RowExpressionNodeInliner.replaceExpression(expression, nullSymbols.stream() + .collect(Collectors.toMap(identity(), variable -> constantNull(variable.getType())))); + return new RowExpressionOptimizer(metadata).optimize(expression, ExpressionOptimizer.Level.OPTIMIZED, session.toConnectorSession()); + } + + private Predicate joinEqualityExpression(final Collection leftVariables) + { + return expression -> { + // At this point in time, our join predicates need to be deterministic + if (determinismEvaluator.isDeterministic(expression) && isOperation(expression, EQUAL)) { + Set variables1 = VariablesExtractor.extractUnique(getLeft(expression)); + Set variables2 = VariablesExtractor.extractUnique(getRight(expression)); + if (variables1.isEmpty() || variables2.isEmpty()) { + return false; + } + return (Iterables.all(variables1, in(leftVariables)) && Iterables.all(variables2, not(in(leftVariables)))) || + (Iterables.all(variables2, in(leftVariables)) && Iterables.all(variables1, not(in(leftVariables)))); + } + return false; + }; + } + + private boolean isOperation(RowExpression expression, OperatorType type) + { + if (expression instanceof CallExpression) { + Optional operatorType = functionManager.getFunctionMetadata(((CallExpression) expression).getFunctionHandle()).getOperatorType(); + if (operatorType.isPresent()) { + return operatorType.get().equals(type); + } + } + return false; + } + + @Override + public PlanNode visitSemiJoin(SemiJoinNode node, RewriteContext context) + { + RowExpression inheritedPredicate = context.get(); + if (!extractConjuncts(inheritedPredicate).contains(node.getSemiJoinOutput())) { + return visitNonFilteringSemiJoin(node, context); + } + return visitFilteringSemiJoin(node, context); + } + + private PlanNode visitNonFilteringSemiJoin(SemiJoinNode node, RewriteContext context) + { + RowExpression inheritedPredicate = context.get(); + List sourceConjuncts = new ArrayList<>(); + List postJoinConjuncts = new ArrayList<>(); + + // TODO: see if there are predicates that can be inferred from the semi join output + + PlanNode rewrittenFilteringSource = context.defaultRewrite(node.getFilteringSource(), TRUE_CONSTANT); + + // Push inheritedPredicates down to the source if they don't involve the semi join output + RowExpressionEqualityInference inheritedInference = new RowExpressionEqualityInference.Builder(functionManager, typeManager) + .addEqualityInference(inheritedPredicate) + .build(); + for (RowExpression conjunct : new RowExpressionEqualityInference.Builder(functionManager, typeManager).nonInferrableConjuncts(inheritedPredicate)) { + RowExpression rewrittenConjunct = inheritedInference.rewriteExpressionAllowNonDeterministic(conjunct, in(node.getSource().getOutputVariables())); + // Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down + if (rewrittenConjunct != null) { + sourceConjuncts.add(rewrittenConjunct); + } + else { + postJoinConjuncts.add(conjunct); + } + } + + // Add the inherited equality predicates back in + RowExpressionEqualityInference.EqualityPartition equalityPartition = inheritedInference.generateEqualitiesPartitionedBy(in(node.getSource() + .getOutputVariables())::apply); + sourceConjuncts.addAll(equalityPartition.getScopeEqualities()); + postJoinConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); + postJoinConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); + + PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(sourceConjuncts)); + + PlanNode output = node; + if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource()) { + output = new SemiJoinNode(node.getId(), rewrittenSource, rewrittenFilteringSource, node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable(), node.getSemiJoinOutput(), node.getSourceHashVariable(), node.getFilteringSourceHashVariable(), node.getDistributionType()); + } + if (!postJoinConjuncts.isEmpty()) { + output = new FilterNode(idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postJoinConjuncts)); + } + return output; + } + + private PlanNode visitFilteringSemiJoin(SemiJoinNode node, RewriteContext context) + { + RowExpression inheritedPredicate = context.get(); + RowExpression deterministicInheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate); + RowExpression sourceEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(effectivePredicateExtractor.extract(node.getSource())); + RowExpression filteringSourceEffectivePredicate = logicalRowExpressions.filterDeterministicConjuncts(effectivePredicateExtractor.extract(node.getFilteringSource())); + RowExpression joinExpression = buildEqualsExpression(functionManager, node.getSourceJoinVariable(), node.getFilteringSourceJoinVariable()); + + List sourceVariables = node.getSource().getOutputVariables(); + List filteringSourceVariables = node.getFilteringSource().getOutputVariables(); + + List sourceConjuncts = new ArrayList<>(); + List filteringSourceConjuncts = new ArrayList<>(); + List postJoinConjuncts = new ArrayList<>(); + + // Generate equality inferences + RowExpressionEqualityInference allInference = createEqualityInference(deterministicInheritedPredicate, sourceEffectivePredicate, filteringSourceEffectivePredicate, joinExpression); + RowExpressionEqualityInference allInferenceWithoutSourceInferred = createEqualityInference(deterministicInheritedPredicate, filteringSourceEffectivePredicate, joinExpression); + RowExpressionEqualityInference allInferenceWithoutFilteringSourceInferred = createEqualityInference(deterministicInheritedPredicate, sourceEffectivePredicate, joinExpression); + + // Push inheritedPredicates down to the source if they don't involve the semi join output + for (RowExpression conjunct : nonInferrableConjuncts(inheritedPredicate)) { + RowExpression rewrittenConjunct = allInference.rewriteExpressionAllowNonDeterministic(conjunct, in(sourceVariables)); + // Since each source row is reflected exactly once in the output, ok to push non-deterministic predicates down + if (rewrittenConjunct != null) { + sourceConjuncts.add(rewrittenConjunct); + } + else { + postJoinConjuncts.add(conjunct); + } + } + + // Push inheritedPredicates down to the filtering source if possible + for (RowExpression conjunct : nonInferrableConjuncts(deterministicInheritedPredicate)) { + RowExpression rewrittenConjunct = allInference.rewriteExpression(conjunct, in(filteringSourceVariables)); + // We cannot push non-deterministic predicates to filtering side. Each filtering side row have to be + // logically reevaluated for each source row. + if (rewrittenConjunct != null) { + filteringSourceConjuncts.add(rewrittenConjunct); + } + } + + // move effective predicate conjuncts source <-> filter + // See if we can push the filtering source effective predicate to the source side + for (RowExpression conjunct : nonInferrableConjuncts(filteringSourceEffectivePredicate)) { + RowExpression rewritten = allInference.rewriteExpression(conjunct, in(sourceVariables)); + if (rewritten != null) { + sourceConjuncts.add(rewritten); + } + } + + // See if we can push the source effective predicate to the filtering soruce side + for (RowExpression conjunct : nonInferrableConjuncts(sourceEffectivePredicate)) { + RowExpression rewritten = allInference.rewriteExpression(conjunct, in(filteringSourceVariables)); + if (rewritten != null) { + filteringSourceConjuncts.add(rewritten); + } + } + + // Add equalities from the inference back in + sourceConjuncts.addAll(allInferenceWithoutSourceInferred.generateEqualitiesPartitionedBy(in(sourceVariables)).getScopeEqualities()); + filteringSourceConjuncts.addAll(allInferenceWithoutFilteringSourceInferred.generateEqualitiesPartitionedBy(in(filteringSourceVariables)).getScopeEqualities()); + + PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(sourceConjuncts)); + PlanNode rewrittenFilteringSource = context.rewrite(node.getFilteringSource(), logicalRowExpressions.combineConjuncts(filteringSourceConjuncts)); + + PlanNode output = node; + if (rewrittenSource != node.getSource() || rewrittenFilteringSource != node.getFilteringSource()) { + output = new SemiJoinNode( + node.getId(), + rewrittenSource, + rewrittenFilteringSource, + node.getSourceJoinVariable(), + node.getFilteringSourceJoinVariable(), + node.getSemiJoinOutput(), + node.getSourceHashVariable(), + node.getFilteringSourceHashVariable(), + node.getDistributionType()); + } + if (!postJoinConjuncts.isEmpty()) { + output = new FilterNode(idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postJoinConjuncts)); + } + return output; + } + + private Iterable nonInferrableConjuncts(RowExpression inheritedPredicate) + { + return new RowExpressionEqualityInference.Builder(functionManager, typeManager) + .nonInferrableConjuncts(inheritedPredicate); + } + + private RowExpressionEqualityInference createEqualityInference(RowExpression... expressions) + { + return new RowExpressionEqualityInference.Builder(functionManager, typeManager) + .addEqualityInference(expressions) + .build(); + } + + @Override + public PlanNode visitAggregation(AggregationNode node, RewriteContext context) + { + if (node.hasEmptyGroupingSet()) { + // TODO: in case of grouping sets, we should be able to push the filters over grouping keys below the aggregation + // and also preserve the filter above the aggregation if it has an empty grouping set + return visitPlan(node, context); + } + + RowExpression inheritedPredicate = context.get(); + + RowExpressionEqualityInference equalityInference = createEqualityInference(inheritedPredicate); + + List pushdownConjuncts = new ArrayList<>(); + List postAggregationConjuncts = new ArrayList<>(); + + List groupingKeyVariables = node.getGroupingKeys(); + + // Strip out non-deterministic conjuncts + postAggregationConjuncts.addAll(ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic)))); + inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate); + + // Sort non-equality predicates by those that can be pushed down and those that cannot + for (RowExpression conjunct : nonInferrableConjuncts(inheritedPredicate)) { + if (node.getGroupIdVariable().isPresent() && VariablesExtractor.extractUnique(conjunct).contains(node.getGroupIdVariable().get())) { + // aggregation operator synthesizes outputs for group ids corresponding to the global grouping set (i.e., ()), so we + // need to preserve any predicates that evaluate the group id to run after the aggregation + // TODO: we should be able to infer if conditions on grouping() correspond to global grouping sets to determine whether + // we need to do this for each specific case + postAggregationConjuncts.add(conjunct); + continue; + } + + RowExpression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(groupingKeyVariables)); + if (rewrittenConjunct != null) { + pushdownConjuncts.add(rewrittenConjunct); + } + else { + postAggregationConjuncts.add(conjunct); + } + } + + // Add the equality predicates back in + RowExpressionEqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(groupingKeyVariables)::apply); + pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); + postAggregationConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); + postAggregationConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); + + PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(pushdownConjuncts)); + + PlanNode output = node; + if (rewrittenSource != node.getSource()) { + output = new AggregationNode(node.getId(), + rewrittenSource, + node.getAggregations(), + node.getGroupingSets(), + ImmutableList.of(), + node.getStep(), + node.getHashVariable(), + node.getGroupIdVariable()); + } + if (!postAggregationConjuncts.isEmpty()) { + output = new FilterNode(idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postAggregationConjuncts)); + } + return output; + } + + @Override + public PlanNode visitUnnest(UnnestNode node, RewriteContext context) + { + RowExpression inheritedPredicate = context.get(); + + RowExpressionEqualityInference equalityInference = createEqualityInference(inheritedPredicate); + + List pushdownConjuncts = new ArrayList<>(); + List postUnnestConjuncts = new ArrayList<>(); + + // Strip out non-deterministic conjuncts + postUnnestConjuncts.addAll(ImmutableList.copyOf(filter(extractConjuncts(inheritedPredicate), not(determinismEvaluator::isDeterministic)))); + inheritedPredicate = logicalRowExpressions.filterDeterministicConjuncts(inheritedPredicate); + + // Sort non-equality predicates by those that can be pushed down and those that cannot + for (RowExpression conjunct : nonInferrableConjuncts(inheritedPredicate)) { + RowExpression rewrittenConjunct = equalityInference.rewriteExpression(conjunct, in(node.getReplicateVariables())); + if (rewrittenConjunct != null) { + pushdownConjuncts.add(rewrittenConjunct); + } + else { + postUnnestConjuncts.add(conjunct); + } + } + + // Add the equality predicates back in + RowExpressionEqualityInference.EqualityPartition equalityPartition = equalityInference.generateEqualitiesPartitionedBy(in(node.getReplicateVariables())::apply); + pushdownConjuncts.addAll(equalityPartition.getScopeEqualities()); + postUnnestConjuncts.addAll(equalityPartition.getScopeComplementEqualities()); + postUnnestConjuncts.addAll(equalityPartition.getScopeStraddlingEqualities()); + + PlanNode rewrittenSource = context.rewrite(node.getSource(), logicalRowExpressions.combineConjuncts(pushdownConjuncts)); + + PlanNode output = node; + if (rewrittenSource != node.getSource()) { + output = new UnnestNode(node.getId(), rewrittenSource, node.getReplicateVariables(), node.getUnnestVariables(), node.getOrdinalityVariable()); + } + if (!postUnnestConjuncts.isEmpty()) { + output = new FilterNode(idAllocator.getNextId(), output, logicalRowExpressions.combineConjuncts(postUnnestConjuncts)); + } + return output; + } + + @Override + public PlanNode visitSample(SampleNode node, RewriteContext context) + { + return context.defaultRewrite(node, context.get()); + } + + @Override + public PlanNode visitTableScan(TableScanNode node, RewriteContext context) + { + RowExpression predicate = simplifyExpression(context.get()); + + if (!TRUE_CONSTANT.equals(predicate)) { + return new FilterNode(idAllocator.getNextId(), node, predicate); + } + + return node; + } + + @Override + public PlanNode visitAssignUniqueId(AssignUniqueId node, RewriteContext context) + { + Set predicateVariables = VariablesExtractor.extractUnique(context.get()); + checkState(!predicateVariables.contains(node.getIdVariable()), "UniqueId in predicate is not yet supported"); + return context.defaultRewrite(node, context.get()); + } + + private static CallExpression buildEqualsExpression(FunctionManager functionManager, RowExpression left, RowExpression right) + { + return call( + EQUAL.getFunctionName().getSuffix(), + functionManager.resolveOperator(EQUAL, fromTypes(left.getType(), right.getType())), + BOOLEAN, + left, + right); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StatsRecordingPlanOptimizer.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StatsRecordingPlanOptimizer.java index 1153545ff5215..f820faba717e1 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StatsRecordingPlanOptimizer.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/StatsRecordingPlanOptimizer.java @@ -20,6 +20,7 @@ import com.facebook.presto.sql.planner.OptimizerStatsRecorder; import com.facebook.presto.sql.planner.PlanVariableAllocator; import com.facebook.presto.sql.planner.TypeProvider; +import com.google.common.annotations.VisibleForTesting; import static java.util.Objects.requireNonNull; @@ -36,6 +37,12 @@ public StatsRecordingPlanOptimizer(OptimizerStatsRecorder stats, PlanOptimizer d stats.register(delegate); } + @VisibleForTesting + public PlanOptimizer getDelegate() + { + return delegate; + } + public final PlanNode optimize( PlanNode plan, Session session, diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java index 45283b9ad6267..f45cb9724b801 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestPredicatePushdown.java @@ -13,17 +13,17 @@ */ package com.facebook.presto.sql.planner; +import com.facebook.presto.spi.function.OperatorType; import com.facebook.presto.sql.planner.assertions.BasePlanTest; import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; import com.facebook.presto.sql.planner.iterative.rule.test.RuleTester; import com.facebook.presto.sql.planner.optimizations.PlanOptimizer; import com.facebook.presto.sql.planner.optimizations.PredicatePushDown; +import com.facebook.presto.sql.planner.optimizations.RowExpressionPredicatePushDown; +import com.facebook.presto.sql.planner.optimizations.StatsRecordingPlanOptimizer; import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.plan.JoinNode.EquiJoinClause; import com.facebook.presto.sql.planner.plan.WindowNode; -import com.facebook.presto.sql.tree.ComparisonExpression; -import com.facebook.presto.sql.tree.LongLiteral; -import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @@ -31,6 +31,7 @@ import java.util.List; import java.util.Optional; +import static com.facebook.presto.spi.type.IntegerType.INTEGER; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.any; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.assignUniqueId; @@ -49,7 +50,7 @@ import static com.facebook.presto.sql.planner.plan.JoinNode.DistributionType.REPLICATED; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.INNER; import static com.facebook.presto.sql.planner.plan.JoinNode.Type.LEFT; -import static com.facebook.presto.sql.tree.ComparisonExpression.Operator.EQUAL; +import static com.facebook.presto.sql.relational.Expressions.constant; public class TestPredicatePushdown extends BasePlanTest @@ -423,14 +424,26 @@ public void testNonDeterministicPredicateNotPushedDown() ImmutableMap.of("CUST_KEY", "custkey")))))))); } + @Override + protected void assertPlan(String sql, PlanMatchPattern pattern) + { + // TODO remove tests with filtered optimizer once we only have RowExpressionPredicatePushDown + // Currently we have mixture of Expression/RowExpression based push down, so we disable one of them to make sure test covers both code path. + assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern, planOptimizer -> !(planOptimizer instanceof StatsRecordingPlanOptimizer) || + !(((StatsRecordingPlanOptimizer) planOptimizer).getDelegate() instanceof PredicatePushDown)); + assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern, planOptimizer -> !(planOptimizer instanceof StatsRecordingPlanOptimizer) || + !(((StatsRecordingPlanOptimizer) planOptimizer).getDelegate() instanceof RowExpressionPredicatePushDown)); + assertPlan(sql, LogicalPlanner.Stage.OPTIMIZED_AND_VALIDATED, pattern); + } + @Test public void testPredicatePushDownCreatesValidJoin() { RuleTester tester = new RuleTester(); - tester.assertThat(new PredicatePushDown(tester.getMetadata(), tester.getSqlParser())) + tester.assertThat(new RowExpressionPredicatePushDown(tester.getMetadata(), tester.getSqlParser())) .on(p -> p.join(INNER, - p.filter(new ComparisonExpression(EQUAL, new SymbolReference("a1"), new LongLiteral("1")), + p.filter(p.comparison(OperatorType.EQUAL, p.variable("a1"), constant(1L, INTEGER)), p.values(p.variable("a1"))), p.values(p.variable("b1")), ImmutableList.of(new EquiJoinClause(p.variable("a1"), p.variable("b1"))), diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java index c48ffcedc41e9..7523539e2206c 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/assertions/RowExpressionVerifier.java @@ -208,6 +208,11 @@ protected Boolean visitComparisonExpression(ComparisonExpression expected, RowEx OperatorType actualOperatorType = functionMetadata.getOperatorType().get(); OperatorType expectedOperatorType = getOperatorType(expected.getOperator()); if (expectedOperatorType.equals(actualOperatorType)) { + if (actualOperatorType == EQUAL) { + return (process(expected.getLeft(), ((CallExpression) actual).getArguments().get(0)) && process(expected.getRight(), ((CallExpression) actual).getArguments().get(1))) + || (process(expected.getLeft(), ((CallExpression) actual).getArguments().get(1)) && process(expected.getRight(), ((CallExpression) actual).getArguments().get(0))); + } + // TODO support other comparison operators return process(expected.getLeft(), ((CallExpression) actual).getArguments().get(0)) && process(expected.getRight(), ((CallExpression) actual).getArguments().get(1)); } } diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java index 99a6285579a2e..4f4fc6d102d82 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/PlanBuilder.java @@ -293,6 +293,12 @@ public CallExpression binaryOperation(OperatorType operatorType, RowExpression l return call(operatorType.getOperator(), functionHandle, left.getType(), left, right); } + public CallExpression comparison(OperatorType operatorType, RowExpression left, RowExpression right) + { + FunctionHandle functionHandle = new FunctionResolution(metadata.getFunctionManager()).comparisonFunction(operatorType, left.getType(), right.getType()); + return call(operatorType.getOperator(), functionHandle, left.getType(), left, right); + } + public class AggregationBuilder { private final TypeProvider types;