diff --git a/presto-docs/src/main/sphinx/admin/properties-session.rst b/presto-docs/src/main/sphinx/admin/properties-session.rst index 5d1b00792402b..3af8ab1c0d1d7 100644 --- a/presto-docs/src/main/sphinx/admin/properties-session.rst +++ b/presto-docs/src/main/sphinx/admin/properties-session.rst @@ -401,6 +401,22 @@ performance by allowing the aggregation to pre-reduce data before the join is pe The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.push-partial-aggregation-through-join\`\``. +``push_projection_through_cross_join`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +When enabled, pushes projection expressions through cross join nodes so that each +expression is evaluated only on the side of the cross join that provides its input +variables. This reduces the number of columns flowing through the cross join and +avoids recomputing expressions on the multiplied output rows. + +Only deterministic expressions are pushed. Expressions that reference variables from +both sides of the cross join, or constant expressions, remain above the join. + +The corresponding configuration property is :ref:`admin/properties:\`\`optimizer.push-projection-through-cross-join\`\``. + ``push_table_write_through_union`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 274701aade9b0..ff90ef4297004 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -969,6 +969,22 @@ performance by allowing the aggregation to pre-reduce data before the join is pe The corresponding session property is :ref:`admin/properties-session:\`\`push_partial_aggregation_through_join\`\``. +``optimizer.push-projection-through-cross-join`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +When enabled, pushes projection expressions through cross join nodes so that each +expression is evaluated only on the side of the cross join that provides its input +variables. This reduces the number of columns flowing through the cross join and +avoids recomputing expressions on the multiplied output rows. + +Only deterministic expressions are pushed. Expressions that reference variables from +both sides of the cross join, or constant expressions, remain above the join. + +The corresponding session property is :ref:`admin/properties-session:\`\`push_projection_through_cross_join\`\``. + ``optimizer.push-table-write-through-union`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/BenchmarkPushProjectionThroughCrossJoin.java b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/BenchmarkPushProjectionThroughCrossJoin.java new file mode 100644 index 0000000000000..51a71f8aab1bb --- /dev/null +++ b/presto-hive/src/test/java/com/facebook/presto/hive/benchmark/BenchmarkPushProjectionThroughCrossJoin.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.hive.benchmark; + +import org.testng.annotations.Test; + +/** + * Benchmarks the PushProjectionThroughCrossJoin optimization. + * + *

Uses a real CROSS JOIN between lineitem and a subquery so that the + * plan produces a JoinNode with isCrossJoin(). CROSS JOIN UNNEST produces + * an UnnestNode instead, which this rule does not target. + * + *

Run via: + *

+ * mvn test -pl presto-hive \
+ *   -Dtest=BenchmarkPushProjectionThroughCrossJoin \
+ *   -DfailIfNoTests=false
+ * 
+ */ +public final class BenchmarkPushProjectionThroughCrossJoin +{ + private static final String QUERY = + "SELECT " + + " regexp_replace(l.comment, '[aeiou]', '*') AS l_redacted, " + + " regexp_extract(l.comment, '\\w+') AS l_first_word, " + + " upper(reverse(l.shipinstruct)) AS l_instruct, " + + " regexp_replace(n.comment, '[aeiou]', '*') AS n_redacted, " + + " upper(reverse(n.name)) AS n_reversed, " + + " length(l.comment) + n.nationkey AS mixed " + + "FROM lineitem l " + + "CROSS JOIN nation n"; + + @Test + public void benchmark() + throws Exception + { + try (HiveDistributedBenchmarkRunner runner = + new HiveDistributedBenchmarkRunner(3, 5)) { + runner.addScenario("baseline", builder -> { + builder.setSystemProperty("push_projection_through_cross_join", "false"); + }); + + runner.addScenario("push_projection_through_cross_join", builder -> { + builder.setSystemProperty("push_projection_through_cross_join", "true"); + }); + + runner.runWithVerification(QUERY); + } + } + + public static void main(String[] args) + throws Exception + { + new BenchmarkPushProjectionThroughCrossJoin().benchmark(); + } +} diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index 3723bc8e99017..27bb54ebc040f 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -179,6 +179,7 @@ public final class SystemSessionProperties public static final String SIMPLIFY_AGGREGATIONS_OVER_CONSTANT = "simplify_aggregations_over_constant"; public static final String PUSH_PARTIAL_AGGREGATION_THROUGH_JOIN = "push_partial_aggregation_through_join"; public static final String PRE_AGGREGATE_BEFORE_GROUPING_SETS = "pre_aggregate_before_grouping_sets"; + public static final String PUSH_PROJECTION_THROUGH_CROSS_JOIN = "push_projection_through_cross_join"; public static final String PARSE_DECIMAL_LITERALS_AS_DOUBLE = "parse_decimal_literals_as_double"; public static final String FORCE_SINGLE_NODE_OUTPUT = "force_single_node_output"; public static final String FILTER_AND_PROJECT_MIN_OUTPUT_PAGE_SIZE = "filter_and_project_min_output_page_size"; @@ -961,6 +962,11 @@ public SystemSessionProperties( "Pre-aggregate data before GroupId node to reduce row multiplication in grouping sets queries", featuresConfig.isPreAggregateBeforeGroupingSets(), false), + booleanProperty( + PUSH_PROJECTION_THROUGH_CROSS_JOIN, + "Push projections that reference only one side of a cross join below the join to evaluate on fewer rows", + featuresConfig.isPushProjectionThroughCrossJoin(), + false), booleanProperty( PARSE_DECIMAL_LITERALS_AS_DOUBLE, "Parse decimal literals as DOUBLE instead of DECIMAL", @@ -2689,6 +2695,11 @@ public static boolean isPreAggregateBeforeGroupingSets(Session session) return session.getSystemProperty(PRE_AGGREGATE_BEFORE_GROUPING_SETS, Boolean.class); } + public static boolean isPushProjectionThroughCrossJoin(Session session) + { + return session.getSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, Boolean.class); + } + public static boolean isParseDecimalLiteralsAsDouble(Session session) { return session.getSystemProperty(PARSE_DECIMAL_LITERALS_AS_DOUBLE, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 152be4481c424..312fe74faeebf 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -161,6 +161,7 @@ public class FeaturesConfig private boolean pushdownThroughUnnest; private boolean simplifyAggregationsOverConstant; private boolean preAggregateBeforeGroupingSets; + private boolean pushProjectionThroughCrossJoin; private double memoryRevokingTarget = 0.5; private double memoryRevokingThreshold = 0.9; private boolean useMarkDistinct = true; @@ -1757,6 +1758,18 @@ public FeaturesConfig setPreAggregateBeforeGroupingSets(boolean preAggregateBefo return this; } + public boolean isPushProjectionThroughCrossJoin() + { + return pushProjectionThroughCrossJoin; + } + + @Config("optimizer.push-projection-through-cross-join") + public FeaturesConfig setPushProjectionThroughCrossJoin(boolean pushProjectionThroughCrossJoin) + { + this.pushProjectionThroughCrossJoin = pushProjectionThroughCrossJoin; + return this; + } + public boolean isForceSingleNodeOutput() { return forceSingleNodeOutput; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java index ff53bf73799c9..0ed9aa2d82db0 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlanOptimizers.java @@ -106,6 +106,7 @@ import com.facebook.presto.sql.planner.iterative.rule.PushOffsetThroughProject; import com.facebook.presto.sql.planner.iterative.rule.PushPartialAggregationThroughExchange; import com.facebook.presto.sql.planner.iterative.rule.PushPartialAggregationThroughJoin; +import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughCrossJoin; import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughExchange; import com.facebook.presto.sql.planner.iterative.rule.PushProjectionThroughUnion; import com.facebook.presto.sql.planner.iterative.rule.PushRemoteExchangeThroughAssignUniqueId; @@ -362,6 +363,7 @@ public PlanOptimizers( ImmutableSet.of( new PushProjectionThroughUnion(), new PushProjectionThroughExchange(), + new PushProjectionThroughCrossJoin(metadata.getFunctionAndTypeManager()), new PushdownThroughUnnest(metadata.getFunctionAndTypeManager()))); IterativeOptimizer simplifyRowExpressionOptimizer = new IterativeOptimizer( diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughCrossJoin.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughCrossJoin.java new file mode 100644 index 0000000000000..5d4671bbb08a4 --- /dev/null +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/PushProjectionThroughCrossJoin.java @@ -0,0 +1,312 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.PlanNode; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.spi.relation.DeterminismEvaluator; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.facebook.presto.SystemSessionProperties.isPushProjectionThroughCrossJoin; +import static com.facebook.presto.sql.planner.VariablesExtractor.extractUnique; +import static com.facebook.presto.sql.planner.plan.Patterns.project; + +/** + * Pushes projections that reference only one side of a cross join + * below that join, so that expressions are evaluated on fewer rows. + * + *

Handles cascading projections: walks through intermediate + * ProjectNodes between the matched project and the cross join, + * pushing single-side assignments from all levels below the join. + * + *

Transforms: + *

+ * Project(a_expr = f(a), b_expr = g(b), mixed = h(a, b))
+ *   CrossJoin
+ *     Left(a)
+ *     Right(b)
+ * 
+ * to: + *
+ * Project(a_expr, b_expr, mixed = h(a, b))
+ *   CrossJoin
+ *     Project(a_expr = f(a), a)
+ *       Left(a)
+ *     Project(b_expr = g(b), b)
+ *       Right(b)
+ * 
+ * + *

Also handles cascading projections: + *

+ * Project(y = g(x))
+ *   Project(x = f(a), b = b)
+ *     CrossJoin
+ *       Left(a)
+ *       Right(b)
+ * 
+ * to: + *
+ * Project(y = y)
+ *   CrossJoin
+ *     Project(y = g(x))
+ *       Project(x = f(a))
+ *         Left(a)
+ *     Right(b)
+ * 
+ * + *

Only fires when there is at least one non-identity, deterministic + * assignment that exclusively references variables from a single side. + */ +public class PushProjectionThroughCrossJoin + implements Rule +{ + private static final Pattern PATTERN = project(); + + private final DeterminismEvaluator determinismEvaluator; + + public PushProjectionThroughCrossJoin(FunctionAndTypeManager functionAndTypeManager) + { + this.determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public boolean isEnabled(Session session) + { + return isPushProjectionThroughCrossJoin(session); + } + + @Override + public Result apply(ProjectNode project, Captures captures, Context context) + { + // Walk through intermediate ProjectNodes to find a CrossJoin + List chain = new ArrayList<>(); + chain.add(project); + PlanNode current = context.getLookup().resolve(project.getSource()); + while (current instanceof ProjectNode) { + chain.add((ProjectNode) current); + current = context.getLookup().resolve(((ProjectNode) current).getSource()); + } + + if (!(current instanceof JoinNode) || !((JoinNode) current).isCrossJoin()) { + return Result.empty(); + } + + JoinNode crossJoin = (JoinNode) current; + + Set leftVariables = ImmutableSet.copyOf( + crossJoin.getLeft().getOutputVariables()); + Set rightVariables = ImmutableSet.copyOf( + crossJoin.getRight().getOutputVariables()); + + // Track which variables are effectively from each side as we push assignments + Set effectiveLeft = new HashSet<>(leftVariables); + Set effectiveRight = new HashSet<>(rightVariables); + + // Process levels from bottom (closest to cross join) to top + // chain is [top, ..., bottom], so iterate in reverse + int totalLevels = chain.size(); + List leftPushLevels = new ArrayList<>(); + List rightPushLevels = new ArrayList<>(); + List residualLevels = new ArrayList<>(); + boolean anyPushed = false; + + for (int chainIdx = totalLevels - 1; chainIdx >= 0; chainIdx--) { + ProjectNode proj = chain.get(chainIdx); + + Assignments.Builder leftPush = Assignments.builder(); + Assignments.Builder rightPush = Assignments.builder(); + Assignments.Builder residualBuild = Assignments.builder(); + + for (Map.Entry entry : proj.getAssignments().entrySet()) { + VariableReferenceExpression outputVar = entry.getKey(); + RowExpression expression = entry.getValue(); + + // Identity assignments: propagate side membership but don't push + if (expression instanceof VariableReferenceExpression) { + VariableReferenceExpression inputVar = (VariableReferenceExpression) expression; + if (effectiveLeft.contains(inputVar)) { + effectiveLeft.add(outputVar); + } + if (effectiveRight.contains(inputVar)) { + effectiveRight.add(outputVar); + } + residualBuild.put(outputVar, expression); + continue; + } + + // Non-deterministic expressions must not be pushed below the cross join + if (!determinismEvaluator.isDeterministic(expression)) { + residualBuild.put(outputVar, expression); + continue; + } + + Set referencedVars = extractUnique(expression); + boolean refsLeft = referencedVars.stream().anyMatch(effectiveLeft::contains); + boolean refsRight = referencedVars.stream().anyMatch(effectiveRight::contains); + // If any referenced variable is defined above the cross join (not from either side), + // the expression cannot be pushed below + boolean refsAbove = referencedVars.stream() + .anyMatch(v -> !effectiveLeft.contains(v) && !effectiveRight.contains(v)); + + if (refsAbove) { + residualBuild.put(outputVar, expression); + } + else if (refsLeft && !refsRight) { + leftPush.put(outputVar, expression); + residualBuild.put(outputVar, outputVar); + effectiveLeft.add(outputVar); + anyPushed = true; + } + else if (refsRight && !refsLeft) { + rightPush.put(outputVar, expression); + residualBuild.put(outputVar, outputVar); + effectiveRight.add(outputVar); + anyPushed = true; + } + else { + // References both sides or is a constant: keep above + residualBuild.put(outputVar, expression); + } + } + + leftPushLevels.add(leftPush.build()); + rightPushLevels.add(rightPush.build()); + residualLevels.add(residualBuild.build()); + } + + if (!anyPushed) { + return Result.empty(); + } + + // Collect all variables needed by residual projects (for pass-throughs) + Set allResidualNeeds = new HashSet<>(); + for (Assignments res : residualLevels) { + for (RowExpression expression : res.getExpressions()) { + allResidualNeeds.addAll(extractUnique(expression)); + } + } + + // Build left chain: stack pushed projects on CrossJoin.Left + PlanNode newLeft = crossJoin.getLeft(); + for (Assignments pushed : leftPushLevels) { + if (pushed.isEmpty()) { + continue; + } + Assignments.Builder all = Assignments.builder(); + all.putAll(pushed); + for (VariableReferenceExpression v : newLeft.getOutputVariables()) { + if (!pushed.getMap().containsKey(v)) { + all.put(v, v); + } + } + newLeft = new ProjectNode( + newLeft.getSourceLocation(), + context.getIdAllocator().getNextId(), + newLeft, + all.build(), + ProjectNode.Locality.LOCAL); + } + + // Build right chain: stack pushed projects on CrossJoin.Right + PlanNode newRight = crossJoin.getRight(); + for (Assignments pushed : rightPushLevels) { + if (pushed.isEmpty()) { + continue; + } + Assignments.Builder all = Assignments.builder(); + all.putAll(pushed); + for (VariableReferenceExpression v : newRight.getOutputVariables()) { + if (!pushed.getMap().containsKey(v)) { + all.put(v, v); + } + } + newRight = new ProjectNode( + newRight.getSourceLocation(), + context.getIdAllocator().getNextId(), + newRight, + all.build(), + ProjectNode.Locality.LOCAL); + } + + // Build new cross join + ImmutableList.Builder newJoinOutputs = ImmutableList.builder(); + newJoinOutputs.addAll(newLeft.getOutputVariables()); + newJoinOutputs.addAll(newRight.getOutputVariables()); + + JoinNode newCrossJoin = new JoinNode( + crossJoin.getSourceLocation(), + context.getIdAllocator().getNextId(), + crossJoin.getType(), + newLeft, + newRight, + crossJoin.getCriteria(), + newJoinOutputs.build(), + crossJoin.getFilter(), + crossJoin.getLeftHashVariable(), + crossJoin.getRightHashVariable(), + crossJoin.getDistributionType(), + crossJoin.getDynamicFilters()); + + // Build residual chain above cross join + PlanNode result = newCrossJoin; + for (int i = 0; i < residualLevels.size(); i++) { + Assignments res = residualLevels.get(i); + boolean isTopLevel = (i == residualLevels.size() - 1); + + // Skip intermediate levels that are all identity (no-op pass-through) + if (!isTopLevel && isAllIdentity(res)) { + continue; + } + + result = new ProjectNode( + project.getSourceLocation(), + context.getIdAllocator().getNextId(), + result, + res, + project.getLocality()); + } + + return Result.ofPlanNode(result); + } + + private static boolean isAllIdentity(Assignments assignments) + { + return assignments.entrySet().stream().allMatch( + e -> e.getValue() instanceof VariableReferenceExpression + && e.getValue().equals(e.getKey())); + } +} diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index b71ffb0988527..da3ebe7b54621 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -142,6 +142,7 @@ public void testDefaults() .setPushdownThroughUnnest(false) .setSimplifyAggregationsOverConstant(false) .setPreAggregateBeforeGroupingSets(false) + .setPushProjectionThroughCrossJoin(false) .setForceSingleNodeOutput(true) .setPagesIndexEagerCompactionEnabled(false) .setFilterAndProjectMinOutputPageSize(new DataSize(500, KILOBYTE)) @@ -368,6 +369,7 @@ public void testExplicitPropertyMappings() .put("optimizer.pushdown-through-unnest", "true") .put("optimizer.simplify-aggregations-over-constant", "true") .put("optimizer.pre-aggregate-before-grouping-sets", "true") + .put("optimizer.push-projection-through-cross-join", "true") .put("optimizer.aggregation-partition-merging", "top_down") .put("optimizer.local-exchange-parent-preference-strategy", "automatic") .put("experimental.spill-enabled", "true") @@ -605,6 +607,7 @@ public void testExplicitPropertyMappings() .setPushdownThroughUnnest(true) .setSimplifyAggregationsOverConstant(true) .setPreAggregateBeforeGroupingSets(true) + .setPushProjectionThroughCrossJoin(true) .setSpillEnabled(true) .setJoinSpillingEnabled(false) .setSpillerSpillPaths("/tmp/custom/spill/path1,/tmp/custom/spill/path2") diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughCrossJoin.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughCrossJoin.java new file mode 100644 index 0000000000000..d78b689809caa --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestPushProjectionThroughCrossJoin.java @@ -0,0 +1,507 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.spi.plan.Assignments; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.facebook.presto.sql.planner.iterative.rule.test.BaseRuleTest; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.PUSH_PROJECTION_THROUGH_CROSS_JOIN; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.join; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.project; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.values; + +public class TestPushProjectionThroughCrossJoin + extends BaseRuleTest +{ + @Test + public void testPushLeftOnlyProjection() + { + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_1", BIGINT), p.rowExpression("a + BIGINT '1'")) + .put(b, b) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .matches( + project( + ImmutableMap.of( + "a_plus_1", expression("a_plus_1"), + "b", expression("b")), + join( + project( + ImmutableMap.of("a_plus_1", expression("a + BIGINT '1'")), + values("a")), + values("b")))); + } + + @Test + public void testPushRightOnlyProjection() + { + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(a, a) + .put(p.variable("b_plus_1", BIGINT), p.rowExpression("b + BIGINT '1'")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .matches( + project( + ImmutableMap.of( + "a", expression("a"), + "b_plus_1", expression("b_plus_1")), + join( + values("a"), + project( + ImmutableMap.of("b_plus_1", expression("b + BIGINT '1'")), + values("b"))))); + } + + @Test + public void testPushBothSides() + { + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_1", BIGINT), p.rowExpression("a + BIGINT '1'")) + .put(p.variable("b_plus_1", BIGINT), p.rowExpression("b + BIGINT '1'")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .matches( + project( + ImmutableMap.of( + "a_plus_1", expression("a_plus_1"), + "b_plus_1", expression("b_plus_1")), + join( + project( + ImmutableMap.of("a_plus_1", expression("a + BIGINT '1'")), + values("a")), + project( + ImmutableMap.of("b_plus_1", expression("b + BIGINT '1'")), + values("b"))))); + } + + @Test + public void testDoesNotFireOnMixedProjections() + { + // All projections reference both sides — nothing to push + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_b", BIGINT), p.rowExpression("a + b")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnConstantProjection() + { + // Constant expressions don't reference either side — nothing to push + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("const_val", BIGINT), p.rowExpression("BIGINT '42'")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnIdentityOnly() + { + // All projections are identity — nothing to push + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(a, a) + .put(b, b) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotFireWhenDisabled() + { + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "false") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_1", BIGINT), p.rowExpression("a + BIGINT '1'")) + .put(b, b) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotFireOnNonCrossJoin() + { + // Join with criteria is not a cross join + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_1", BIGINT), p.rowExpression("a + BIGINT '1'")) + .put(b, b) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b), + new com.facebook.presto.spi.plan.EquiJoinClause(a, b))); + }) + .doesNotFire(); + } + + @Test + public void testMixedPushableAndNonPushable() + { + // One left-only, one mixed — only the left-only should be pushed + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_1", BIGINT), p.rowExpression("a + BIGINT '1'")) + .put(p.variable("a_plus_b", BIGINT), p.rowExpression("a + b")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .matches( + project( + ImmutableMap.of( + "a_plus_1", expression("a_plus_1"), + "a_plus_b", expression("a + b")), + join( + project( + ImmutableMap.of( + "a_plus_1", expression("a + BIGINT '1'"), + "a", expression("a")), + values("a")), + values("b")))); + } + + @Test + public void testDoesNotFireOnJoinWithFilter() + { + // Inner join with a filter (no equi-join keys) is not a cross join + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_1", BIGINT), p.rowExpression("a + BIGINT '1'")) + .put(b, b) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b), + p.rowExpression("a > b"))); + }) + .doesNotFire(); + } + + @Test + public void testDoesNotPushNonDeterministicExpression() + { + // random() is non-deterministic — pushing it below the cross join would change + // semantics (computed once then replicated vs computed per output row) + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("rand_val", DOUBLE), p.rowExpression("random()")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .doesNotFire(); + } + + @Test + public void testPushesDeterministicButKeepsNonDeterministic() + { + // a + 1 is deterministic and references only left — should be pushed + // random() is non-deterministic — should stay above + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("a_plus_1", BIGINT), p.rowExpression("a + BIGINT '1'")) + .put(p.variable("rand_val", DOUBLE), p.rowExpression("random()")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b))); + }) + .matches( + project( + ImmutableMap.of( + "a_plus_1", expression("a_plus_1"), + "rand_val", expression("random()")), + join( + project( + ImmutableMap.of("a_plus_1", expression("a + BIGINT '1'")), + values("a")), + values("b")))); + } + + @Test + public void testCascadingProjectionsBothPushLeft() + { + // Project(y = x + 1) -> Project(x = a + 1, b = b) -> CrossJoin + // Both x = a + 1 and y = x + 1 are transitively left-only + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("y", BIGINT), p.rowExpression("x + BIGINT '1'")) + .build(), + p.project( + Assignments.builder() + .put(x, p.rowExpression("a + BIGINT '1'")) + .put(b, b) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b)))); + }) + .matches( + project( + ImmutableMap.of("y", expression("y")), + join( + project( + ImmutableMap.of("y", expression("x + BIGINT '1'")), + project( + ImmutableMap.of("x", expression("a + BIGINT '1'")), + values("a"))), + values("b")))); + } + + @Test + public void testCascadingProjectionsPushBothSides() + { + // Intermediate pushes to left and right, top pushes transitively to both sides + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression y = p.variable("y", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("x_plus_1", BIGINT), p.rowExpression("x + BIGINT '1'")) + .put(p.variable("y_plus_1", BIGINT), p.rowExpression("y + BIGINT '1'")) + .build(), + p.project( + Assignments.builder() + .put(x, p.rowExpression("a + BIGINT '1'")) + .put(y, p.rowExpression("b + BIGINT '1'")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b)))); + }) + .matches( + project( + ImmutableMap.of( + "x_plus_1", expression("x_plus_1"), + "y_plus_1", expression("y_plus_1")), + join( + project( + ImmutableMap.of("x_plus_1", expression("x + BIGINT '1'")), + project( + ImmutableMap.of("x", expression("a + BIGINT '1'")), + values("a"))), + project( + ImmutableMap.of("y_plus_1", expression("y + BIGINT '1'")), + project( + ImmutableMap.of("y", expression("b + BIGINT '1'")), + values("b")))))); + } + + @Test + public void testCascadingProjectionsIntermediateMixedKeepsAbove() + { + // Intermediate has left-only (x = a + 1) and mixed (w = a + b) assignments. + // Top references w which is defined above cross join, so it stays above. + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + VariableReferenceExpression x = p.variable("x", BIGINT); + VariableReferenceExpression w = p.variable("w", BIGINT); + + return p.project( + Assignments.builder() + .put(p.variable("result", BIGINT), p.rowExpression("x + w")) + .build(), + p.project( + Assignments.builder() + .put(x, p.rowExpression("a + BIGINT '1'")) + .put(w, p.rowExpression("a + b")) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b)))); + }) + .matches( + project( + ImmutableMap.of("result", expression("x + w")), + project( + ImmutableMap.of( + "x", expression("x"), + "w", expression("a + b")), + join( + project( + ImmutableMap.of("x", expression("a + BIGINT '1'")), + values("a")), + values("b"))))); + } + + @Test + public void testCascadingDoesNotFireOnAllIdentity() + { + // All cascading projects are identity — nothing to push + tester().assertThat(new PushProjectionThroughCrossJoin(getFunctionManager())) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .on(p -> { + VariableReferenceExpression a = p.variable("a", BIGINT); + VariableReferenceExpression b = p.variable("b", BIGINT); + + return p.project( + Assignments.builder() + .put(a, a) + .put(b, b) + .build(), + p.project( + Assignments.builder() + .put(a, a) + .put(b, b) + .build(), + p.join( + JoinType.INNER, + p.values(a), + p.values(b)))); + }) + .doesNotFire(); + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index f90a9eb7380f7..0280253afcac9 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -76,6 +76,7 @@ import static com.facebook.presto.SystemSessionProperties.PRE_PROCESS_METADATA_CALLS; import static com.facebook.presto.SystemSessionProperties.PULL_EXPRESSION_FROM_LAMBDA_ENABLED; import static com.facebook.presto.SystemSessionProperties.PUSH_DOWN_FILTER_EXPRESSION_EVALUATION_THROUGH_CROSS_JOIN; +import static com.facebook.presto.SystemSessionProperties.PUSH_PROJECTION_THROUGH_CROSS_JOIN; import static com.facebook.presto.SystemSessionProperties.PUSH_REMOTE_EXCHANGE_THROUGH_GROUP_ID; import static com.facebook.presto.SystemSessionProperties.QUICK_DISTINCT_LIMIT_ENABLED; import static com.facebook.presto.SystemSessionProperties.RANDOMIZE_NULL_SOURCE_KEY_IN_SEMI_JOIN_STRATEGY; @@ -1501,6 +1502,34 @@ public void testGroupingSetsWithPreAggregation() } } + @Test + public void testPushProjectionThroughCrossJoin() + { + Session enabled = Session.builder(getSession()) + .setSystemProperty(PUSH_PROJECTION_THROUGH_CROSS_JOIN, "true") + .build(); + Session disabled = getSession(); + + // Use real CROSS JOINs (JoinNode with empty equi-join criteria). + // CROSS JOIN UNNEST produces an UnnestNode, not a JoinNode. + String[] queries = { + // Left-only projections pushed below cross join + "SELECT n.nationkey * 2, r.regionkey FROM nation n CROSS JOIN region r", + // Right-only projection + "SELECT n.nationkey, r.regionkey * 10 FROM nation n CROSS JOIN region r", + // Both sides have pushable projections + "SELECT length(n.name), r.regionkey * 10 FROM nation n CROSS JOIN region r", + // Mixed: some pushable, some not + "SELECT n.nationkey * 2, n.nationkey + r.regionkey, r.name FROM nation n CROSS JOIN region r", + // Multiple expressions per side + "SELECT n.nationkey * 2, length(n.name), r.regionkey * 10, length(r.name) FROM nation n CROSS JOIN region r", + }; + + for (String query : queries) { + assertQueryWithSameQueryRunner(enabled, query, disabled); + } + } + @Test public void testGroupingWithFortyArguments() {