From 52819a60af85b0bb81b75d31ebe78d5c7ce29201 Mon Sep 17 00:00:00 2001 From: Feilong Liu Date: Mon, 2 Jun 2025 16:16:24 -0700 Subject: [PATCH] Add an optimizer to add distinct below build side of semi join --- .../presto/SystemSessionProperties.java | 12 +- .../presto/sql/analyzer/FeaturesConfig.java | 14 ++ .../presto/sql/planner/PlanOptimizers.java | 7 + .../rule/AddDistinctForSemiJoinBuild.java | 95 +++++++++++ .../sql/analyzer/TestFeaturesConfig.java | 3 + .../rule/TestAddDistinctForSemiJoinBuild.java | 160 ++++++++++++++++++ .../presto/tests/AbstractTestQueries.java | 51 ++++++ 7 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddDistinctForSemiJoinBuild.java create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddDistinctForSemiJoinBuild.java diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index 4c5627e629ee8..7cf2891231282 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -332,6 +332,7 @@ public final class SystemSessionProperties public static final String ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID = "add_exchange_below_partial_aggregation_over_group_id"; public static final String QUERY_CLIENT_TIMEOUT = "query_client_timeout"; public static final String REWRITE_MIN_MAX_BY_TO_TOP_N = "rewrite_min_max_by_to_top_n"; + public static final String ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD = "add_distinct_below_semi_join_build"; // TODO: Native execution related session properties that are temporarily put here. They will be relocated in the future. public static final String NATIVE_AGGREGATION_SPILL_ALL = "native_aggregation_spill_all"; @@ -1906,7 +1907,11 @@ public SystemSessionProperties( queryManagerConfig.getClientTimeout(), false, value -> Duration.valueOf((String) value), - Duration::toString)); + Duration::toString), + booleanProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, + "Add distinct aggregation below semi join build", + featuresConfig.isAddDistinctBelowSemiJoinBuild(), + false)); } public static boolean isSpoolingOutputBufferEnabled(Session session) @@ -3238,6 +3243,11 @@ public static boolean isEnabledAddExchangeBelowGroupId(Session session) return session.getSystemProperty(ADD_EXCHANGE_BELOW_PARTIAL_AGGREGATION_OVER_GROUP_ID, Boolean.class); } + public static boolean isAddDistinctBelowSemiJoinBuildEnabled(Session session) + { + return session.getSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, Boolean.class); + } + public static boolean isCanonicalizedJsonExtract(Session session) { return session.getSystemProperty(CANONICALIZED_JSON_EXTRACT, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 9cd78bd997f4c..ac1983b87353c 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -303,6 +303,7 @@ public class FeaturesConfig private boolean nativeExecutionTypeRewriteEnabled; private String expressionOptimizerName = DEFAULT_EXPRESSION_OPTIMIZER_NAME; private boolean addExchangeBelowPartialAggregationOverGroupId; + private boolean addDistinctBelowSemiJoinBuild; public enum PartitioningPrecisionStrategy { @@ -3014,4 +3015,17 @@ public boolean getAddExchangeBelowPartialAggregationOverGroupId() { return addExchangeBelowPartialAggregationOverGroupId; } + + @Config("optimizer.add-distinct-below-semi-join-build") + @ConfigDescription("Add a distinct aggregation below build side of semi join") + public FeaturesConfig setAddDistinctBelowSemiJoinBuild(boolean addDistinctBelowSemiJoinBuild) + { + this.addDistinctBelowSemiJoinBuild = addDistinctBelowSemiJoinBuild; + return this; + } + + public boolean isAddDistinctBelowSemiJoinBuild() + { + return addDistinctBelowSemiJoinBuild; + } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index 3dfca1a161312..b017a7e1c8247 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -28,6 +28,7 @@ import com.facebook.presto.sql.planner.iterative.IterativeOptimizer; import com.facebook.presto.sql.planner.iterative.Rule; import com.facebook.presto.sql.planner.iterative.properties.LogicalPropertiesProviderImpl; +import com.facebook.presto.sql.planner.iterative.rule.AddDistinctForSemiJoinBuild; import com.facebook.presto.sql.planner.iterative.rule.AddExchangesBelowPartialAggregationOverGroupIdRuleSet; import com.facebook.presto.sql.planner.iterative.rule.AddIntermediateAggregations; import com.facebook.presto.sql.planner.iterative.rule.AddNotNullFiltersToJoinNode; @@ -593,6 +594,12 @@ public PlanOptimizers( statsCalculator, estimatedExchangesCostCalculator, ImmutableSet.of(new LeftJoinNullFilterToSemiJoin(metadata.getFunctionAndTypeManager()))), + new IterativeOptimizer( + metadata, + ruleStats, + statsCalculator, + estimatedExchangesCostCalculator, + ImmutableSet.of(new AddDistinctForSemiJoinBuild())), new KeyBasedSampler(metadata), new IterativeOptimizer( metadata, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddDistinctForSemiJoinBuild.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddDistinctForSemiJoinBuild.java new file mode 100644 index 0000000000000..fd6a471f7e078 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/AddDistinctForSemiJoinBuild.java @@ -0,0 +1,95 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.plan.SemiJoinNode; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.isAddDistinctBelowSemiJoinBuildEnabled; +import static com.facebook.presto.spi.plan.AggregationNode.isDistinct; +import static com.facebook.presto.spi.plan.AggregationNode.singleGroupingSet; +import static com.facebook.presto.sql.planner.plan.Patterns.semiJoin; + +public class AddDistinctForSemiJoinBuild + implements Rule +{ + @Override + public Pattern getPattern() + { + return semiJoin(); + } + + @Override + public boolean isEnabled(Session session) + { + return isAddDistinctBelowSemiJoinBuildEnabled(session); + } + + @Override + public Result apply(SemiJoinNode node, Captures captures, Context context) + { + PlanNode filterSource = context.getLookup().resolve(node.getFilteringSource()); + VariableReferenceExpression filteringSourceVariable = node.getFilteringSourceJoinVariable(); + if (isOutputDistinct(filterSource, filteringSourceVariable, context)) { + return Result.empty(); + } + AggregationNode.GroupingSetDescriptor groupingSetDescriptor = singleGroupingSet(ImmutableList.of(node.getFilteringSourceJoinVariable())); + AggregationNode distinctAggregation = new AggregationNode( + node.getSourceLocation(), + context.getIdAllocator().getNextId(), + filterSource, + ImmutableMap.of(), + groupingSetDescriptor, + ImmutableList.of(), + AggregationNode.Step.SINGLE, + Optional.empty(), + Optional.empty(), + Optional.empty()); + + return Result.ofPlanNode(node.replaceChildren(ImmutableList.of(node.getSource(), distinctAggregation))); + } + + boolean isOutputDistinct(PlanNode node, VariableReferenceExpression output, Context context) + { + if (node instanceof AggregationNode) { + AggregationNode aggregationNode = (AggregationNode) node; + return isDistinct(aggregationNode) && aggregationNode.getGroupingKeys().size() == 1 && aggregationNode.getGroupingKeys().contains(output); + } + else if (node instanceof ProjectNode) { + ProjectNode projectNode = (ProjectNode) node; + RowExpression inputExpression = projectNode.getAssignments().get(output); + if (inputExpression instanceof VariableReferenceExpression) { + return isOutputDistinct(context.getLookup().resolve(projectNode.getSource()), (VariableReferenceExpression) inputExpression, context); + } + return false; + } + else if (node instanceof FilterNode) { + return isOutputDistinct(context.getLookup().resolve(((FilterNode) node).getSource()), output, context); + } + return false; + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index ac534e726be16..4cc65d98f5a10 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -257,6 +257,7 @@ public void testDefaults() .setExpressionOptimizerName("default") .setExcludeInvalidWorkerSessionProperties(false) .setAddExchangeBelowPartialAggregationOverGroupId(false) + .setAddDistinctBelowSemiJoinBuild(false) .setInnerJoinPushdownEnabled(false) .setInEqualityJoinPushdownEnabled(false) .setRewriteMinMaxByToTopNEnabled(false) @@ -467,6 +468,7 @@ public void testExplicitPropertyMappings() .put("enhanced-cte-scheduling-enabled", "false") .put("expression-optimizer-name", "custom") .put("exclude-invalid-worker-session-properties", "true") + .put("optimizer.add-distinct-below-semi-join-build", "true") .put("optimizer.add-exchange-below-partial-aggregation-over-group-id", "true") .build(); @@ -670,6 +672,7 @@ public void testExplicitPropertyMappings() .setExpressionOptimizerName("custom") .setExcludeInvalidWorkerSessionProperties(true) .setAddExchangeBelowPartialAggregationOverGroupId(true) + .setAddDistinctBelowSemiJoinBuild(true) .setInEqualityJoinPushdownEnabled(true) .setRewriteMinMaxByToTopNEnabled(true) .setInnerJoinPushdownEnabled(true) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddDistinctForSemiJoinBuild.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddDistinctForSemiJoinBuild.java new file mode 100644 index 0000000000000..da2ead84c3d8b --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestAddDistinctForSemiJoinBuild.java @@ -0,0 +1,160 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static com.facebook.presto.SystemSessionProperties.ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.semiJoin; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; +import static com.facebook.presto.sql.planner.iterative.rule.test.PlanBuilder.assignment; + +public class TestAddDistinctForSemiJoinBuild + extends BaseRuleTest +{ + @Test + public void testTrigger() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.values(filteringSourceJoinVariable)); + }).matches( + semiJoin( + "sourceJoinVariable", + "filteringSourceJoinVariable", + "semiJoinOutput", + values("sourceJoinVariable"), + aggregation( + singleGroupingSet("filteringSourceJoinVariable"), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.SINGLE, + values("filteringSourceJoinVariable")))); + } + + @Test + public void testTriggerOverNonQualifiedDistinctAggregation() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + VariableReferenceExpression col1 = p.variable("col1"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.aggregation((a) -> a + .singleGroupingSet(filteringSourceJoinVariable, col1) + .step(AggregationNode.Step.SINGLE) + .source(p.values(filteringSourceJoinVariable, col1)))); + }).matches( + semiJoin( + "sourceJoinVariable", + "filteringSourceJoinVariable", + "semiJoinOutput", + values("sourceJoinVariable"), + aggregation( + singleGroupingSet("filteringSourceJoinVariable"), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.SINGLE, + aggregation( + singleGroupingSet("filteringSourceJoinVariable", "col1"), + ImmutableMap.of(), + ImmutableMap.of(), + Optional.empty(), + AggregationNode.Step.SINGLE, + values("filteringSourceJoinVariable", "col1"))))); + } + + @Test + public void testNotTriggerOverDistinct() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.aggregation((a) -> a + .singleGroupingSet(filteringSourceJoinVariable) + .step(AggregationNode.Step.SINGLE) + .source(p.values(filteringSourceJoinVariable)))); + }).doesNotFire(); + } + + @Test + public void testNotTriggerOverDistinctUnderProject() + { + tester().assertThat(new AddDistinctForSemiJoinBuild()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .on(p -> + { + VariableReferenceExpression sourceJoinVariable = p.variable("sourceJoinVariable"); + VariableReferenceExpression filteringSourceJoinVariable = p.variable("filteringSourceJoinVariable"); + VariableReferenceExpression semiJoinOutput = p.variable("semiJoinOutput"); + VariableReferenceExpression col1 = p.variable("col1"); + return p.semiJoin( + sourceJoinVariable, + filteringSourceJoinVariable, + semiJoinOutput, + Optional.empty(), + Optional.empty(), + p.values(sourceJoinVariable), + p.project( + assignment(filteringSourceJoinVariable, p.rowExpression("col1")), + p.aggregation((a) -> a + .singleGroupingSet(col1) + .step(AggregationNode.Step.SINGLE) + .source(p.values(col1))))); + }).doesNotFire(); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 87acd583c63a1..f2652436e8ec0 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -51,6 +51,7 @@ import java.util.regex.Pattern; import java.util.stream.IntStream; +import static com.facebook.presto.SystemSessionProperties.ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD; import static com.facebook.presto.SystemSessionProperties.ADD_PARTIAL_NODE_FOR_ROW_NUMBER_WITH_LIMIT; import static com.facebook.presto.SystemSessionProperties.ENABLE_INTERMEDIATE_AGGREGATIONS; import static com.facebook.presto.SystemSessionProperties.FIELD_NAMES_IN_JSON_CAST_ENABLED; @@ -7989,6 +7990,56 @@ public void testRemoveCrossJoinWithSingleRowConstantInput() assertQuery(enableOptimization, "select orderkey, col1 from orders cross join (select cast(col as varchar) col1 from (values 1) t(col))"); } + @Test + public void testAddDistinctForSemiJoinBuild() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "true") + .build(); + Session disabled = Session.builder(getSession()) + .setSystemProperty(ADD_DISTINCT_BELOW_SEMI_JOIN_BUILD, "false") + .build(); + @Language("SQL") String sql = "SELECT * FROM customer c WHERE custkey in ( SELECT custkey FROM orders o WHERE o.orderdate > date('1995-01-01'))"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT * FROM customer c WHERE custkey in ( SELECT distinct custkey FROM orders o WHERE o.orderdate > date('1995-01-01'))"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT *\n" + + "FROM customer c\n" + + "WHERE c.custkey IN (\n" + + " SELECT o.custkey\n" + + " FROM orders o\n" + + " WHERE o.totalprice > 1000\n" + + ")"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT c.name\n" + + "FROM customer c\n" + + "WHERE c.custkey IN (\n" + + " SELECT o.custkey\n" + + " FROM orders o\n" + + ")"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT s.name\n" + + "FROM supplier s\n" + + "WHERE s.suppkey IN (\n" + + " SELECT l.suppkey\n" + + " FROM lineitem l\n" + + ")"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT c.name FROM customer c WHERE c.custkey IN ( SELECT o.custkey FROM orders o JOIN lineitem l ON o.orderkey = l.orderkey WHERE l.partkey > 1235 )"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + sql = "SELECT p.name\n" + + "FROM part p\n" + + "WHERE p.partkey IN (\n" + + " SELECT l.partkey\n" + + " FROM lineitem l\n" + + " JOIN orders o ON l.orderkey = o.orderkey\n" + + " JOIN customer c ON o.custkey = c.custkey\n" + + " JOIN nation n ON c.nationkey = n.nationkey\n" + + " WHERE n.name = 'UNITED STATES'\n" + + ")"; + assertQueryWithSameQueryRunner(enabled, sql, disabled); + } + /** * When optimize_hash_generation is enabled, the "hash_code" operator is used for * hashing join/group by values. When it is disabled Type.hash() is used.