diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java index fd5807999b6f4..83df899380f50 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/PlannerUtils.java @@ -40,6 +40,7 @@ import com.facebook.presto.spi.plan.TableScanNode; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; +import com.facebook.presto.spi.relation.DeterminismEvaluator; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.SpecialFormExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; @@ -48,6 +49,7 @@ import com.facebook.presto.sql.planner.plan.ExchangeNode; import com.facebook.presto.sql.planner.planPrinter.PlanPrinter; import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.facebook.presto.sql.tree.ComparisonExpression; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.FunctionCall; @@ -472,6 +474,37 @@ node instanceof ProjectNode && isScanFilterProject(((ProjectNode) node).getSourc node instanceof FilterNode && isScanFilterProject(((FilterNode) node).getSource()); } + /** + * Returns true if the scan-filter-project plan subtree contains only deterministic + * expressions in all filters and projections. This check is critical for optimizations + * that clone the subtree (e.g., JoinPrefilter), because cloning a subtree with + * non-deterministic expressions (like rand()) produces different results from each + * clone, leading to incorrect query results. + */ + public static boolean isDeterministicScanFilterProject(PlanNode node, FunctionAndTypeManager functionAndTypeManager) + { + DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager); + return isDeterministicPlanSubtree(node, determinismEvaluator); + } + + private static boolean isDeterministicPlanSubtree(PlanNode node, DeterminismEvaluator determinismEvaluator) + { + if (node instanceof TableScanNode) { + return true; + } + else if (node instanceof FilterNode) { + FilterNode filterNode = (FilterNode) node; + return determinismEvaluator.isDeterministic(filterNode.getPredicate()) + && isDeterministicPlanSubtree(filterNode.getSource(), determinismEvaluator); + } + else if (node instanceof ProjectNode) { + ProjectNode projectNode = (ProjectNode) node; + return projectNode.getAssignments().getExpressions().stream().allMatch(determinismEvaluator::isDeterministic) + && isDeterministicPlanSubtree(projectNode.getSource(), determinismEvaluator); + } + return false; + } + public static CallExpression equalityPredicate(FunctionResolution functionResolution, RowExpression leftExpr, RowExpression rightExpr) { return new CallExpression(EQUAL.name(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java index 1f3581d29233f..5846984155a5b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/JoinPrefilter.java @@ -48,6 +48,7 @@ import static com.facebook.presto.sql.planner.PlannerUtils.addProjections; import static com.facebook.presto.sql.planner.PlannerUtils.clonePlanNode; import static com.facebook.presto.sql.planner.PlannerUtils.getVariableHash; +import static com.facebook.presto.sql.planner.PlannerUtils.isDeterministicScanFilterProject; import static com.facebook.presto.sql.planner.PlannerUtils.isScanFilterProject; import static com.facebook.presto.sql.planner.PlannerUtils.projectExpressions; import static com.facebook.presto.sql.planner.PlannerUtils.restrictOutput; @@ -190,8 +191,14 @@ public PlanNode visitJoin(JoinNode node, RewriteContext context) PlanNode rewrittenRight = rewriteWith(this, right); List equiJoinClause = node.getCriteria(); - // We apply this for only left and inner join and the left side of the join is a simple scan - if ((node.getType() == LEFT || node.getType() == INNER) && isScanFilterProject(rewrittenLeft) && !node.getCriteria().isEmpty()) { + // We apply this for only left and inner join and the left side of the join is a simple scan. + // We also require that all expressions in the left subtree are deterministic, because + // cloning a subtree with non-deterministic expressions (like rand() from TABLESAMPLE BERNOULLI) + // would produce different results from each clone, leading to incorrect query results. + if ((node.getType() == LEFT || node.getType() == INNER) + && isScanFilterProject(rewrittenLeft) + && isDeterministicScanFilterProject(rewrittenLeft, functionAndTypeManager) + && !node.getCriteria().isEmpty()) { List leftKeyList = equiJoinClause.stream().map(EquiJoinClause::getLeft).collect(toImmutableList()); List rightKeyList = equiJoinClause.stream().map(EquiJoinClause::getRight).collect(toImmutableList()); checkState(IntStream.range(0, leftKeyList.size()).boxed().allMatch(i -> leftKeyList.get(i).getType().equals(rightKeyList.get(i).getType()))); diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index 3a6c366e4e191..d1bf74a89a637 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -7799,7 +7799,7 @@ public void testJoinPrefilter() { // Orig String testQuery = "SELECT 1 from region join nation using(regionkey)"; - MaterializedResult result = computeActual("explain(type distributed) " + testQuery); + MaterializedResult result = computeActual("explain(type logical) " + testQuery); assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); result = computeActual(testQuery); assertEquals(result.getRowCount(), 25); @@ -7808,7 +7808,7 @@ public void testJoinPrefilter() Session session = Session.builder(getSession()) .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) .build(); - result = computeActual(session, "explain(type distributed) " + testQuery); + result = computeActual(session, "explain(type logical) " + testQuery); assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); result = computeActual(session, testQuery); assertEquals(result.getRowCount(), 25); @@ -7817,7 +7817,7 @@ public void testJoinPrefilter() { // Orig @Language("SQL") String testQuery = "SELECT 1 from region r join nation n on cast(r.regionkey as varchar) = cast(n.regionkey as varchar)"; - MaterializedResult result = computeActual("explain(type distributed) " + testQuery); + MaterializedResult result = computeActual("explain(type logical) " + testQuery); assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); result = computeActual(testQuery); assertEquals(result.getRowCount(), 25); @@ -7827,7 +7827,7 @@ public void testJoinPrefilter() .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) .setSystemProperty(REMOVE_REDUNDANT_CAST_TO_VARCHAR_IN_JOIN, String.valueOf(false)) .build(); - result = computeActual(session, "explain(type distributed) " + testQuery); + result = computeActual(session, "explain(type logical) " + testQuery); assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1); result = computeActual(session, testQuery); @@ -7837,7 +7837,7 @@ public void testJoinPrefilter() { // Orig String testQuery = "SELECT 1 from lineitem l join orders o on l.orderkey = o.orderkey and l.suppkey = o.custkey"; - MaterializedResult result = computeActual("explain(type distributed) " + testQuery); + MaterializedResult result = computeActual("explain(type logical) " + testQuery); assertEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); result = computeActual(testQuery); assertEquals(result.getRowCount(), 37); @@ -7846,7 +7846,7 @@ public void testJoinPrefilter() Session session = Session.builder(getSession()) .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) .build(); - result = computeActual(session, "explain(type distributed) " + testQuery); + result = computeActual(session, "explain(type logical) " + testQuery); assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("SemiJoin"), -1); assertNotEquals(((String) result.getMaterializedRows().get(0).getField(0)).indexOf("XX_HASH_64"), -1); result = computeActual(session, testQuery); @@ -7854,6 +7854,44 @@ public void testJoinPrefilter() } } + @Test + public void testJoinPrefilterSkippedForNonDeterministicExpressions() + { + // When the left side of a join contains non-deterministic expressions (e.g., TABLESAMPLE BERNOULLI + // which uses rand()), the JoinPrefilter optimizer should NOT clone the subtree, because each clone + // would produce a different random sample, effectively squaring the sampling rate. + Session session = Session.builder(getSession()) + .setSystemProperty(JOIN_PREFILTER_BUILD_SIDE, String.valueOf(true)) + .build(); + + // With TABLESAMPLE BERNOULLI (which introduces rand() filter), the optimizer should + // skip prefiltering and NOT produce a SemiJoin node in the plan. + // We use 50% (not 100%) to avoid RemoveFullSample optimizing away the SampleNode + // before ImplementBernoulliSampleAsFilter converts it to a rand() filter. + String testQuery = "SELECT orderkey from orders TABLESAMPLE BERNOULLI (50) join lineitem using(orderkey)"; + MaterializedResult result = computeActual(session, "explain(type logical) " + testQuery); + String plan = (String) result.getMaterializedRows().get(0).getField(0); + assertEquals(plan.indexOf("SemiJoin"), -1, + "JoinPrefilter should not produce SemiJoin when left side contains non-deterministic BERNOULLI sampling"); + + // Verify that a deterministic query with the same session setting still gets prefiltered + String deterministicQuery = "SELECT orderkey from orders join lineitem using(orderkey)"; + result = computeActual(session, "explain(type logical) " + deterministicQuery); + plan = (String) result.getMaterializedRows().get(0).getField(0); + assertNotEquals(plan.indexOf("SemiJoin"), -1, + "JoinPrefilter should produce SemiJoin for deterministic joins"); + + // Verify that TABLESAMPLE BERNOULLI on the RIGHT side still allows prefiltering, + // since the determinism guard only inspects the left scan-filter-project subtree. + // This documents the intended asymmetry: the left side is cloned for the bloom + // filter, so only the left side needs to be deterministic. + String rightSideBernoulliQuery = "SELECT orderkey from orders join lineitem TABLESAMPLE BERNOULLI (50) using(orderkey)"; + result = computeActual(session, "explain(type logical) " + rightSideBernoulliQuery); + plan = (String) result.getMaterializedRows().get(0).getField(0); + assertNotEquals(plan.indexOf("SemiJoin"), -1, + "JoinPrefilter should still produce SemiJoin when only right side contains non-deterministic BERNOULLI sampling"); + } + @Test public void testRemoveCrossJoinWithSingleRowConstantInput() {