From 347142ee5a2250efd2fdb7695df0c8f726a9e34f Mon Sep 17 00:00:00 2001 From: Sreeni Viswanadha Date: Wed, 25 Mar 2026 10:54:05 -0700 Subject: [PATCH] feat(optimizer): Enhance PayloadJoinOptimizer with null-check skipping, chain flattening, and LOJ reordering --- .../optimizations/PayloadJoinOptimizer.java | 355 +++++++++++++++++- .../TestPayloadJoinOptimizer.java | 339 +++++++++++++++++ .../tests/AbstractTestDistributedQueries.java | 146 +++++++ 3 files changed, 826 insertions(+), 14 deletions(-) create mode 100644 presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPayloadJoinOptimizer.java diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PayloadJoinOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PayloadJoinOptimizer.java index aaf9c088a48e9..27786a1be44d1 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PayloadJoinOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PayloadJoinOptimizer.java @@ -31,17 +31,21 @@ import com.facebook.presto.spi.plan.PlanNodeIdAllocator; import com.facebook.presto.spi.plan.ProjectNode; import com.facebook.presto.spi.plan.TableScanNode; +import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.planner.TypeProvider; import com.facebook.presto.sql.planner.VariablesExtractor; import com.facebook.presto.sql.planner.plan.SimplePlanRewriter; import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slices; +import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -139,9 +143,16 @@ public PlanOptimizerResult optimize(PlanNode plan, Session session, TypeProvider { FunctionAndTypeManager functionAndTypeManager = metadata.getFunctionAndTypeManager(); if (isEnabled(session)) { + PlanNode flattenedPlan = flattenJoinChains(plan, idAllocator); Rewriter rewriter = new PayloadJoinOptimizer.Rewriter(session, this.metadata, types, functionAndTypeManager, idAllocator, variableAllocator); - PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, plan, new JoinContext()); - return PlanOptimizerResult.optimizerResult(rewrittenPlan, rewriter.isPlanChanged()); + PlanNode rewrittenPlan = SimplePlanRewriter.rewriteWith(rewriter, flattenedPlan, new JoinContext()); + if (rewriter.isPlanChanged()) { + return PlanOptimizerResult.optimizerResult(rewrittenPlan, true); + } + // Pre-pass may have restructured the plan, but if the main rewrite + // didn't fire, return the original plan to avoid plan changes that + // don't produce the payload join optimization. + return PlanOptimizerResult.optimizerResult(plan, false); } return PlanOptimizerResult.optimizerResult(plan, false); } @@ -356,6 +367,10 @@ private PlanNode rewriteScanFilterProject(PlanNode planNode, RewriteContext nonNullVars = extractNonNullVariablesFromScanFilterProject(planNode, joinKeys); + context.get().addNonNullKeys(nonNullVars); + List outputCols = planNode.getOutputVariables(); if (!ImmutableSet.copyOf(planNode.getOutputVariables()).containsAll(joinKeys)) { // not all join keys are in the plan node: check if there are any pushable projections @@ -406,6 +421,7 @@ private PlanNode transformJoin(JoinNode keysNode, JoinContext context) Set joinKeys = context.getJoinKeys(); Map joinKeyMap = context.getJoinKeyMap(); + Set nonNullKeys = context.getNonNullKeys(); checkState(null != payloadPlanNode, "Payload plannode not initialized"); checkState(null != joinKeyMap, "joinkey map not initialized"); @@ -415,8 +431,7 @@ private PlanNode transformJoin(JoinNode keysNode, JoinContext context) // build new assignments of the form "jk IS NULL as jk_NULL" Assignments.Builder assignments = Assignments.builder(); - ImmutableList.Builder coalesceComparisonBuilder = ImmutableList.builder(); - ImmutableList.Builder nullComparisonBuilder = ImmutableList.builder(); + ImmutableList.Builder joinPredicateBuilder = ImmutableList.builder(); List joinOutputCols = keysNode.getOutputVariables(); @@ -427,25 +442,43 @@ private PlanNode transformJoin(JoinNode keysNode, JoinContext context) for (VariableReferenceExpression var : joinKeys) { VariableReferenceExpression newVar = joinKeyMap.get(var); - VariableReferenceExpression isNullVar = variableAllocator.newVariable(var.getName() + "_NULL", BOOLEAN); - assignments.put(isNullVar, specialForm(IS_NULL, BOOLEAN, ImmutableList.of(var))); - - // construct predicate of the form "coalesce(newVar, 0) = coalesce(var, 0)" - RowExpression coalesceComp = equalityPredicate(functionResolution, coalesceToZero(newVar), coalesceToZero(var)); - RowExpression nullComp = equalityPredicate(functionResolution, specialForm(IS_NULL, BOOLEAN, ImmutableList.of(newVar)), isNullVar); - nullComparisonBuilder.add(nullComp); - coalesceComparisonBuilder.add(coalesceComp); + if (nonNullKeys.contains(var)) { + // Key is guaranteed non-null: use direct equality + joinPredicateBuilder.add(equalityPredicate(functionResolution, newVar, var)); + } + else { + // Key may be null: use IS_NULL comparison + COALESCE comparison + VariableReferenceExpression isNullVar = variableAllocator.newVariable(var.getName() + "_NULL", BOOLEAN); + assignments.put(isNullVar, specialForm(IS_NULL, BOOLEAN, ImmutableList.of(var))); + + RowExpression coalesceComp = equalityPredicate(functionResolution, coalesceToZero(newVar), coalesceToZero(var)); + RowExpression nullComp = equalityPredicate(functionResolution, specialForm(IS_NULL, BOOLEAN, ImmutableList.of(newVar)), isNullVar); + joinPredicateBuilder.add(nullComp); + joinPredicateBuilder.add(coalesceComp); + } } ProjectNode projectNode = new ProjectNode(planNodeIdAllocator.getNextId(), keysNode, assignments.build()); List resultOutputCols = Stream.concat(payloadPlanNode.getOutputVariables().stream(), projectNode.getOutputVariables().stream()).collect(toImmutableList()); - List joinCriteria = Stream.concat(nullComparisonBuilder.build().stream(), coalesceComparisonBuilder.build().stream()).collect(toImmutableList()); + List joinCriteria = joinPredicateBuilder.build(); + + // If all keys are non-null and all key expressions are deterministic, + // use INNER join (every payload row matches exactly one distinct key). + // Non-deterministic keys (e.g., random()) are computed separately in the + // cloned payload and distinct-keys subtrees, so values may differ and + // an INNER join could incorrectly drop rows. + RowExpressionDeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager); + Map projectionsToPush = context.getProjectionsToPush(); + boolean allKeysDeterministic = joinKeys.stream() + .allMatch(key -> !projectionsToPush.containsKey(key) || determinismEvaluator.isDeterministic(projectionsToPush.get(key))); + boolean allKeysNonNull = nonNullKeys.containsAll(joinKeys); + JoinType rejoinType = (allKeysNonNull && allKeysDeterministic) ? JoinType.INNER : JoinType.LEFT; return new JoinNode( keysNode.getSourceLocation(), planNodeIdAllocator.getNextId(), - JoinType.LEFT, + rejoinType, payloadPlanNode, projectNode, ImmutableList.of(), @@ -510,6 +543,44 @@ private boolean supportedJoinKeyTypes(Set joinKeys) { return joinKeys.stream().allMatch(key -> key.getType() instanceof VarcharType || isNumericType(key.getType())); } + + private Set extractNonNullVariablesFromScanFilterProject(PlanNode node, Set joinKeys) + { + FunctionResolution functionResolution = new FunctionResolution(functionAndTypeManager.getFunctionAndTypeResolver()); + ImmutableSet.Builder nonNullVars = ImmutableSet.builder(); + extractNonNullVariablesRecursive(node, joinKeys, functionResolution, nonNullVars); + return nonNullVars.build(); + } + + private void extractNonNullVariablesRecursive(PlanNode node, Set joinKeys, + FunctionResolution functionResolution, ImmutableSet.Builder result) + { + if (node instanceof FilterNode) { + FilterNode filterNode = (FilterNode) node; + RowExpression predicate = filterNode.getPredicate(); + for (RowExpression conjunct : LogicalRowExpressions.extractConjuncts(predicate)) { + if (conjunct instanceof CallExpression) { + CallExpression call = (CallExpression) conjunct; + if (functionResolution.isNotFunction(call.getFunctionHandle()) + && call.getArguments().size() == 1 + && call.getArguments().get(0) instanceof SpecialFormExpression + && ((SpecialFormExpression) call.getArguments().get(0)).getForm() == IS_NULL + && ((SpecialFormExpression) call.getArguments().get(0)).getArguments().size() == 1 + && ((SpecialFormExpression) call.getArguments().get(0)).getArguments().get(0) instanceof VariableReferenceExpression) { + VariableReferenceExpression variable = (VariableReferenceExpression) ((SpecialFormExpression) call.getArguments().get(0)).getArguments().get(0); + if (joinKeys.contains(variable)) { + result.add(variable); + } + } + } + } + extractNonNullVariablesRecursive(filterNode.getSource(), joinKeys, functionResolution, result); + } + else if (node instanceof ProjectNode) { + extractNonNullVariablesRecursive(((ProjectNode) node).getSource(), joinKeys, functionResolution, result); + } + // TableScanNode is a leaf — nothing to do + } } private static RowExpression zeroForType(Type type) @@ -522,11 +593,256 @@ private static RowExpression zeroForType(Type type) return constant(Slices.utf8Slice(""), VarcharType.VARCHAR); } + /** + * Pre-pass: flatten LOJ chains by removing identity projections and hoisting cross joins + * above the LOJ chain. This allows the payload join optimization to handle more join patterns. + */ + private static PlanNode flattenJoinChains(PlanNode node, PlanNodeIdAllocator idAllocator) + { + List children = node.getSources(); + ImmutableList.Builder newChildrenBuilder = ImmutableList.builder(); + boolean childChanged = false; + for (PlanNode child : children) { + PlanNode newChild = flattenJoinChains(child, idAllocator); + if (newChild != child) { + childChanged = true; + } + newChildrenBuilder.add(newChild); + } + + PlanNode current = childChanged ? replaceChildren(node, newChildrenBuilder.build()) : node; + + if (current instanceof JoinNode && ((JoinNode) current).getType() == LEFT) { + PlanNode flattened = flattenLeftChain((JoinNode) current, idAllocator); + if (flattened instanceof JoinNode && ((JoinNode) flattened).getType() == LEFT) { + return reorderLeftJoinChain((JoinNode) flattened, idAllocator); + } + return flattened; + } + + return current; + } + + private static PlanNode flattenLeftChain(JoinNode joinNode, PlanNodeIdAllocator idAllocator) + { + PlanNode left = joinNode.getLeft(); + + // Case 1: Left child is an identity projection - remove it + if (left instanceof ProjectNode && isIdentityProjection((ProjectNode) left)) { + PlanNode projectSource = ((ProjectNode) left).getSource(); + List newOutput = Stream.concat( + projectSource.getOutputVariables().stream(), + joinNode.getRight().getOutputVariables().stream()) + .collect(toImmutableList()); + + JoinNode newJoin = new JoinNode( + joinNode.getSourceLocation(), + idAllocator.getNextId(), + joinNode.getType(), + projectSource, + joinNode.getRight(), + joinNode.getCriteria(), + newOutput, + joinNode.getFilter(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), + joinNode.getDistributionType(), + joinNode.getDynamicFilters()); + + return flattenLeftChain(newJoin, idAllocator); + } + + // Case 2: Left child is a cross join - hoist it above the LOJ + if (left instanceof JoinNode && ((JoinNode) left).isCrossJoin()) { + JoinNode crossJoin = (JoinNode) left; + + Set lojLeftKeys = extractLeftJoinKeys(joinNode); + Set crossLeftCols = ImmutableSet.copyOf(crossJoin.getLeft().getOutputVariables()); + Set crossRightCols = ImmutableSet.copyOf(crossJoin.getRight().getOutputVariables()); + + PlanNode chainSide = null; + PlanNode crossSide = null; + + if (crossLeftCols.containsAll(lojLeftKeys)) { + chainSide = crossJoin.getLeft(); + crossSide = crossJoin.getRight(); + } + else if (crossRightCols.containsAll(lojLeftKeys)) { + chainSide = crossJoin.getRight(); + crossSide = crossJoin.getLeft(); + } + + if (chainSide != null) { + List lojOutput = Stream.concat( + chainSide.getOutputVariables().stream(), + joinNode.getRight().getOutputVariables().stream()) + .collect(toImmutableList()); + + JoinNode newLOJ = new JoinNode( + joinNode.getSourceLocation(), + idAllocator.getNextId(), + joinNode.getType(), + chainSide, + joinNode.getRight(), + joinNode.getCriteria(), + lojOutput, + joinNode.getFilter(), + joinNode.getLeftHashVariable(), + joinNode.getRightHashVariable(), + joinNode.getDistributionType(), + joinNode.getDynamicFilters()); + + PlanNode flattenedLOJ = flattenLeftChain(newLOJ, idAllocator); + + List crossOutput = Stream.concat( + flattenedLOJ.getOutputVariables().stream(), + crossSide.getOutputVariables().stream()) + .collect(toImmutableList()); + + return new JoinNode( + crossJoin.getSourceLocation(), + idAllocator.getNextId(), + crossJoin.getType(), + flattenedLOJ, + crossSide, + crossJoin.getCriteria(), + crossOutput, + crossJoin.getFilter(), + crossJoin.getLeftHashVariable(), + crossJoin.getRightHashVariable(), + crossJoin.getDistributionType(), + crossJoin.getDynamicFilters()); + } + } + + return joinNode; + } + + private static Set extractLeftJoinKeys(JoinNode joinNode) + { + ImmutableSet.Builder builder = ImmutableSet.builder(); + + for (EquiJoinClause clause : joinNode.getCriteria()) { + builder.add(clause.getLeft()); + } + + if (joinNode.getFilter().isPresent()) { + Set rightCols = ImmutableSet.copyOf(joinNode.getRight().getOutputVariables()); + for (VariableReferenceExpression var : VariablesExtractor.extractAll(joinNode.getFilter().get())) { + if (!rightCols.contains(var)) { + builder.add(var); + } + } + } + + return builder.build(); + } + + private static boolean isIdentityProjection(ProjectNode project) + { + return project.getAssignments().entrySet().stream() + .allMatch(entry -> entry.getValue().equals(entry.getKey())); + } + + /** + * Reorder LOJs in a chain so that base-keyed LOJs (keys from the base table) come first, + * and dependent LOJs (keys from other LOJ results) are pushed to the top. This maximizes + * the number of LOJs that the payload join optimization can handle. + */ + private static PlanNode reorderLeftJoinChain(JoinNode topJoin, PlanNodeIdAllocator idAllocator) + { + // Collect all LOJs in the chain (top to bottom order) + List joins = new ArrayList<>(); + PlanNode current = topJoin; + while (current instanceof JoinNode && ((JoinNode) current).getType() == LEFT) { + joins.add((JoinNode) current); + current = ((JoinNode) current).getLeft(); + } + PlanNode baseNode = current; + + if (joins.size() < 3) { + return topJoin; + } + + Set baseColumns = ImmutableSet.copyOf(baseNode.getOutputVariables()); + + // Classify in bottom-to-top order: base-keyed vs dependent + List baseKeyed = new ArrayList<>(); + List dependent = new ArrayList<>(); + boolean needsReorder = false; + boolean seenDependent = false; + + for (int i = joins.size() - 1; i >= 0; i--) { + JoinNode j = joins.get(i); + Set leftKeys = extractLeftJoinKeys(j); + if (baseColumns.containsAll(leftKeys)) { + baseKeyed.add(j); + if (seenDependent) { + needsReorder = true; + } + } + else { + dependent.add(j); + seenDependent = true; + } + } + + if (!needsReorder || baseKeyed.size() < 2) { + return topJoin; + } + + // Rebuild: base -> baseKeyed LOJs -> dependent LOJs + PlanNode result = baseNode; + for (JoinNode j : baseKeyed) { + List newOutput = Stream.concat( + result.getOutputVariables().stream(), + j.getRight().getOutputVariables().stream()) + .collect(toImmutableList()); + + result = new JoinNode( + j.getSourceLocation(), + idAllocator.getNextId(), + j.getType(), + result, + j.getRight(), + j.getCriteria(), + newOutput, + j.getFilter(), + j.getLeftHashVariable(), + j.getRightHashVariable(), + j.getDistributionType(), + j.getDynamicFilters()); + } + for (JoinNode j : dependent) { + List newOutput = Stream.concat( + result.getOutputVariables().stream(), + j.getRight().getOutputVariables().stream()) + .collect(toImmutableList()); + + result = new JoinNode( + j.getSourceLocation(), + idAllocator.getNextId(), + j.getType(), + result, + j.getRight(), + j.getCriteria(), + newOutput, + j.getFilter(), + j.getLeftHashVariable(), + j.getRightHashVariable(), + j.getDistributionType(), + j.getDynamicFilters()); + } + + return result; + } + private static class JoinContext { private Set joinKeys = new HashSet<>(); private Map joinKeyMap; private Map projectionsToPush = new HashMap<>(); + private Set nonNullKeys = new HashSet<>(); int numJoins; PlanNode payloadNode; @@ -580,6 +896,7 @@ public void reset() joinKeyMap = null; numJoins = 0; payloadNode = null; + nonNullKeys = new HashSet<>(); } public int getNumJoins() @@ -596,5 +913,15 @@ public boolean needsPayloadRejoin() { return payloadNode != null; } + + public Set getNonNullKeys() + { + return nonNullKeys; + } + + public void addNonNullKeys(Set keys) + { + nonNullKeys.addAll(keys); + } } } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPayloadJoinOptimizer.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPayloadJoinOptimizer.java new file mode 100644 index 0000000000000..a888736f09956 --- /dev/null +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestPayloadJoinOptimizer.java @@ -0,0 +1,339 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.optimizations; + +import com.facebook.presto.Session; +import com.facebook.presto.spi.plan.AggregationNode; +import com.facebook.presto.spi.plan.JoinNode; +import com.facebook.presto.spi.plan.JoinType; +import com.facebook.presto.spi.plan.ProjectNode; +import com.facebook.presto.sql.planner.Plan; +import com.facebook.presto.sql.planner.assertions.BasePlanTest; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_PAYLOAD_JOINS; +import static com.facebook.presto.SystemSessionProperties.REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN; +import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED; +import static com.facebook.presto.sql.planner.optimizations.PlanNodeSearcher.searchFrom; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestPayloadJoinOptimizer + extends BasePlanTest +{ + public TestPayloadJoinOptimizer() + { + super(ImmutableMap.of( + REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false")); + } + + private Session optimizedSession() + { + return Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "true") + .build(); + } + + private Session unoptimizedSession() + { + return Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "false") + .build(); + } + + @Test + public void testBasicPayloadJoinRewrite() + { + // A chain of 2+ LOJs should produce an AggregationNode (DISTINCT keys) + String sql = "SELECT l.* FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + assertFalse(searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Optimized 2-LOJ chain should contain AggregationNode for DISTINCT keys"); + + Plan unoptimizedPlan = plan(sql, OPTIMIZED, true, unoptimizedSession()); + assertTrue(searchFrom(unoptimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Unoptimized plan should not contain AggregationNode"); + } + + @Test + public void testSingleLeftJoinNotRewritten() + { + // A single LOJ should NOT trigger the payload join optimization (requires 2+) + String sql = "SELECT l.* FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + assertTrue(searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Single LOJ should not trigger payload join rewrite"); + } + + @Test + public void testNonNullKeysReduceProjections() + { + // When join keys have IS NOT NULL predicates, the rejoin uses direct equality + // instead of IS_NULL + COALESCE pairs, producing fewer ProjectNodes. + String sqlNonNull = "SELECT l.* FROM " + + "(SELECT * FROM lineitem WHERE orderkey IS NOT NULL AND partkey IS NOT NULL) l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + String sqlNullable = "SELECT l.* FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + Plan nonNullPlan = plan(sqlNonNull, OPTIMIZED, true, optimizedSession()); + Plan nullablePlan = plan(sqlNullable, OPTIMIZED, true, optimizedSession()); + + // Both should be rewritten + assertFalse(searchFrom(nonNullPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Non-null key query should be rewritten"); + assertFalse(searchFrom(nullablePlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Nullable key query should be rewritten"); + + // Non-null plan should have fewer ProjectNodes (no IS_NULL projections) + int nonNullProjects = searchFrom(nonNullPlan.getRoot()) + .where(n -> n instanceof ProjectNode) + .findAll().size(); + int nullableProjects = searchFrom(nullablePlan.getRoot()) + .where(n -> n instanceof ProjectNode) + .findAll().size(); + assertTrue(nonNullProjects < nullableProjects, + "Non-null plan should have fewer ProjectNodes. " + + "Non-null: " + nonNullProjects + ", Nullable: " + nullableProjects); + } + + @Test + public void testPartialNonNullKeys() + { + // Only one key is non-null — should still be rewritten, with mixed predicate styles + String sql = "SELECT l.* FROM " + + "(SELECT * FROM lineitem WHERE orderkey IS NOT NULL) l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + assertFalse(searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Partial non-null key query should be rewritten"); + } + + @Test + public void testInterveningIdentityProjection() + { + // Subquery wrapping generates an identity ProjectNode between LOJs. + // The pre-pass should remove it so the full chain is optimized. + String sql = "SELECT sub.*, s.name as s_name FROM " + + "(SELECT l.orderkey, l.partkey, l.suppkey, o.orderstatus " + + "FROM lineitem l LEFT JOIN orders o ON l.orderkey = o.orderkey) sub " + + "LEFT JOIN supplier s ON sub.suppkey = s.suppkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + assertFalse(searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Chain with identity projection should be rewritten"); + } + + @Test + public void testCrossJoinHoisted() + { + // Cross join between LOJs should be hoisted above the chain + String sql = "SELECT l.orderkey, l.partkey, o.orderstatus, n.name as nation_name, p.brand " + + "FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "CROSS JOIN (SELECT name FROM nation WHERE name = 'JAPAN') n " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + assertFalse(searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Chain with cross join should be rewritten after hoisting"); + } + + @Test + public void testAllBaseKeyedJoinsOptimized() + { + // When all LOJs in the chain are base-keyed, the full chain is optimized + String sql = "SELECT l.orderkey, l.partkey, l.suppkey, " + + "o.orderstatus, p.brand, s.name " + + "FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey " + + "LEFT JOIN supplier s ON l.suppkey = s.suppkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + assertFalse(searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "All-base-keyed 3-LOJ chain should be rewritten"); + } + + @Test + public void testDependentJoinKeyFromRightSideAbortsRewrite() + { + // When a join key comes from the right side of a prior LOJ (dependent LOJ), + // the optimizer correctly handles the chain. The 2-LOJ chain with a dependent + // second LOJ may or may not be optimized depending on key provenance. + String sql = "SELECT l.orderkey, o.orderstatus, s.name " + + "FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN supplier s ON o.shippriority = s.suppkey"; + + // This query has 2 LOJs but the second one uses o.shippriority (from orders RHS). + // The optimizer aborts when it detects collected join keys in the RHS. + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + Plan unoptimizedPlan = plan(sql, OPTIMIZED, true, unoptimizedSession()); + + // Both plans should have the same number of AggregationNodes (none), + // confirming the optimizer correctly handles this case without crashing + assertEquals( + searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().size(), + searchFrom(unoptimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().size(), + "Dependent join key chain should produce same aggregation structure"); + } + + @Test + public void testNonIdentityProjectionPreservesChain() + { + // Non-identity projection computing a join key should not break the chain + String sql = "SELECT sub.*, s.name as s_name FROM " + + "(SELECT l.orderkey, l.partkey, l.suppkey + 0 as sk, o.orderstatus " + + "FROM lineitem l LEFT JOIN orders o ON l.orderkey = o.orderkey) sub " + + "LEFT JOIN supplier s ON sub.sk = s.suppkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + assertFalse(searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(), + "Non-identity projection with computed join key should be rewritten"); + } + + @Test + public void testAllNonNullKeysUseInnerJoinRejoin() + { + // When all join keys are guaranteed non-null, the rejoin should use INNER join + String sql = "SELECT l.* FROM " + + "(SELECT * FROM lineitem WHERE orderkey IS NOT NULL AND partkey IS NOT NULL) l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + + // The rejoin JoinNode should be INNER (payload is left side, so it's the top-most join + // that has a child AggregationNode on the right side) + boolean hasInnerRejoin = searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof JoinNode && ((JoinNode) n).getType() == JoinType.INNER) + .findAll().stream() + .anyMatch(n -> !searchFrom(((JoinNode) n).getRight()) + .where(c -> c instanceof AggregationNode) + .findAll().isEmpty()); + assertTrue(hasInnerRejoin, + "All non-null keys should produce INNER join rejoin"); + } + + @Test + public void testNullableKeysUseLeftJoinRejoin() + { + // When keys are nullable, the rejoin should remain LEFT join + String sql = "SELECT l.* FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + + // The rejoin JoinNode should be LEFT + boolean hasLeftRejoin = searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof JoinNode && ((JoinNode) n).getType() == JoinType.LEFT) + .findAll().stream() + .anyMatch(n -> !searchFrom(((JoinNode) n).getRight()) + .where(c -> c instanceof AggregationNode) + .findAll().isEmpty()); + assertTrue(hasLeftRejoin, + "Nullable keys should produce LEFT join rejoin"); + } + + @Test + public void testPartialNonNullKeysUseLeftJoinRejoin() + { + // When only some keys are non-null, the rejoin should remain LEFT join + String sql = "SELECT l.* FROM " + + "(SELECT * FROM lineitem WHERE orderkey IS NOT NULL) l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + + // The rejoin should still be LEFT since partkey might be null + boolean hasLeftRejoin = searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof JoinNode && ((JoinNode) n).getType() == JoinType.LEFT) + .findAll().stream() + .anyMatch(n -> !searchFrom(((JoinNode) n).getRight()) + .where(c -> c instanceof AggregationNode) + .findAll().isEmpty()); + assertTrue(hasLeftRejoin, + "Partial non-null keys should produce LEFT join rejoin"); + } + + @Test + public void testNonDeterministicKeysUseLeftJoinRejoin() + { + // Even if keys are non-null, non-deterministic computed keys should use LEFT join + // because the cloned payload and distinct-keys subtrees would compute different values + String sql = "SELECT sub.* FROM " + + "(SELECT l.orderkey, l.partkey, CAST(random() * 100 AS BIGINT) + l.suppkey as rk " + + "FROM lineitem l WHERE orderkey IS NOT NULL AND partkey IS NOT NULL) sub " + + "LEFT JOIN orders o ON sub.orderkey = o.orderkey " + + "LEFT JOIN supplier s ON sub.rk = s.suppkey"; + + Plan optimizedPlan = plan(sql, OPTIMIZED, true, optimizedSession()); + + // If optimized, the rejoin should be LEFT (not INNER) due to non-deterministic key + boolean hasAggregation = !searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof AggregationNode) + .findAll().isEmpty(); + + if (hasAggregation) { + // Only check rejoin type if the optimizer actually fired + boolean hasInnerRejoin = searchFrom(optimizedPlan.getRoot()) + .where(n -> n instanceof JoinNode && ((JoinNode) n).getType() == JoinType.INNER) + .findAll().stream() + .anyMatch(n -> !searchFrom(((JoinNode) n).getRight()) + .where(c -> c instanceof AggregationNode) + .findAll().isEmpty()); + assertFalse(hasInnerRejoin, + "Non-deterministic key should prevent INNER join rejoin"); + } + } +} diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java index 0711400559056..3e6900c6db00d 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestDistributedQueries.java @@ -1470,6 +1470,152 @@ public void testPayloadJoinCorrectness() } } + @Test + public void testPayloadJoinSkipsNullChecksForNonNullKeys() + { + Session sessionNoOpt = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "false") + .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false") + .build(); + + Session session = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "true") + .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false") + .build(); + + // Queries with WHERE key IS NOT NULL on the base table — the rejoin + // predicate should use direct equality instead of IS_NULL + COALESCE + String[] queries = { + // Both join keys are non-null + "SELECT l.* FROM (select * from lineitem where orderkey IS NOT NULL AND partkey IS NOT NULL) l left join orders o on (l.orderkey = o.orderkey) left join part p on (l.partkey=p.partkey)", + // Only one join key is non-null + "SELECT l.* FROM (select * from lineitem where orderkey IS NOT NULL) l left join orders o on (l.orderkey = o.orderkey) left join part p on (l.partkey=p.partkey)", + // IS NOT NULL combined with other filter predicates + "SELECT l.* FROM (select * from lineitem where orderkey IS NOT NULL AND partkey IS NOT NULL AND quantity > 1) l left join orders o on (l.orderkey = o.orderkey) left join part p on (l.partkey=p.partkey)", + }; + + for (String query : queries) { + // Verify plan is optimized (structurally different from unoptimized) + MaterializedResult resultExplainQuery = computeActual(session, "EXPLAIN " + query); + MaterializedResult resultExplainQueryNoOpt = computeActual(sessionNoOpt, "EXPLAIN " + query); + String explainNoOpt = sanitizePlan((String) getOnlyElement(resultExplainQueryNoOpt.getOnlyColumnAsSet())); + String explainWithOpt = sanitizePlan((String) getOnlyElement(resultExplainQuery.getOnlyColumnAsSet())); + assertNotEquals(explainWithOpt, explainNoOpt, "Couldn't optimize query: " + query); + + // Verify correctness + assertQueryWithSameQueryRunner(session, query, sessionNoOpt); + } + } + + @Test + public void testPayloadJoinWithInterveningProjectionsAndCrossJoins() + { + Session sessionNoOpt = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "false") + .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false") + .build(); + + Session session = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "true") + .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false") + .build(); + + String[] queries = { + // Intervening identity projection from subquery wrapping inner LOJs + "SELECT sub.*, s.name as s_name FROM (SELECT l.orderkey, l.partkey, l.suppkey, o.orderstatus " + + "FROM lineitem l LEFT JOIN orders o ON l.orderkey = o.orderkey) sub " + + "LEFT JOIN supplier s ON sub.suppkey = s.suppkey", + // Cross join between LOJs: t LOJ r1 CROSS JOIN c LOJ r2 + "SELECT l.orderkey, l.partkey, o.orderstatus, n.name as nation_name, p.brand " + + "FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "CROSS JOIN (SELECT name FROM nation WHERE name = 'JAPAN') n " + + "LEFT JOIN part p ON l.partkey = p.partkey", + // Both: subquery with identity projection AND cross join + "SELECT sub.*, p.brand FROM (" + + "SELECT l.orderkey, l.partkey, o.orderstatus, n.name as nation_name " + + "FROM lineitem l LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "CROSS JOIN (SELECT name FROM nation WHERE name = 'JAPAN') n) sub " + + "LEFT JOIN part p ON sub.partkey = p.partkey", + // Intervening projection computes a field used as a later join key + "SELECT sub.*, s.name as s_name FROM (SELECT l.orderkey, l.partkey, l.suppkey + 0 as sk, o.orderstatus " + + "FROM lineitem l LEFT JOIN orders o ON l.orderkey = o.orderkey) sub " + + "LEFT JOIN supplier s ON sub.sk = s.suppkey", + }; + + for (String query : queries) { + // Verify plan is optimized + MaterializedResult resultExplainQuery = computeActual(session, "EXPLAIN " + query); + MaterializedResult resultExplainQueryNoOpt = computeActual(sessionNoOpt, "EXPLAIN " + query); + String explainNoOpt = sanitizePlan((String) getOnlyElement(resultExplainQueryNoOpt.getOnlyColumnAsSet())); + String explainWithOpt = sanitizePlan((String) getOnlyElement(resultExplainQuery.getOnlyColumnAsSet())); + assertNotEquals(explainWithOpt, explainNoOpt, "Couldn't optimize query: " + query); + + // Verify correctness + assertQueryWithSameQueryRunner(session, query, sessionNoOpt); + } + + // Queries where cross join columns are used as subsequent LOJ keys. + // The optimizer cannot handle these (keys not from the base table), + // but we verify correctness is preserved. + String[] crossJoinKeyQueries = { + // Cross join column used as a join key in a later LOJ + "SELECT l.orderkey, o.orderstatus, n.nationkey, s.name as s_name " + + "FROM lineitem l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "CROSS JOIN (SELECT nationkey FROM nation WHERE name = 'JAPAN') n " + + "LEFT JOIN supplier s ON n.nationkey = s.nationkey", + // Cross join column used as join key combined with identity projection + "SELECT sub.*, s.name as s_name FROM (" + + "SELECT l.orderkey, l.partkey, o.orderstatus, n.nationkey " + + "FROM lineitem l LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "CROSS JOIN (SELECT nationkey FROM nation WHERE name = 'JAPAN') n) sub " + + "LEFT JOIN supplier s ON sub.nationkey = s.nationkey", + }; + + for (String query : crossJoinKeyQueries) { + assertQueryWithSameQueryRunner(session, query, sessionNoOpt); + } + } + + @Test + public void testPayloadJoinInnerRejoinWithNonNullKeys() + { + Session sessionNoOpt = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "false") + .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false") + .build(); + + Session session = Session.builder(getSession()) + .setSystemProperty(OPTIMIZE_PAYLOAD_JOINS, "true") + .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, "false") + .build(); + + // When all join keys are guaranteed non-null, the optimizer uses INNER join + // for the payload rejoin instead of LEFT join + String[] queries = { + // Both keys non-null — should use INNER rejoin + "SELECT l.* FROM (SELECT * FROM lineitem WHERE orderkey IS NOT NULL AND partkey IS NOT NULL) l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey", + // All three keys non-null + "SELECT l.orderkey, l.partkey, l.suppkey, o.orderstatus, p.brand, s.name " + + "FROM (SELECT * FROM lineitem WHERE orderkey IS NOT NULL AND partkey IS NOT NULL AND suppkey IS NOT NULL) l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey " + + "LEFT JOIN supplier s ON l.suppkey = s.suppkey", + // Non-null with additional filter predicates + "SELECT l.* FROM (SELECT * FROM lineitem WHERE orderkey IS NOT NULL AND partkey IS NOT NULL AND quantity > 1) l " + + "LEFT JOIN orders o ON l.orderkey = o.orderkey " + + "LEFT JOIN part p ON l.partkey = p.partkey", + }; + + for (String query : queries) { + // Verify correctness: results should match between optimized and non-optimized + assertQueryWithSameQueryRunner(session, query, sessionNoOpt); + } + } + private static List getPayloadQueries(String tableName) { String[] queries = {