diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java index 2aefb2f5e376..f496f3977b56 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/FeaturesConfig.java @@ -135,7 +135,7 @@ public class FeaturesConfig private DataSize filterAndProjectMinOutputPageSize = DataSize.of(500, KILOBYTE); private int filterAndProjectMinOutputPageRowCount = 256; private int maxGroupingSets = 2048; - private JoinPushdownMode joinPushdownMode = JoinPushdownMode.DISABLED; + private JoinPushdownMode joinPushdownMode = JoinPushdownMode.AUTOMATIC; public enum JoinReorderingStrategy { @@ -179,9 +179,12 @@ public enum JoinPushdownMode * Try to push all joins except cross-joins to connector. */ EAGER, - // TODO Add cost based logic to join pushdown - // AUTOMATIC, - /**/; + /** + * Determine automatically if push join to connector based on table statistics. + * Do not perform join in absence of table statistics. + */ + AUTOMATIC, + /**/ } public double getCpuCostWeight() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java index 19eae89f1479..5cb00e1da2ed 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import io.trino.Session; +import io.trino.cost.PlanNodeStatsEstimate; import io.trino.matching.Capture; import io.trino.matching.Captures; import io.trino.matching.Pattern; @@ -32,6 +33,7 @@ import io.trino.sql.ExpressionUtils; import io.trino.sql.analyzer.FeaturesConfig.JoinPushdownMode; import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.iterative.Rule; import io.trino.sql.planner.plan.JoinNode; import io.trino.sql.planner.plan.Patterns; @@ -62,6 +64,7 @@ import static io.trino.sql.planner.plan.Patterns.Join.left; import static io.trino.sql.planner.plan.Patterns.Join.right; import static io.trino.sql.planner.plan.Patterns.tableScan; +import static java.lang.Double.isNaN; import static java.util.Objects.requireNonNull; public class PushJoinIntoTableScan @@ -114,6 +117,10 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) return Result.empty(); } + if (skipJoinPushdownBasedOnCost(joinNode, context)) { + return Result.empty(); + } + Map leftAssignments = left.getAssignments().entrySet().stream() .collect(toImmutableMap(entry -> entry.getKey().getName(), Map.Entry::getValue)); @@ -162,6 +169,43 @@ public Result apply(JoinNode joinNode, Captures captures, Context context) return Result.ofPlanNode(new TableScanNode(joinNode.getId(), handle, joinNode.getOutputSymbols(), newAssignments.build(), newEnforcedConstraint, false)); } + private boolean skipJoinPushdownBasedOnCost(JoinNode joinNode, Context context) + { + if (getJoinPushdownMode(context.getSession()) != JoinPushdownMode.AUTOMATIC) { + return false; + } + + TypeProvider types = context.getSymbolAllocator().getTypes(); + + // returning as quickly as possible to avoid unnecessary, costly work + + PlanNodeStatsEstimate leftStats = context.getStatsProvider().getStats(joinNode.getLeft()); + double leftOutputSize = leftStats.getOutputSizeInBytes(joinNode.getLeft().getOutputSymbols(), types); + if (isNaN(leftOutputSize)) { + return true; + } + + PlanNodeStatsEstimate rightStats = context.getStatsProvider().getStats(joinNode.getRight()); + double rightOutputSize = rightStats.getOutputSizeInBytes(joinNode.getRight().getOutputSymbols(), types); + if (isNaN(rightOutputSize)) { + return true; + } + + PlanNodeStatsEstimate joinStats = context.getStatsProvider().getStats(joinNode); + double joinOutputSize = joinStats.getOutputSizeInBytes(joinNode.getOutputSymbols(), types); + if (isNaN(joinOutputSize)) { + return true; + } + + if (joinOutputSize > leftOutputSize + rightOutputSize) { + // This is poor man's estimation if it makes more sense to perform join in source database or Trino. + // The assumption here is that cost of performing join in source database is less than or equal to cost of join in Trino. + // We resolve tie for pessimistic case (both join costs equal) on cost of sending the data from source database to Trino. + return true; + } + return false; + } + private TupleDomain deriveConstraint(TupleDomain sourceConstraint, Map columnMapping, boolean nullable) { TupleDomain constraint = sourceConstraint; diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java index 9802e77e1248..361337fa5dbb 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestPushJoinIntoTableScan.java @@ -20,6 +20,7 @@ import io.trino.connector.MockConnectorColumnHandle; import io.trino.connector.MockConnectorFactory; import io.trino.connector.MockConnectorTableHandle; +import io.trino.cost.PlanNodeStatsEstimate; import io.trino.metadata.TableHandle; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ColumnMetadata; @@ -34,8 +35,10 @@ import io.trino.spi.predicate.NullableValue; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.rule.test.RuleAssert; import io.trino.sql.planner.iterative.rule.test.RuleTester; import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.TableScanNode; import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ComparisonExpression; @@ -47,6 +50,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.OptionalDouble; import java.util.function.Predicate; import static com.google.common.base.Predicates.equalTo; @@ -579,6 +583,164 @@ public void testPushJoinIntoTableRequiresFullColumnHandleMappingInResult() } } + @Test(dataProvider = "testAutomaticJoinPushDownParams") + public void testAutomaticJoinPushDown(OptionalDouble leftRows, OptionalDouble righRows, OptionalDouble joinRows, boolean pushdownExpected) + { + Session pushdownAutomaticSession = Session.builder(MOCK_SESSION) + .setSystemProperty("join_pushdown", "AUTOMATIC") + .build(); + + try (RuleTester ruleTester = defaultRuleTester()) { + MockConnectorFactory connectorFactory = createMockConnectorFactory((session, applyJoinType, left, right, joinConditions, leftAssignments, rightAssignments) -> { + assertThat(((MockConnectorTableHandle) left).getTableName()).isEqualTo(TABLE_A_SCHEMA_TABLE_NAME); + assertThat(((MockConnectorTableHandle) right).getTableName()).isEqualTo(TABLE_B_SCHEMA_TABLE_NAME); + Assertions.assertThat(applyJoinType).isEqualTo(toSpiJoinType(JoinNode.Type.INNER)); + Assertions.assertThat(joinConditions).containsExactly(new JoinCondition(JoinCondition.Operator.EQUAL, COLUMN_A1_VARIABLE, COLUMN_B1_VARIABLE)); + + return Optional.of(new JoinApplicationResult<>( + JOIN_CONNECTOR_TABLE_HANDLE, + JOIN_TABLE_A_COLUMN_MAPPING, + JOIN_TABLE_B_COLUMN_MAPPING)); + }); + + ruleTester.getQueryRunner().createCatalog(MOCK_CATALOG, connectorFactory, ImmutableMap.of()); + + RuleAssert ruleAssert = ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + .overrideStats("left", new PlanNodeStatsEstimate(leftRows.orElse(Double.NaN), ImmutableMap.of())) + .overrideStats("right", new PlanNodeStatsEstimate(righRows.orElse(Double.NaN), ImmutableMap.of())) + .overrideStats("join", new PlanNodeStatsEstimate(joinRows.orElse(Double.NaN), ImmutableMap.of())) + .on(p -> { + Symbol columnA1Symbol = p.symbol(COLUMN_A1); + Symbol columnA2Symbol = p.symbol(COLUMN_A2); + Symbol columnB1Symbol = p.symbol(COLUMN_B1); + TableScanNode left = new TableScanNode( + new PlanNodeId("left"), + TABLE_A_HANDLE, + ImmutableList.of(columnA1Symbol, columnA2Symbol), + ImmutableMap.of( + columnA1Symbol, COLUMN_A1_HANDLE, + columnA2Symbol, COLUMN_A2_HANDLE), + TupleDomain.all(), + false); + + TableScanNode right = new TableScanNode( + new PlanNodeId("right"), + TABLE_B_HANDLE, + ImmutableList.of(columnB1Symbol), + ImmutableMap.of(columnB1Symbol, COLUMN_B1_HANDLE), + TupleDomain.all(), + false); + + return join(new PlanNodeId("join"), JoinNode.Type.INNER, left, right, new JoinNode.EquiJoinClause(columnA1Symbol, columnB1Symbol)); + }) + .withSession(pushdownAutomaticSession); + + if (pushdownExpected) { + ruleAssert.matches(tableScan(JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.getTableName())); + } + else { + ruleAssert.doesNotFire(); + } + } + } + + @DataProvider + public static Object[][] testAutomaticJoinPushDownParams() + { + return new Object[][] { + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.of(133), true}, + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.of(134), false}, // just above output size boundary + {OptionalDouble.empty(), OptionalDouble.of(200), OptionalDouble.of(250), false}, + {OptionalDouble.of(100), OptionalDouble.empty(), OptionalDouble.of(250), false}, + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.empty(), false}, + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.of(301), false} + }; + } + + @Test(dataProvider = "testJoinPushdownStatsIrrelevantIfPushdownForcedParams") + public void testJoinPushdownStatsIrrelevantIfPushdownForced(OptionalDouble leftRows, OptionalDouble righRows, OptionalDouble joinRows) + { + try (RuleTester ruleTester = defaultRuleTester()) { + MockConnectorFactory connectorFactory = createMockConnectorFactory((session, applyJoinType, left, right, joinConditions, leftAssignments, rightAssignments) -> { + assertThat(((MockConnectorTableHandle) left).getTableName()).isEqualTo(TABLE_A_SCHEMA_TABLE_NAME); + assertThat(((MockConnectorTableHandle) right).getTableName()).isEqualTo(TABLE_B_SCHEMA_TABLE_NAME); + Assertions.assertThat(applyJoinType).isEqualTo(toSpiJoinType(JoinNode.Type.INNER)); + Assertions.assertThat(joinConditions).containsExactly(new JoinCondition(JoinCondition.Operator.EQUAL, COLUMN_A1_VARIABLE, COLUMN_B1_VARIABLE)); + + return Optional.of(new JoinApplicationResult<>( + JOIN_CONNECTOR_TABLE_HANDLE, + JOIN_TABLE_A_COLUMN_MAPPING, + JOIN_TABLE_B_COLUMN_MAPPING)); + }); + + ruleTester.getQueryRunner().createCatalog(MOCK_CATALOG, connectorFactory, ImmutableMap.of()); + + ruleTester.assertThat(new PushJoinIntoTableScan(ruleTester.getMetadata())) + .overrideStats("left", new PlanNodeStatsEstimate(leftRows.orElse(Double.NaN), ImmutableMap.of())) + .overrideStats("right", new PlanNodeStatsEstimate(righRows.orElse(Double.NaN), ImmutableMap.of())) + .overrideStats("join", new PlanNodeStatsEstimate(joinRows.orElse(Double.NaN), ImmutableMap.of())) + .on(p -> { + Symbol columnA1Symbol = p.symbol(COLUMN_A1); + Symbol columnA2Symbol = p.symbol(COLUMN_A2); + Symbol columnB1Symbol = p.symbol(COLUMN_B1); + TableScanNode left = new TableScanNode( + new PlanNodeId("left"), + TABLE_A_HANDLE, + ImmutableList.of(columnA1Symbol, columnA2Symbol), + ImmutableMap.of( + columnA1Symbol, COLUMN_A1_HANDLE, + columnA2Symbol, COLUMN_A2_HANDLE), + TupleDomain.all(), + false); + + TableScanNode right = new TableScanNode( + new PlanNodeId("right"), + TABLE_B_HANDLE, + ImmutableList.of(columnB1Symbol), + ImmutableMap.of(columnB1Symbol, COLUMN_B1_HANDLE), + TupleDomain.all(), + false); + + return join(new PlanNodeId("join"), JoinNode.Type.INNER, left, right, new JoinNode.EquiJoinClause(columnA1Symbol, columnB1Symbol)); + }) + .withSession(MOCK_SESSION) + .matches(tableScan(JOIN_PUSHDOWN_SCHEMA_TABLE_NAME.getTableName())); + } + } + + @DataProvider + public static Object[][] testJoinPushdownStatsIrrelevantIfPushdownForcedParams() + { + return new Object[][] { + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.of(133)}, + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.of(134)}, + {OptionalDouble.empty(), OptionalDouble.of(200), OptionalDouble.of(250)}, + {OptionalDouble.of(100), OptionalDouble.empty(), OptionalDouble.of(250)}, + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.empty()}, + {OptionalDouble.of(100), OptionalDouble.of(200), OptionalDouble.of(301)} + }; + } + + private JoinNode join(PlanNodeId planNodeId, JoinNode.Type joinType, TableScanNode left, TableScanNode right, JoinNode.EquiJoinClause... criteria) + { + return new JoinNode( + planNodeId, + joinType, + left, + right, + ImmutableList.copyOf(criteria), + left.getOutputSymbols(), + right.getOutputSymbols(), + false, + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()); + } + private static TableHandle createTableHandle(ConnectorTableHandle tableHandle) { return createTableHandle(tableHandle, MOCK_CATALOG);