diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java
index a1d5791fc16a9..97fb43ea2cf96 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/EffectivePredicateExtractor.java
@@ -70,6 +70,7 @@
*
* Note: non-deterministic predicates can not be pulled up (so they will be ignored)
*/
+@Deprecated
public class EffectivePredicateExtractor
{
private static final Predicate> VARIABLE_MATCHES_EXPRESSION =
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionEqualityInference.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionEqualityInference.java
index 557a451d546fc..aa4ab8937ee80 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionEqualityInference.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionEqualityInference.java
@@ -106,6 +106,13 @@ private RowExpressionEqualityInference(
this.derivedExpressions = ImmutableSet.copyOf(derivedExpressions);
}
+ public static RowExpressionEqualityInference createEqualityInference(Metadata metadata, RowExpression... equalityInferences)
+ {
+ return new Builder(metadata)
+ .addEqualityInference(equalityInferences)
+ .build();
+ }
+
/**
* Attempts to rewrite an RowExpression in terms of the symbols allowed by the symbol scope
* given the known equalities. Returns null if unsuccessful.
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.java
new file mode 100644
index 0000000000000..46a5803aad2f3
--- /dev/null
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/RowExpressionPredicateExtractor.java
@@ -0,0 +1,424 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner;
+
+import com.facebook.presto.expressions.LogicalRowExpressions;
+import com.facebook.presto.metadata.FunctionManager;
+import com.facebook.presto.metadata.OperatorNotFoundException;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.plan.AggregationNode;
+import com.facebook.presto.spi.plan.FilterNode;
+import com.facebook.presto.spi.plan.LimitNode;
+import com.facebook.presto.spi.plan.PlanNode;
+import com.facebook.presto.spi.plan.ProjectNode;
+import com.facebook.presto.spi.plan.TableScanNode;
+import com.facebook.presto.spi.plan.TopNNode;
+import com.facebook.presto.spi.relation.CallExpression;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.spi.type.TypeManager;
+import com.facebook.presto.sql.planner.plan.AssignUniqueId;
+import com.facebook.presto.sql.planner.plan.DistinctLimitNode;
+import com.facebook.presto.sql.planner.plan.ExchangeNode;
+import com.facebook.presto.sql.planner.plan.InternalPlanVisitor;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
+import com.facebook.presto.sql.planner.plan.SortNode;
+import com.facebook.presto.sql.planner.plan.SpatialJoinNode;
+import com.facebook.presto.sql.planner.plan.UnionNode;
+import com.facebook.presto.sql.planner.plan.WindowNode;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
+import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
+import com.google.common.collect.ImmutableBiMap;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
+import static com.facebook.presto.expressions.LogicalRowExpressions.extractConjuncts;
+import static com.facebook.presto.spi.function.OperatorType.EQUAL;
+import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
+import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
+import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
+import static com.facebook.presto.sql.relational.Expressions.call;
+import static com.facebook.presto.sql.relational.Expressions.specialForm;
+import static com.google.common.base.Predicates.in;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+
+public class RowExpressionPredicateExtractor
+{
+ private final RowExpressionDomainTranslator domainTranslator;
+ private final FunctionManager functionManager;
+ private final TypeManager typeManager;
+
+ public RowExpressionPredicateExtractor(RowExpressionDomainTranslator domainTranslator, FunctionManager functionManager, TypeManager typeManager)
+ {
+ this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null");
+ this.functionManager = functionManager;
+ this.typeManager = typeManager;
+ }
+
+ public RowExpression extract(PlanNode node)
+ {
+ return node.accept(new Visitor(domainTranslator, functionManager, typeManager), null);
+ }
+
+ private static class Visitor
+ extends InternalPlanVisitor
+ {
+ private final RowExpressionDomainTranslator domainTranslator;
+ private final LogicalRowExpressions logicalRowExpressions;
+ private final RowExpressionDeterminismEvaluator determinismEvaluator;
+ private final TypeManager typeManager;
+ private final FunctionManager functionManger;
+
+ public Visitor(RowExpressionDomainTranslator domainTranslator, FunctionManager functionManager, TypeManager typeManager)
+ {
+ this.domainTranslator = requireNonNull(domainTranslator, "domainTranslator is null");
+ this.typeManager = requireNonNull(typeManager);
+ this.functionManger = requireNonNull(functionManager);
+ this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionManager);
+ this.logicalRowExpressions = new LogicalRowExpressions(determinismEvaluator, new FunctionResolution(functionManager), functionManager);
+ }
+
+ @Override
+ public RowExpression visitPlan(PlanNode node, Void context)
+ {
+ return TRUE_CONSTANT;
+ }
+
+ @Override
+ public RowExpression visitAggregation(AggregationNode node, Void context)
+ {
+ // GROUP BY () always produces a group, regardless of whether there's any
+ // input (unlike the case where there are group by keys, which produce
+ // no output if there's no input).
+ // Therefore, we can't say anything about the effective predicate of the
+ // output of such an aggregation.
+ if (node.getGroupingKeys().isEmpty()) {
+ return TRUE_CONSTANT;
+ }
+
+ RowExpression underlyingPredicate = node.getSource().accept(this, context);
+
+ return pullExpressionThroughVariables(underlyingPredicate, node.getGroupingKeys());
+ }
+
+ @Override
+ public RowExpression visitFilter(FilterNode node, Void context)
+ {
+ RowExpression underlyingPredicate = node.getSource().accept(this, context);
+
+ RowExpression predicate = node.getPredicate();
+
+ // Remove non-deterministic conjuncts
+ predicate = logicalRowExpressions.filterDeterministicConjuncts(predicate);
+
+ return logicalRowExpressions.combineConjuncts(predicate, underlyingPredicate);
+ }
+
+ @Override
+ public RowExpression visitExchange(ExchangeNode node, Void context)
+ {
+ return deriveCommonPredicates(node, source -> {
+ Map mappings = new HashMap<>();
+ for (int i = 0; i < node.getInputs().get(source).size(); i++) {
+ mappings.put(
+ node.getOutputVariables().get(i),
+ node.getInputs().get(source).get(i));
+ }
+ return mappings.entrySet();
+ });
+ }
+
+ @Override
+ public RowExpression visitProject(ProjectNode node, Void context)
+ {
+ // TODO: add simple algebraic solver for projection translation (right now only considers identity projections)
+
+ RowExpression underlyingPredicate = node.getSource().accept(this, context);
+
+ List projectionEqualities = node.getAssignments().getMap().entrySet().stream()
+ .filter(this::notIdentityAssignment)
+ .filter(this::canCompareEquity)
+ .map(this::toEquality)
+ .collect(toImmutableList());
+
+ return pullExpressionThroughVariables(logicalRowExpressions.combineConjuncts(
+ ImmutableList.builder()
+ .addAll(projectionEqualities)
+ .add(underlyingPredicate)
+ .build()),
+ node.getOutputVariables());
+ }
+
+ @Override
+ public RowExpression visitTopN(TopNNode node, Void context)
+ {
+ return node.getSource().accept(this, context);
+ }
+
+ @Override
+ public RowExpression visitLimit(LimitNode node, Void context)
+ {
+ return node.getSource().accept(this, context);
+ }
+
+ @Override
+ public RowExpression visitAssignUniqueId(AssignUniqueId node, Void context)
+ {
+ return node.getSource().accept(this, context);
+ }
+
+ @Override
+ public RowExpression visitDistinctLimit(DistinctLimitNode node, Void context)
+ {
+ return node.getSource().accept(this, context);
+ }
+
+ @Override
+ public RowExpression visitTableScan(TableScanNode node, Void context)
+ {
+ Map assignments = ImmutableBiMap.copyOf(node.getAssignments()).inverse();
+ return domainTranslator.toPredicate(node.getCurrentConstraint().simplify().transform(column -> assignments.containsKey(column) ? assignments.get(column) : null));
+ }
+
+ @Override
+ public RowExpression visitSort(SortNode node, Void context)
+ {
+ return node.getSource().accept(this, context);
+ }
+
+ @Override
+ public RowExpression visitWindow(WindowNode node, Void context)
+ {
+ return node.getSource().accept(this, context);
+ }
+
+ @Override
+ public RowExpression visitUnion(UnionNode node, Void context)
+ {
+ return deriveCommonPredicates(node, source -> node.outputMap(source).entries());
+ }
+
+ @Override
+ public RowExpression visitJoin(JoinNode node, Void context)
+ {
+ RowExpression leftPredicate = node.getLeft().accept(this, context);
+ RowExpression rightPredicate = node.getRight().accept(this, context);
+
+ List joinConjuncts = node.getCriteria().stream()
+ .map(this::toRowExpression)
+ .collect(toImmutableList());
+
+ switch (node.getType()) {
+ case INNER:
+ return pullExpressionThroughVariables(logicalRowExpressions.combineConjuncts(ImmutableList.builder()
+ .add(leftPredicate)
+ .add(rightPredicate)
+ .add(logicalRowExpressions.combineConjuncts(joinConjuncts))
+ .add(node.getFilter().orElse(TRUE_CONSTANT))
+ .build()), node.getOutputVariables());
+ case LEFT:
+ return logicalRowExpressions.combineConjuncts(ImmutableList.builder()
+ .add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables()))
+ .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
+ .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
+ .build());
+ case RIGHT:
+ return logicalRowExpressions.combineConjuncts(ImmutableList.builder()
+ .add(pullExpressionThroughVariables(rightPredicate, node.getOutputVariables()))
+ .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputVariables(), node.getLeft().getOutputVariables()::contains))
+ .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getLeft().getOutputVariables()::contains))
+ .build());
+ case FULL:
+ return logicalRowExpressions.combineConjuncts(ImmutableList.builder()
+ .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(leftPredicate), node.getOutputVariables(), node.getLeft().getOutputVariables()::contains))
+ .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
+ .addAll(pullNullableConjunctsThroughOuterJoin(joinConjuncts, node.getOutputVariables(), node.getLeft().getOutputVariables()::contains, node.getRight().getOutputVariables()::contains))
+ .build());
+ default:
+ throw new UnsupportedOperationException("Unknown join type: " + node.getType());
+ }
+ }
+
+ private Iterable pullNullableConjunctsThroughOuterJoin(List conjuncts, Collection outputVariables, Predicate... nullVariableScopes)
+ {
+ // Conjuncts without any symbol dependencies cannot be applied to the effective predicate (e.g. FALSE literal)
+ return conjuncts.stream()
+ .map(expression -> pullExpressionThroughVariables(expression, outputVariables))
+ .map(expression -> VariablesExtractor.extractAll(expression).isEmpty() ? TRUE_CONSTANT : expression)
+ .map(expressionOrNullVariables(nullVariableScopes))
+ .collect(toImmutableList());
+ }
+
+ public Function expressionOrNullVariables(final Predicate... nullVariableScopes)
+ {
+ return expression -> {
+ ImmutableList.Builder resultDisjunct = ImmutableList.builder();
+ resultDisjunct.add(expression);
+
+ for (Predicate nullVariableScope : nullVariableScopes) {
+ List variables = VariablesExtractor.extractUnique(expression).stream()
+ .filter(nullVariableScope)
+ .collect(toImmutableList());
+
+ if (Iterables.isEmpty(variables)) {
+ continue;
+ }
+
+ ImmutableList.Builder nullConjuncts = ImmutableList.builder();
+ for (VariableReferenceExpression variable : variables) {
+ nullConjuncts.add(specialForm(IS_NULL, BOOLEAN, variable));
+ }
+
+ resultDisjunct.add(logicalRowExpressions.and(nullConjuncts.build()));
+ }
+
+ return logicalRowExpressions.or(resultDisjunct.build());
+ };
+ }
+
+ @Override
+ public RowExpression visitSemiJoin(SemiJoinNode node, Void context)
+ {
+ // Filtering source does not change the effective predicate over the output symbols
+ return node.getSource().accept(this, context);
+ }
+
+ @Override
+ public RowExpression visitSpatialJoin(SpatialJoinNode node, Void context)
+ {
+ RowExpression leftPredicate = node.getLeft().accept(this, context);
+ RowExpression rightPredicate = node.getRight().accept(this, context);
+
+ switch (node.getType()) {
+ case INNER:
+ return logicalRowExpressions.combineConjuncts(ImmutableList.builder()
+ .add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables()))
+ .add(pullExpressionThroughVariables(rightPredicate, node.getOutputVariables()))
+ .build());
+ case LEFT:
+ return logicalRowExpressions.combineConjuncts(ImmutableList.builder()
+ .add(pullExpressionThroughVariables(leftPredicate, node.getOutputVariables()))
+ .addAll(pullNullableConjunctsThroughOuterJoin(extractConjuncts(rightPredicate), node.getOutputVariables(), node.getRight().getOutputVariables()::contains))
+ .build());
+ default:
+ throw new IllegalArgumentException("Unsupported spatial join type: " + node.getType());
+ }
+ }
+
+ private RowExpression toRowExpression(JoinNode.EquiJoinClause equiJoinClause)
+ {
+ return buildEqualsExpression(functionManger, equiJoinClause.getLeft(), equiJoinClause.getRight());
+ }
+
+ private RowExpression deriveCommonPredicates(PlanNode node, Function>> mapping)
+ {
+ // Find the predicates that can be pulled up from each source
+ List> sourceOutputConjuncts = new ArrayList<>();
+ for (int i = 0; i < node.getSources().size(); i++) {
+ RowExpression underlyingPredicate = node.getSources().get(i).accept(this, null);
+
+ List equalities = mapping.apply(i).stream()
+ .filter(this::notIdentityAssignment)
+ .filter(this::canCompareEquity)
+ .map(this::toEquality)
+ .collect(toImmutableList());
+
+ sourceOutputConjuncts.add(ImmutableSet.copyOf(extractConjuncts(pullExpressionThroughVariables(logicalRowExpressions.combineConjuncts(
+ ImmutableList.builder()
+ .addAll(equalities)
+ .add(underlyingPredicate)
+ .build()),
+ node.getOutputVariables()))));
+ }
+
+ // Find the intersection of predicates across all sources
+ // TODO: use a more precise way to determine overlapping conjuncts (e.g. commutative predicates)
+ Iterator> iterator = sourceOutputConjuncts.iterator();
+ Set potentialOutputConjuncts = iterator.next();
+ while (iterator.hasNext()) {
+ potentialOutputConjuncts = Sets.intersection(potentialOutputConjuncts, iterator.next());
+ }
+
+ return logicalRowExpressions.combineConjuncts(potentialOutputConjuncts);
+ }
+
+ private boolean notIdentityAssignment(Map.Entry entry)
+ {
+ return !entry.getKey().equals(entry.getValue());
+ }
+
+ private boolean canCompareEquity(Map.Entry entry)
+ {
+ try {
+ functionManger.resolveOperator(EQUAL, fromTypes(entry.getKey().getType(), entry.getValue().getType()));
+ return true;
+ }
+ catch (OperatorNotFoundException e) {
+ return false;
+ }
+ }
+
+ private RowExpression toEquality(Map.Entry entry)
+ {
+ return buildEqualsExpression(functionManger, entry.getKey(), entry.getValue());
+ }
+
+ private static CallExpression buildEqualsExpression(FunctionManager functionManager, RowExpression left, RowExpression right)
+ {
+ return call(
+ EQUAL.getFunctionName().getSuffix(),
+ functionManager.resolveOperator(EQUAL, fromTypes(left.getType(), right.getType())),
+ BOOLEAN,
+ left,
+ right);
+ }
+
+ private RowExpression pullExpressionThroughVariables(RowExpression expression, Collection variables)
+ {
+ RowExpressionEqualityInference equalityInference = new RowExpressionEqualityInference.Builder(functionManger, typeManager)
+ .addEqualityInference(expression)
+ .build();
+
+ ImmutableList.Builder effectiveConjuncts = ImmutableList.builder();
+ for (RowExpression conjunct : new RowExpressionEqualityInference.Builder(functionManger, typeManager).nonInferrableConjuncts(expression)) {
+ if (determinismEvaluator.isDeterministic(conjunct)) {
+ RowExpression rewritten = equalityInference.rewriteExpression(conjunct, in(variables));
+ if (rewritten != null) {
+ effectiveConjuncts.add(rewritten);
+ }
+ }
+ }
+
+ effectiveConjuncts.addAll(equalityInference.generateEqualitiesPartitionedBy(in(variables)).getScopeEqualities());
+
+ return logicalRowExpressions.combineConjuncts(effectiveConjuncts.build());
+ }
+ }
+}
diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java
index bda32840eb556..d326243774a8e 100644
--- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java
+++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java
@@ -96,6 +96,14 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi
return canonicalizedLeft.equals(canonicalizedRight);
}
+ public boolean areExpressionsEquivalent(RowExpression leftExpression, RowExpression rightExpression)
+ {
+ RowExpression canonicalizedLeft = leftExpression.accept(canonicalizationVisitor, null);
+ RowExpression canonicalizedRight = rightExpression.accept(canonicalizationVisitor, null);
+
+ return canonicalizedLeft.equals(canonicalizedRight);
+ }
+
private RowExpression toRowExpression(Session session, Expression expression, Map variableInput, TypeProvider types)
{
// replace qualified names with input references since row expressions do not support these
diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionPredicateExtractor.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionPredicateExtractor.java
new file mode 100644
index 0000000000000..588ca62b8dc46
--- /dev/null
+++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/TestRowExpressionPredicateExtractor.java
@@ -0,0 +1,799 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.sql.planner;
+
+import com.facebook.presto.expressions.LogicalRowExpressions;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.metadata.MetadataManager;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ConnectorId;
+import com.facebook.presto.spi.TableHandle;
+import com.facebook.presto.spi.block.SortOrder;
+import com.facebook.presto.spi.function.OperatorType;
+import com.facebook.presto.spi.plan.AggregationNode;
+import com.facebook.presto.spi.plan.FilterNode;
+import com.facebook.presto.spi.plan.LimitNode;
+import com.facebook.presto.spi.plan.Ordering;
+import com.facebook.presto.spi.plan.OrderingScheme;
+import com.facebook.presto.spi.plan.PlanNode;
+import com.facebook.presto.spi.plan.PlanNodeId;
+import com.facebook.presto.spi.plan.ProjectNode;
+import com.facebook.presto.spi.plan.TableScanNode;
+import com.facebook.presto.spi.plan.TopNNode;
+import com.facebook.presto.spi.predicate.Domain;
+import com.facebook.presto.spi.predicate.TupleDomain;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.VariableReferenceExpression;
+import com.facebook.presto.sql.planner.plan.JoinNode;
+import com.facebook.presto.sql.planner.plan.SemiJoinNode;
+import com.facebook.presto.sql.planner.plan.SortNode;
+import com.facebook.presto.sql.planner.plan.UnionNode;
+import com.facebook.presto.sql.planner.plan.WindowNode;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
+import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
+import com.facebook.presto.testing.TestingHandle;
+import com.facebook.presto.testing.TestingMetadata;
+import com.facebook.presto.testing.TestingTransactionHandle;
+import com.google.common.base.Preconditions;
+import com.google.common.base.Predicates;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableListMultimap;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.UUID;
+
+import static com.facebook.presto.expressions.LogicalRowExpressions.FALSE_CONSTANT;
+import static com.facebook.presto.expressions.LogicalRowExpressions.TRUE_CONSTANT;
+import static com.facebook.presto.expressions.LogicalRowExpressions.and;
+import static com.facebook.presto.expressions.LogicalRowExpressions.or;
+import static com.facebook.presto.spi.function.OperatorType.EQUAL;
+import static com.facebook.presto.spi.function.OperatorType.GREATER_THAN;
+import static com.facebook.presto.spi.function.OperatorType.LESS_THAN;
+import static com.facebook.presto.spi.function.OperatorType.LESS_THAN_OR_EQUAL;
+import static com.facebook.presto.spi.plan.AggregationNode.globalAggregation;
+import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet;
+import static com.facebook.presto.spi.plan.LimitNode.Step.FINAL;
+import static com.facebook.presto.spi.relation.SpecialFormExpression.Form.IS_NULL;
+import static com.facebook.presto.spi.type.BigintType.BIGINT;
+import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
+import static com.facebook.presto.spi.type.DoubleType.DOUBLE;
+import static com.facebook.presto.sql.analyzer.TypeSignatureProvider.fromTypes;
+import static com.facebook.presto.sql.planner.RowExpressionEqualityInference.Builder.nonInferrableConjuncts;
+import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment;
+import static com.facebook.presto.sql.planner.optimizations.AggregationNodeUtils.count;
+import static com.facebook.presto.sql.relational.Expressions.call;
+import static com.facebook.presto.sql.relational.Expressions.constant;
+import static com.facebook.presto.sql.relational.Expressions.specialForm;
+import static org.testng.Assert.assertEquals;
+
+public class TestRowExpressionPredicateExtractor
+{
+ private static final TableHandle DUAL_TABLE_HANDLE = new TableHandle(
+ new ConnectorId("test"),
+ new TestingMetadata.TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.empty());
+
+ private static final TableHandle DUAL_TABLE_HANDLE_WITH_LAYOUT = new TableHandle(
+ new ConnectorId("test"),
+ new TestingMetadata.TestingTableHandle(),
+ TestingTransactionHandle.create(),
+ Optional.of(TestingHandle.INSTANCE));
+
+ private static final VariableReferenceExpression AV = new VariableReferenceExpression("a", BIGINT);
+ private static final VariableReferenceExpression BV = new VariableReferenceExpression("b", BIGINT);
+ private static final VariableReferenceExpression CV = new VariableReferenceExpression("c", BIGINT);
+ private static final VariableReferenceExpression DV = new VariableReferenceExpression("d", BIGINT);
+ private static final VariableReferenceExpression EV = new VariableReferenceExpression("e", BIGINT);
+ private static final VariableReferenceExpression FV = new VariableReferenceExpression("f", BIGINT);
+ private static final VariableReferenceExpression GV = new VariableReferenceExpression("g", BIGINT);
+
+ private final Metadata metadata = MetadataManager.createTestMetadataManager();
+ private final LogicalRowExpressions logicalRowExpressions = new LogicalRowExpressions(
+ new RowExpressionDeterminismEvaluator(metadata.getFunctionManager()),
+ new FunctionResolution(metadata.getFunctionManager()),
+ metadata.getFunctionManager());
+ private final RowExpressionPredicateExtractor effectivePredicateExtractor = new RowExpressionPredicateExtractor(
+ new RowExpressionDomainTranslator(metadata),
+ metadata.getFunctionManager(),
+ metadata.getTypeManager());
+
+ private Map scanAssignments;
+ private TableScanNode baseTableScan;
+
+ @BeforeMethod
+ public void setUp()
+ {
+ scanAssignments = ImmutableMap.builder()
+ .put(AV, new TestingMetadata.TestingColumnHandle("a"))
+ .put(BV, new TestingMetadata.TestingColumnHandle("b"))
+ .put(CV, new TestingMetadata.TestingColumnHandle("c"))
+ .put(DV, new TestingMetadata.TestingColumnHandle("d"))
+ .put(EV, new TestingMetadata.TestingColumnHandle("e"))
+ .put(FV, new TestingMetadata.TestingColumnHandle("f"))
+ .build();
+
+ Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV, DV, EV, FV)));
+ baseTableScan = new TableScanNode(
+ newId(),
+ DUAL_TABLE_HANDLE,
+ ImmutableList.copyOf(assignments.keySet()),
+ assignments,
+ TupleDomain.all(),
+ TupleDomain.all());
+ }
+
+ @Test
+ public void testAggregation()
+ {
+ PlanNode node = new AggregationNode(newId(),
+ filter(baseTableScan,
+ and(
+ equals(AV, DV),
+ equals(BV, EV),
+ equals(CV, FV),
+ lessThan(DV, bigintLiteral(10)),
+ lessThan(CV, DV),
+ greaterThan(AV, bigintLiteral(2)),
+ equals(EV, FV))),
+ ImmutableMap.of(
+ CV, count(metadata.getFunctionManager()),
+ DV, count(metadata.getFunctionManager())),
+ singleGroupingSet(ImmutableList.of(AV, BV, CV)),
+ ImmutableList.of(),
+ AggregationNode.Step.FINAL,
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Rewrite in terms of group by symbols
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(
+ lessThan(AV, bigintLiteral(10)),
+ lessThan(BV, AV),
+ greaterThan(AV, bigintLiteral(2)),
+ equals(BV, CV)));
+ }
+
+ @Test
+ public void testGroupByEmpty()
+ {
+ PlanNode node = new AggregationNode(
+ newId(),
+ filter(baseTableScan, FALSE_CONSTANT),
+ ImmutableMap.of(),
+ globalAggregation(),
+ ImmutableList.of(),
+ AggregationNode.Step.FINAL,
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ assertEquals(effectivePredicate, TRUE_CONSTANT);
+ }
+
+ @Test
+ public void testFilter()
+ {
+ PlanNode node = filter(baseTableScan,
+ and(
+ greaterThan(AV, call(metadata.getFunctionManager(), "rand", DOUBLE, ImmutableList.of())),
+ lessThan(BV, bigintLiteral(10))));
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Non-deterministic functions should be purged
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(lessThan(BV, bigintLiteral(10))));
+ }
+
+ @Test
+ public void testProject()
+ {
+ PlanNode node = new ProjectNode(newId(),
+ filter(baseTableScan,
+ and(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10)))),
+ assignment(DV, AV, EV, CV));
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Rewrite in terms of project output symbols
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(
+ lessThan(DV, bigintLiteral(10)),
+ equals(DV, EV)));
+ }
+
+ @Test
+ public void testTopN()
+ {
+ PlanNode node = new TopNNode(newId(),
+ filter(baseTableScan,
+ and(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10)))),
+ 1, new OrderingScheme(ImmutableList.of(new Ordering(AV, SortOrder.ASC_NULLS_FIRST))), TopNNode.Step.PARTIAL);
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Pass through
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10))));
+ }
+
+ @Test
+ public void testLimit()
+ {
+ PlanNode node = new LimitNode(newId(),
+ filter(baseTableScan,
+ and(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10)))),
+ 1,
+ FINAL);
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Pass through
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10))));
+ }
+
+ @Test
+ public void testSort()
+ {
+ PlanNode node = new SortNode(newId(),
+ filter(baseTableScan,
+ and(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10)))),
+ new OrderingScheme(ImmutableList.of(new Ordering(AV, SortOrder.ASC_NULLS_LAST))),
+ false);
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Pass through
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10))));
+ }
+
+ @Test
+ public void testWindow()
+ {
+ PlanNode node = new WindowNode(newId(),
+ filter(baseTableScan,
+ and(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10)))),
+ new WindowNode.Specification(
+ ImmutableList.of(AV),
+ Optional.of(new OrderingScheme(
+ ImmutableList.of(new Ordering(AV, SortOrder.ASC_NULLS_LAST))))),
+ ImmutableMap.of(),
+ Optional.empty(),
+ ImmutableSet.of(),
+ 0);
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Pass through
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(
+ equals(AV, BV),
+ equals(BV, CV),
+ lessThan(CV, bigintLiteral(10))));
+ }
+
+ @Test
+ public void testTableScan()
+ {
+ // Effective predicate is True if there is no effective predicate
+ Map assignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV, DV)));
+ PlanNode node = new TableScanNode(
+ newId(),
+ DUAL_TABLE_HANDLE,
+ ImmutableList.copyOf(assignments.keySet()),
+ assignments,
+ TupleDomain.all(),
+ TupleDomain.all());
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+ assertEquals(effectivePredicate, TRUE_CONSTANT);
+
+ node = new TableScanNode(
+ newId(),
+ DUAL_TABLE_HANDLE_WITH_LAYOUT,
+ ImmutableList.copyOf(assignments.keySet()),
+ assignments,
+ TupleDomain.none(),
+ TupleDomain.all());
+ effectivePredicate = effectivePredicateExtractor.extract(node);
+ assertEquals(effectivePredicate, FALSE_CONSTANT);
+
+ node = new TableScanNode(
+ newId(),
+ DUAL_TABLE_HANDLE_WITH_LAYOUT,
+ ImmutableList.copyOf(assignments.keySet()),
+ assignments,
+ TupleDomain.withColumnDomains(ImmutableMap.of(scanAssignments.get(AV), Domain.singleValue(BIGINT, 1L))),
+ TupleDomain.all());
+ effectivePredicate = effectivePredicateExtractor.extract(node);
+ assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(1L), AV)));
+
+ node = new TableScanNode(
+ newId(),
+ DUAL_TABLE_HANDLE_WITH_LAYOUT,
+ ImmutableList.copyOf(assignments.keySet()),
+ assignments,
+ TupleDomain.withColumnDomains(ImmutableMap.of(
+ scanAssignments.get(AV), Domain.singleValue(BIGINT, 1L),
+ scanAssignments.get(BV), Domain.singleValue(BIGINT, 2L))),
+ TupleDomain.all());
+ effectivePredicate = effectivePredicateExtractor.extract(node);
+ assertEquals(normalizeConjuncts(effectivePredicate), normalizeConjuncts(equals(bigintLiteral(2L), BV), equals(bigintLiteral(1L), AV)));
+
+ node = new TableScanNode(
+ newId(),
+ DUAL_TABLE_HANDLE,
+ ImmutableList.copyOf(assignments.keySet()),
+ assignments,
+ TupleDomain.all(),
+ TupleDomain.all());
+ effectivePredicate = effectivePredicateExtractor.extract(node);
+ assertEquals(effectivePredicate, TRUE_CONSTANT);
+ }
+
+ @Test
+ public void testUnion()
+ {
+ ImmutableListMultimap variableMapping = ImmutableListMultimap.of(AV, BV, AV, CV, AV, EV);
+ PlanNode node = new UnionNode(newId(),
+ ImmutableList.of(
+ filter(baseTableScan, greaterThan(AV, bigintLiteral(10))),
+ filter(baseTableScan, and(greaterThan(AV, bigintLiteral(10)), lessThan(AV, bigintLiteral(100)))),
+ filter(baseTableScan, and(greaterThan(AV, bigintLiteral(10)), lessThan(AV, bigintLiteral(100))))),
+ variableMapping);
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Only the common conjuncts can be inferred through a Union
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(greaterThan(AV, bigintLiteral(10))));
+ }
+
+ @Test
+ public void testInnerJoin()
+ {
+ ImmutableList.Builder criteriaBuilder = ImmutableList.builder();
+ criteriaBuilder.add(new JoinNode.EquiJoinClause(AV, DV));
+ criteriaBuilder.add(new JoinNode.EquiJoinClause(BV, EV));
+ List criteria = criteriaBuilder.build();
+
+ Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV)));
+ TableScanNode leftScan = tableScanNode(leftAssignments);
+
+ Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV)));
+ TableScanNode rightScan = tableScanNode(rightAssignments);
+
+ FilterNode left = filter(leftScan,
+ and(
+ lessThan(BV, AV),
+ lessThan(CV, bigintLiteral(10)),
+ equals(GV, bigintLiteral(10))));
+ FilterNode right = filter(rightScan,
+ and(
+ equals(DV, EV),
+ lessThan(FV, bigintLiteral(100))));
+
+ PlanNode node = new JoinNode(newId(),
+ JoinNode.Type.INNER,
+ left,
+ right,
+ criteria,
+ ImmutableList.builder()
+ .addAll(left.getOutputVariables())
+ .addAll(right.getOutputVariables())
+ .build(),
+ Optional.of(lessThanOrEqual(BV, EV)),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // All predicates having output symbol should be carried through
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(lessThan(BV, AV),
+ lessThan(CV, bigintLiteral(10)),
+ equals(DV, EV),
+ lessThan(FV, bigintLiteral(100)),
+ equals(AV, DV),
+ equals(BV, EV),
+ lessThanOrEqual(BV, EV)));
+ }
+
+ @Test
+ public void testInnerJoinPropagatesPredicatesViaEquiConditions()
+ {
+ Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV)));
+ TableScanNode leftScan = tableScanNode(leftAssignments);
+
+ Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV)));
+ TableScanNode rightScan = tableScanNode(rightAssignments);
+
+ FilterNode left = filter(leftScan, equals(AV, bigintLiteral(10)));
+
+ // predicates on "a" column should be propagated to output symbols via join equi conditions
+ PlanNode node = new JoinNode(newId(),
+ JoinNode.Type.INNER,
+ left,
+ rightScan,
+ ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV)),
+ ImmutableList.builder()
+ .addAll(rightScan.getOutputVariables())
+ .build(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ assertEquals(
+ normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(equals(DV, bigintLiteral(10))));
+ }
+
+ @Test
+ public void testInnerJoinWithFalseFilter()
+ {
+ Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV)));
+ TableScanNode leftScan = tableScanNode(leftAssignments);
+
+ Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV)));
+ TableScanNode rightScan = tableScanNode(rightAssignments);
+
+ PlanNode node = new JoinNode(newId(),
+ JoinNode.Type.INNER,
+ leftScan,
+ rightScan,
+ ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV)),
+ ImmutableList.builder()
+ .addAll(leftScan.getOutputVariables())
+ .addAll(rightScan.getOutputVariables())
+ .build(),
+ Optional.of(FALSE_CONSTANT),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ assertEquals(effectivePredicate, FALSE_CONSTANT);
+ }
+
+ @Test
+ public void testLeftJoin()
+ {
+ ImmutableList.Builder criteriaBuilder = ImmutableList.builder();
+ criteriaBuilder.add(new JoinNode.EquiJoinClause(AV, DV));
+ criteriaBuilder.add(new JoinNode.EquiJoinClause(BV, EV));
+ List criteria = criteriaBuilder.build();
+
+ Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV)));
+ TableScanNode leftScan = tableScanNode(leftAssignments);
+
+ Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV)));
+ TableScanNode rightScan = tableScanNode(rightAssignments);
+
+ FilterNode left = filter(leftScan,
+ and(
+ lessThan(BV, AV),
+ lessThan(CV, bigintLiteral(10)),
+ equals(GV, bigintLiteral(10))));
+ FilterNode right = filter(rightScan,
+ and(
+ equals(DV, EV),
+ lessThan(FV, bigintLiteral(100))));
+ PlanNode node = new JoinNode(newId(),
+ JoinNode.Type.LEFT,
+ left,
+ right,
+ criteria,
+ ImmutableList.builder()
+ .addAll(left.getOutputVariables())
+ .addAll(right.getOutputVariables())
+ .build(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // All right side symbols having output symbols should be checked against NULL
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(lessThan(BV, AV),
+ lessThan(CV, bigintLiteral(10)),
+ or(equals(DV, EV), and(isNull(DV), isNull(EV))),
+ or(lessThan(FV, bigintLiteral(100)), isNull(FV)),
+ or(equals(AV, DV), isNull(DV)),
+ or(equals(BV, EV), isNull(EV))));
+ }
+
+ @Test
+ public void testLeftJoinWithFalseInner()
+ {
+ List criteria = ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV));
+
+ Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV)));
+ TableScanNode leftScan = tableScanNode(leftAssignments);
+
+ Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV)));
+ TableScanNode rightScan = tableScanNode(rightAssignments);
+
+ FilterNode left = filter(leftScan,
+ and(
+ lessThan(BV, AV),
+ lessThan(CV, bigintLiteral(10)),
+ equals(GV, bigintLiteral(10))));
+ FilterNode right = filter(rightScan, FALSE_CONSTANT);
+ PlanNode node = new JoinNode(newId(),
+ JoinNode.Type.LEFT,
+ left,
+ right,
+ criteria,
+ ImmutableList.builder()
+ .addAll(left.getOutputVariables())
+ .addAll(right.getOutputVariables())
+ .build(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // False literal on the right side should be ignored
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(lessThan(BV, AV),
+ lessThan(CV, bigintLiteral(10)),
+ or(equals(AV, DV), isNull(DV))));
+ }
+
+ @Test
+ public void testRightJoin()
+ {
+ ImmutableList.Builder criteriaBuilder = ImmutableList.builder();
+ criteriaBuilder.add(new JoinNode.EquiJoinClause(AV, DV));
+ criteriaBuilder.add(new JoinNode.EquiJoinClause(BV, EV));
+ List criteria = criteriaBuilder.build();
+
+ Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV)));
+ TableScanNode leftScan = tableScanNode(leftAssignments);
+
+ Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV)));
+ TableScanNode rightScan = tableScanNode(rightAssignments);
+
+ FilterNode left = filter(leftScan,
+ and(
+ lessThan(BV, AV),
+ lessThan(CV, bigintLiteral(10)),
+ equals(GV, bigintLiteral(10))));
+ FilterNode right = filter(rightScan,
+ and(
+ equals(DV, EV),
+ lessThan(FV, bigintLiteral(100))));
+ PlanNode node = new JoinNode(newId(),
+ JoinNode.Type.RIGHT,
+ left,
+ right,
+ criteria,
+ ImmutableList.builder()
+ .addAll(left.getOutputVariables())
+ .addAll(right.getOutputVariables())
+ .build(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // All left side symbols should be checked against NULL
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(or(lessThan(BV, AV), and(isNull(BV), isNull(AV))),
+ or(lessThan(CV, bigintLiteral(10)), isNull(CV)),
+ equals(DV, EV),
+ lessThan(FV, bigintLiteral(100)),
+ or(equals(AV, DV), isNull(AV)),
+ or(equals(BV, EV), isNull(BV))));
+ }
+
+ @Test
+ public void testRightJoinWithFalseInner()
+ {
+ List criteria = ImmutableList.of(new JoinNode.EquiJoinClause(AV, DV));
+
+ Map leftAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(AV, BV, CV)));
+ TableScanNode leftScan = tableScanNode(leftAssignments);
+
+ Map rightAssignments = Maps.filterKeys(scanAssignments, Predicates.in(ImmutableList.of(DV, EV, FV)));
+ TableScanNode rightScan = tableScanNode(rightAssignments);
+
+ FilterNode left = filter(leftScan, FALSE_CONSTANT);
+ FilterNode right = filter(rightScan,
+ and(
+ equals(DV, EV),
+ lessThan(FV, bigintLiteral(100))));
+ PlanNode node = new JoinNode(newId(),
+ JoinNode.Type.RIGHT,
+ left,
+ right,
+ criteria,
+ ImmutableList.builder()
+ .addAll(left.getOutputVariables())
+ .addAll(right.getOutputVariables())
+ .build(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // False literal on the left side should be ignored
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(equals(DV, EV),
+ lessThan(FV, bigintLiteral(100)),
+ or(equals(AV, DV), isNull(AV))));
+ }
+
+ @Test
+ public void testSemiJoin()
+ {
+ PlanNode node = new SemiJoinNode(newId(),
+ filter(baseTableScan, and(greaterThan(AV, bigintLiteral(10)), lessThan(AV, bigintLiteral(100)))),
+ filter(baseTableScan, greaterThan(AV, bigintLiteral(5))),
+ AV, BV, CV,
+ Optional.empty(),
+ Optional.empty(),
+ Optional.empty());
+
+ RowExpression effectivePredicate = effectivePredicateExtractor.extract(node);
+
+ // Currently, only pull predicates through the source plan
+ assertEquals(normalizeConjuncts(effectivePredicate),
+ normalizeConjuncts(and(greaterThan(AV, bigintLiteral(10)), lessThan(AV, bigintLiteral(100)))));
+ }
+
+ private static TableScanNode tableScanNode(Map scanAssignments)
+ {
+ return new TableScanNode(
+ newId(),
+ DUAL_TABLE_HANDLE,
+ ImmutableList.copyOf(scanAssignments.keySet()),
+ scanAssignments,
+ TupleDomain.all(),
+ TupleDomain.all());
+ }
+
+ private static PlanNodeId newId()
+ {
+ return new PlanNodeId(UUID.randomUUID().toString());
+ }
+
+ private static FilterNode filter(PlanNode source, RowExpression predicate)
+ {
+ return new FilterNode(newId(), source, predicate);
+ }
+
+ private static RowExpression bigintLiteral(long number)
+ {
+ return constant(number, BIGINT);
+ }
+
+ private RowExpression equals(RowExpression expression1, RowExpression expression2)
+ {
+ return compare(EQUAL, expression1, expression2);
+ }
+
+ private RowExpression lessThan(RowExpression expression1, RowExpression expression2)
+ {
+ return compare(LESS_THAN, expression1, expression2);
+ }
+
+ private RowExpression lessThanOrEqual(RowExpression expression1, RowExpression expression2)
+ {
+ return compare(LESS_THAN_OR_EQUAL, expression1, expression2);
+ }
+
+ private RowExpression greaterThan(RowExpression expression1, RowExpression expression2)
+ {
+ return compare(GREATER_THAN, expression1, expression2);
+ }
+
+ private RowExpression compare(OperatorType type, RowExpression left, RowExpression right)
+ {
+ return call(
+ type.getFunctionName().getSuffix(),
+ metadata.getFunctionManager().resolveOperator(type, fromTypes(left.getType(), right.getType())),
+ BOOLEAN,
+ left,
+ right);
+ }
+
+ private static RowExpression isNull(RowExpression expression)
+ {
+ return specialForm(IS_NULL, BOOLEAN, expression);
+ }
+
+ private Set normalizeConjuncts(RowExpression... conjuncts)
+ {
+ return normalizeConjuncts(Arrays.asList(conjuncts));
+ }
+
+ private Set normalizeConjuncts(Collection conjuncts)
+ {
+ return normalizeConjuncts(logicalRowExpressions.combineConjuncts(conjuncts));
+ }
+
+ private Set normalizeConjuncts(RowExpression predicate)
+ {
+ // Normalize the predicate by identity so that the EqualityInference will produce stable rewrites in this test
+ // and thereby produce comparable Sets of conjuncts from this method.
+
+ // Equality inference rewrites and equality generation will always be stable across multiple runs in the same JVM
+ RowExpressionEqualityInference inference = RowExpressionEqualityInference.createEqualityInference(metadata, predicate);
+
+ Set rewrittenSet = new HashSet<>();
+ for (RowExpression expression : nonInferrableConjuncts(metadata, predicate)) {
+ RowExpression rewritten = inference.rewriteExpression(expression, Predicates.alwaysTrue());
+ Preconditions.checkState(rewritten != null, "Rewrite with full symbol scope should always be possible");
+ rewrittenSet.add(rewritten);
+ }
+ rewrittenSet.addAll(inference.generateEqualitiesPartitionedBy(Predicates.alwaysTrue()).getScopeEqualities());
+
+ return rewrittenSet;
+ }
+}