diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/LogicalPropertiesImpl.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/LogicalPropertiesImpl.java index a699a35133fec..01336be7a3a60 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/LogicalPropertiesImpl.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/properties/LogicalPropertiesImpl.java @@ -190,13 +190,7 @@ private boolean keyRequirementSatisfied(Key keyRequirement) if (maxCardProperty.isAtMostOne()) { return true; } - Optional normalizedKeyRequirement = getNormalizedKey(keyRequirement, equivalenceClassProperty); - if (normalizedKeyRequirement.isPresent()) { - return keyProperty.satisfiesKeyRequirement(keyRequirement); - } - else { - return false; - } + return getNormalizedKey(keyRequirement, equivalenceClassProperty).filter(keyProperty::satisfiesKeyRequirement).isPresent(); } @Override diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLogicalPropertyPropagation.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLogicalPropertyPropagation.java index 3759e5018adef..44a7aa7ffe9fd 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLogicalPropertyPropagation.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestLogicalPropertyPropagation.java @@ -22,6 +22,7 @@ import com.facebook.presto.spi.constraints.PrimaryKeyConstraint; import com.facebook.presto.spi.constraints.TableConstraint; import com.facebook.presto.spi.constraints.UniqueConstraint; +import com.facebook.presto.spi.plan.AggregationNode; import com.facebook.presto.spi.plan.Assignments; import com.facebook.presto.spi.plan.EquiJoinClause; import com.facebook.presto.spi.plan.FilterNode; @@ -71,6 +72,7 @@ import static com.facebook.presto.sql.relational.Expressions.constant; import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Collections.emptyList; +import static org.testng.Assert.assertTrue; public class TestLogicalPropertyPropagation extends BaseRuleTest @@ -200,6 +202,58 @@ void testValuesNodeLogicalProperties() .matches(expectedLogicalProperties); } + @Test + public void testKeyNormalization() + { + tester().assertThat(new NoOpRule(), logicalPropertiesProvider) + .on(p -> { + TableScanNode customerTableScan = p.tableScan( + customerTableHandle, + ImmutableList.of(customerCustKeyVariable), + ImmutableMap.of(customerCustKeyVariable, customerCustKeyColumn), + TupleDomain.none(), + TupleDomain.none(), + tester().getTableConstraints(customerTableHandle)); + + TableScanNode ordersTableScan = p.tableScan( + ordersTableHandle, + ImmutableList.of(ordersCustKeyVariable), + ImmutableMap.of(ordersCustKeyVariable, ordersCustKeyColumn), + TupleDomain.none(), + TupleDomain.none(), + tester().getTableConstraints(ordersTableHandle)); + + TableScanNode lineitemTableScan = p.tableScan( + lineitemTableHandle, + ImmutableList.of(lineitemOrderkeyVariable), + ImmutableMap.of(lineitemOrderkeyVariable, lineitemOrderkeyColumn), + TupleDomain.none(), + TupleDomain.none(), + tester().getTableConstraints(lineitemTableHandle)); + + JoinNode ordersCustomerJoin = p.join(JoinType.INNER, + ordersTableScan, + customerTableScan, + new EquiJoinClause(ordersCustKeyVariable, customerCustKeyVariable)); + + AggregationNode aggregation = p.aggregation(builder -> builder + .singleGroupingSet(ordersCustKeyVariable) + .source(p.join(JoinType.INNER, + ordersCustomerJoin, + lineitemTableScan, + new EquiJoinClause(customerCustKeyVariable, lineitemOrderkeyVariable)))); + return aggregation; + }).assertLogicalProperties(groupProperties -> { + // SINGLE aggregation on ordersCustKeyVariable => this is a key + assertTrue(groupProperties.isDistinct(ImmutableSet.of(ordersCustKeyVariable))); + // Since ordersCustKeyVariable == customerCustKeyVariable, customerCustKeyVariable is a key as well + // This is derived through the equivalence classes + assertTrue(groupProperties.isDistinct(ImmutableSet.of(customerCustKeyVariable))); + // Same holds true for lineitemOrderkeyVariable + assertTrue(groupProperties.isDistinct(ImmutableSet.of(lineitemOrderkeyVariable))); + }); + } + @Test public void testTableScanNodeLogicalProperties() { diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java index 8bfca678a940d..34cecae8d8c23 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/test/RuleAssert.java @@ -50,6 +50,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Stream; @@ -58,6 +59,7 @@ import static com.facebook.presto.sql.planner.planPrinter.PlanPrinter.textLogicalPlan; import static com.facebook.presto.transaction.TransactionBuilder.transaction; import static com.google.common.base.Preconditions.checkState; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static org.testng.Assert.fail; @@ -131,7 +133,7 @@ public PlanNode get() TypeProvider types = ruleApplication.types; if (!ruleApplication.wasRuleApplied()) { - fail(String.format( + fail(format( "%s did not fire for:\n%s", rule.getClass().getName(), formatPlan(plan, types))); @@ -145,7 +147,7 @@ public void doesNotFire() RuleApplication ruleApplication = applyRule(); if (ruleApplication.wasRuleApplied()) { - fail(String.format( + fail(format( "Expected %s to not fire for:\n%s", rule.getClass().getName(), inTransaction(session -> textLogicalPlan(plan, ruleApplication.types, StatsAndCosts.empty(), metadata.getFunctionAndTypeManager(), session, 2)))); @@ -158,7 +160,7 @@ public void matches(PlanMatchPattern pattern) TypeProvider types = ruleApplication.types; if (!ruleApplication.wasRuleApplied()) { - fail(String.format( + fail(format( "%s did not fire for:\n%s", rule.getClass().getName(), formatPlan(plan, types))); @@ -167,14 +169,14 @@ public void matches(PlanMatchPattern pattern) PlanNode actual = ruleApplication.getTransformedPlan(); if (actual == plan) { // plans are not comparable, so we can only ensure they are not the same instance - fail(String.format( + fail(format( "%s: rule fired but return the original plan:\n%s", rule.getClass().getName(), formatPlan(plan, types))); } if (!ImmutableSet.copyOf(plan.getOutputVariables()).equals(ImmutableSet.copyOf(actual.getOutputVariables()))) { - fail(String.format( + fail(format( "%s: output schema of transformed and original plans are not equivalent\n" + "\texpected: %s\n" + "\tactual: %s", @@ -189,28 +191,35 @@ public void matches(PlanMatchPattern pattern) }); } - public void matches(LogicalProperties expectedLogicalProperties) + public void assertLogicalProperties(Consumer matcher) { RuleApplication ruleApplication = applyRule(); TypeProvider types = ruleApplication.types; if (!ruleApplication.wasRuleApplied()) { - fail(String.format( + fail(format( "%s did not fire for:\n%s", rule.getClass().getName(), formatPlan(plan, types))); } - // ensure that the logical properties of the root group are equivalent to the expected logical properties LogicalProperties rootNodeLogicalProperties = ruleApplication.getMemo().getLogicalProperties(ruleApplication.getMemo().getRootGroup()).get(); - if (!((LogicalPropertiesImpl) rootNodeLogicalProperties).equals((LogicalPropertiesImpl) expectedLogicalProperties)) { - fail(String.format( - "Logical properties of root node doesn't match expected logical properties\n" + - "\texpected: %s\n" + - "\tactual: %s", - expectedLogicalProperties, - rootNodeLogicalProperties)); - } + matcher.accept(rootNodeLogicalProperties); + } + + public void matches(LogicalProperties expectedLogicalProperties) + { + // Ensure that the logical properties of the root group are equivalent to the expected logical properties + assertLogicalProperties(rootNodeLogicalProperties -> { + if (!((LogicalPropertiesImpl) rootNodeLogicalProperties).equals((LogicalPropertiesImpl) expectedLogicalProperties)) { + fail(format( + "Logical properties of root node doesn't match expected logical properties\n" + + "\texpected: %s\n" + + "\tactual: %s", + expectedLogicalProperties, + rootNodeLogicalProperties)); + } + }); } private RuleApplication applyRule()