diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index 6d59895ce5cd..f694f4fce205 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -37,6 +37,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -120,6 +121,9 @@ public PlanWithConsumedDynamicFilters visitJoin(JoinNode node, Set allow PlanNode left = leftResult.getNode(); PlanNode right = rightResult.getNode(); if (!left.equals(node.getLeft()) || !right.equals(node.getRight()) || !dynamicFilters.equals(node.getDynamicFilters())) { + Optional filter = node + .getFilter().map(this::removeAllDynamicFilters) // no DF support at Join operators. + .filter(expression -> !expression.equals(TRUE_LITERAL)); return new PlanWithConsumedDynamicFilters(new JoinNode( node.getId(), node.getType(), @@ -127,7 +131,7 @@ public PlanWithConsumedDynamicFilters visitJoin(JoinNode node, Set allow right, node.getCriteria(), node.getOutputSymbols(), - node.getFilter(), + filter, node.getLeftHashSymbol(), node.getRightHashSymbol(), node.getDistributionType(), diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java index 789d26644f20..59c89f069f7d 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/sanity/DynamicFiltersChecker.java @@ -13,6 +13,7 @@ */ package io.prestosql.sql.planner.sanity; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import io.prestosql.Session; import io.prestosql.execution.warnings.WarningCollector; @@ -72,12 +73,21 @@ public Set visitJoin(JoinNode node, Void context) { Set currentJoinDynamicFilters = node.getDynamicFilters().keySet(); Set consumedProbeSide = node.getLeft().accept(this, context); - verify(difference(currentJoinDynamicFilters, consumedProbeSide).isEmpty(), - "Dynamic filters present in join were not fully consumed by it's probe side."); + Set unconsumedByProbeSide = difference(currentJoinDynamicFilters, consumedProbeSide); + verify(unconsumedByProbeSide.isEmpty(), + "Dynamic filters %s present in join were not fully consumed by it's probe side.", unconsumedByProbeSide); Set consumedBuildSide = node.getRight().accept(this, context); - verify(intersection(currentJoinDynamicFilters, consumedBuildSide).isEmpty(), - "Dynamic filters present in join were consumed by it's build side."); + Set unconsumedByBuildSide = intersection(currentJoinDynamicFilters, consumedBuildSide); + verify(unconsumedByBuildSide.isEmpty(), + "Dynamic filters %s present in join were consumed by it's build side.", unconsumedByBuildSide); + + List nonPushedDownFilters = node + .getFilter() + .map(DynamicFilters::extractDynamicFilters) + .map(DynamicFilters.ExtractResult::getDynamicConjuncts) + .orElse(ImmutableList.of()); + verify(nonPushedDownFilters.isEmpty(), "Dynamic filters %s present in join filter predicate were not pushed down.", nonPushedDownFilters); Set unmatched = new HashSet<>(consumedBuildSide); unmatched.addAll(consumedProbeSide); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java index 7f6d5a509866..f007328af249 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestDynamicFilter.java @@ -35,6 +35,7 @@ import static io.prestosql.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.equiJoinClause; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.exchange; +import static io.prestosql.sql.planner.assertions.PlanMatchPattern.expression; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.filter; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.join; import static io.prestosql.sql.planner.assertions.PlanMatchPattern.node; @@ -259,6 +260,34 @@ public void testNestedDynamicFiltersRemoval() tableScan("orders", ImmutableMap.of("ORDERS_CK27", "clerk"))))), metadata))); } + @Test + public void testNonPushedDownJoinFilterRemoval() + { + assertPlan( + "SELECT 1 FROM part t0, part t1, part t2 " + + "WHERE t0.partkey = t1.partkey AND t0.partkey = t2.partkey " + + "AND t0.size + t1.size = t2.size", + noJoinReordering(), + anyTree( + join(INNER, + ImmutableList.of(equiJoinClause("K0", "K2"), equiJoinClause("S", "V2")), + project( + project(ImmutableMap.of("S", expression("V0 + V1")), + join( + INNER, + ImmutableList.of(equiJoinClause("K0", "K1")), + project( + node(FilterNode.class, + tableScan("part", ImmutableMap.of("K0", "partkey", "V0", "size")))), + exchange( + project( + node(FilterNode.class, + tableScan("part", ImmutableMap.of("K1", "partkey", "V1", "size")))))))), + exchange( + project( + tableScan("part", ImmutableMap.of("K2", "partkey", "V2", "size"))))))); + } + private Session noJoinReordering() { return Session.builder(getQueryRunner().getDefaultSession()) diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestDynamicFiltersChecker.java b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestDynamicFiltersChecker.java index c68ae573d0e5..e39395faf63c 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestDynamicFiltersChecker.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/sanity/TestDynamicFiltersChecker.java @@ -77,7 +77,7 @@ public void setup() ordersTableScanNode = builder.tableScan(ordersTableHandle, ImmutableList.of(ordersOrderKeySymbol), ImmutableMap.of(ordersOrderKeySymbol, new TpchColumnHandle("orderkey", BIGINT))); } - @Test(expectedExceptions = VerifyException.class, expectedExceptionsMessageRegExp = "Dynamic filters present in join were not fully consumed by it's probe side.") + @Test(expectedExceptions = VerifyException.class, expectedExceptionsMessageRegExp = "Dynamic filters \\[DF\\] present in join were not fully consumed by it's probe side.") public void testUnconsumedDynamicFilterInJoin() { PlanNode root = builder.join( @@ -93,7 +93,7 @@ public void testUnconsumedDynamicFilterInJoin() validatePlan(root); } - @Test(expectedExceptions = VerifyException.class, expectedExceptionsMessageRegExp = "Dynamic filters present in join were consumed by it's build side.") + @Test(expectedExceptions = VerifyException.class, expectedExceptionsMessageRegExp = "Dynamic filters \\[DF\\] present in join were consumed by it's build side.") public void testDynamicFilterConsumedOnBuildSide() { PlanNode root = builder.join( diff --git a/presto-memory/src/test/java/io/prestosql/plugin/memory/TestMemorySmoke.java b/presto-memory/src/test/java/io/prestosql/plugin/memory/TestMemorySmoke.java index f91e1c64117d..8a8455f9d6bd 100644 --- a/presto-memory/src/test/java/io/prestosql/plugin/memory/TestMemorySmoke.java +++ b/presto-memory/src/test/java/io/prestosql/plugin/memory/TestMemorySmoke.java @@ -33,6 +33,7 @@ import static com.google.common.collect.ImmutableSet.toImmutableSet; import static io.prestosql.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING; import static io.prestosql.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; +import static io.prestosql.SystemSessionProperties.JOIN_REORDERING_STRATEGY; import static io.prestosql.testing.assertions.Assert.assertEquals; import static java.lang.String.format; import static org.testng.Assert.assertTrue; @@ -136,6 +137,25 @@ public void testJoinDynamicFilteringSingleValue() assertEquals(rowsRead, ImmutableSet.of(6L, buildSideRowsCount)); } + @Test + public void testJoinDynamicFilteringMultiJoin() + { + assertUpdate("CREATE TABLE t0 (k0 integer, v0 real)"); + assertUpdate("CREATE TABLE t1 (k1 integer, v1 real)"); + assertUpdate("CREATE TABLE t2 (k2 integer, v2 real)"); + assertUpdate("INSERT INTO t0 VALUES (1, 1.0)", 1); + assertUpdate("INSERT INTO t1 VALUES (1, 2.0)", 1); + assertUpdate("INSERT INTO t2 VALUES (1, 3.0)", 1); + + String query = "SELECT k0, k1, k2 FROM t0, t1, t2 WHERE (k0 = k1) AND (k0 = k2) AND (v0 + v1 = v2)"; + Session session = Session.builder(getSession()) + .setSystemProperty(ENABLE_DYNAMIC_FILTERING, "true") + .setSystemProperty(JOIN_DISTRIBUTION_TYPE, FeaturesConfig.JoinDistributionType.BROADCAST.name()) + .setSystemProperty(JOIN_REORDERING_STRATEGY, FeaturesConfig.JoinReorderingStrategy.NONE.name()) + .build(); + assertQuery(session, query, "SELECT 1, 1, 1"); + } + @Test public void testCreateTableWithNoData() {