diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DeterminePartitionCount.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DeterminePartitionCount.java index e24c9dac47ce..5162e9d684ff 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DeterminePartitionCount.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/DeterminePartitionCount.java @@ -24,6 +24,7 @@ import io.trino.execution.warnings.WarningCollector; import io.trino.operator.RetryPolicy; import io.trino.sql.planner.PartitioningHandle; +import io.trino.sql.planner.PartitioningScheme; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.SystemPartitioningHandle; @@ -59,6 +60,7 @@ import static io.trino.SystemSessionProperties.isDeterminePartitionCountForWriteEnabled; import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar; import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith; import static java.lang.Double.isNaN; import static java.lang.Math.max; @@ -66,12 +68,12 @@ /** * This rule looks at the amount of data read and processed by the query to determine the value of partition count - * used for remote exchanges. It helps to increase the concurrency of the engine in the case of large cluster. + * used for remote partitioned exchanges. It helps to increase the concurrency of the engine in the case of large cluster. * This rule is also cautious about lack of or incorrect statistics therefore it skips for input multiplying nodes like * CROSS JOIN or UNNEST. * * E.g. 1: - * Given query: SELECT count(column_a) FROM table_with_stats_a + * Given query: SELECT count(column_a) FROM table_with_stats_a group by column_b * config: * MIN_INPUT_SIZE_PER_TASK: 500 MB * Input table data size: 1000 MB @@ -114,6 +116,11 @@ public PlanNode optimize( requireNonNull(types, "types is null"); requireNonNull(tableStatsProvider, "tableStatsProvider is null"); + // Skip partition count determination if no partitioned remote exchanges exist in the plan anyway + if (!isEligibleRemoteExchangePresent(plan)) { + return plan; + } + // Unless enabled, skip for write nodes since writing partitioned data with small amount of nodes could cause // memory related issues even when the amount of data is small. boolean isWriteQuery = PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(INSERT_NODES).matches(); @@ -308,6 +315,24 @@ private static double getSourceNodesOutputStats(PlanNode root, ToDoubleFunction< .sum(); } + private static boolean isEligibleRemoteExchangePresent(PlanNode root) + { + return PlanNodeSearcher.searchFrom(root) + .where(node -> node instanceof ExchangeNode exchangeNode && isEligibleRemoteExchange(exchangeNode)) + .matches(); + } + + private static boolean isEligibleRemoteExchange(ExchangeNode exchangeNode) + { + if (exchangeNode.getScope() != REMOTE || exchangeNode.getType() != REPARTITION) { + return false; + } + PartitioningHandle partitioningHandle = exchangeNode.getPartitioningScheme().getPartitioning().getHandle(); + return !partitioningHandle.isScaleWriters() + && !partitioningHandle.isSingleNode() + && partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle; + } + private static class Rewriter extends SimplePlanRewriter { @@ -321,20 +346,20 @@ private Rewriter(int partitionCount) @Override public PlanNode visitExchange(ExchangeNode node, RewriteContext context) { - PartitioningHandle handle = node.getPartitioningScheme().getPartitioning().getHandle(); - if (!(node.getScope() == REMOTE && handle.getConnectorHandle() instanceof SystemPartitioningHandle)) { - return node; - } - List sources = node.getSources().stream() .map(context::rewrite) .collect(toImmutableList()); + PartitioningScheme partitioningScheme = node.getPartitioningScheme(); + if (isEligibleRemoteExchange(node)) { + partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(partitionCount)); + } + return new ExchangeNode( node.getId(), node.getType(), node.getScope(), - node.getPartitioningScheme().withPartitionCount(Optional.of(partitionCount)), + partitioningScheme, sources, node.getInputs(), node.getOrderingScheme()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java index 4d72d4eba0da..5f9be0ec1889 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanMatchPattern.java @@ -625,6 +625,11 @@ public static PlanMatchPattern exchange(ExchangeNode.Scope scope, Optional partitionCount, PlanMatchPattern... sources) + { + return exchange(scope, Optional.of(type), Optional.empty(), ImmutableList.of(), ImmutableSet.of(), Optional.empty(), ImmutableList.of(), Optional.of(partitionCount), sources); + } + public static PlanMatchPattern exchange(ExchangeNode.Scope scope, PartitioningHandle partitioningHandle, Optional partitionCount, PlanMatchPattern... sources) { return exchange(scope, Optional.empty(), Optional.of(partitioningHandle), ImmutableList.of(), ImmutableSet.of(), Optional.empty(), ImmutableList.of(), Optional.of(partitionCount), sources); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java index 8da217c0695d..73bd8eee4879 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestDeterminePartitionCount.java @@ -42,6 +42,7 @@ import static io.trino.spi.statistics.TableStatistics.empty; import static io.trino.spi.type.VarcharType.VARCHAR; import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.output; @@ -49,6 +50,9 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL; import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE; +import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE; import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -101,17 +105,53 @@ protected LocalQueryRunner createLocalQueryRunner() } @Test - public void testPlanWhenTableStatisticsArePresent() + public void testSimpleSelect() { - @Language("SQL") String query = """ - SELECT count(column_a) FROM table_with_stats_a - """; + @Language("SQL") String query = "SELECT * FROM table_with_stats_a"; - // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 5 for remote exchanges + // DeterminePartitionCount optimizer rule should not fire since no remote exchanges are present assertDistributedPlan( query, Session.builder(getQueryRunner().getDefaultSession()) - .setSystemProperty(MAX_HASH_PARTITION_COUNT, "10") + .setSystemProperty(MAX_HASH_PARTITION_COUNT, "100") + .setSystemProperty(MIN_HASH_PARTITION_COUNT, "4") + .setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB") + .setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400") + .build(), + output( + node(TableScanNode.class))); + } + + @Test + public void testSimpleFilter() + { + @Language("SQL") String query = "SELECT column_a FROM table_with_stats_a WHERE column_b IS NULL"; + + // DeterminePartitionCount optimizer rule should not fire since no remote exchanges are present + assertDistributedPlan( + query, + Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(MAX_HASH_PARTITION_COUNT, "100") + .setSystemProperty(MIN_HASH_PARTITION_COUNT, "4") + .setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB") + .setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400") + .build(), + output( + project( + filter("column_b IS NULL", + tableScan("table_with_stats_a", ImmutableMap.of("column_a", "column_a", "column_b", "column_b")))))); + } + + @Test + public void testSimpleCount() + { + @Language("SQL") String query = "SELECT count(*) FROM table_with_stats_a"; + + // DeterminePartitionCount optimizer rule should not fire since no remote repartition exchanges are present + assertDistributedPlan( + query, + Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(MAX_HASH_PARTITION_COUNT, "100") .setSystemProperty(MIN_HASH_PARTITION_COUNT, "4") .setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB") .setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400") @@ -119,11 +159,37 @@ SELECT count(column_a) FROM table_with_stats_a output( node(AggregationNode.class, exchange(LOCAL, - exchange(REMOTE, Optional.of(5), + exchange(REMOTE, GATHER, Optional.empty(), node(AggregationNode.class, node(TableScanNode.class))))))); } + @Test + public void testPlanWhenTableStatisticsArePresent() + { + @Language("SQL") String query = """ + SELECT count(column_a) FROM table_with_stats_a group by column_b + """; + + // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges + assertDistributedPlan( + query, + Session.builder(getQueryRunner().getDefaultSession()) + .setSystemProperty(MAX_HASH_PARTITION_COUNT, "20") + .setSystemProperty(MIN_HASH_PARTITION_COUNT, "4") + .setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB") + .setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400") + .build(), + output( + project( + node(AggregationNode.class, + exchange(LOCAL, + exchange(REMOTE, REPARTITION, Optional.of(10), + node(AggregationNode.class, + project( + node(TableScanNode.class))))))))); + } + @Test public void testPlanWhenTableStatisticsAreAbsent() { @@ -184,7 +250,7 @@ public void testPlanWhenCrossJoinIsScalar() SELECT * FROM table_with_stats_a CROSS JOIN (select max(column_a) from table_with_stats_b) t(a) """; - // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 1 for remote exchanges + // DeterminePartitionCount optimizer rule should not fire since no remote repartitioning exchanges are present assertDistributedPlan( query, Session.builder(getQueryRunner().getDefaultSession()) @@ -197,10 +263,10 @@ public void testPlanWhenCrossJoinIsScalar() join(INNER, builder -> builder .right( exchange(LOCAL, - exchange(REMOTE, Optional.of(15), + exchange(REMOTE, REPLICATE, Optional.empty(), node(AggregationNode.class, exchange(LOCAL, - exchange(REMOTE, Optional.of(15), + exchange(REMOTE, GATHER, Optional.empty(), node(AggregationNode.class, node(TableScanNode.class)))))))) .left(node(TableScanNode.class))))); @@ -242,7 +308,7 @@ public void testPlanWhenJoinNodeOutputIsBiggerThanRowsScanned() SELECT a.column_a FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_a = b.column_a """; - // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 20 for remote exchanges + // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges assertDistributedPlan( query, Session.builder(getQueryRunner().getDefaultSession()) @@ -336,7 +402,7 @@ public void testPlanWhenUnionNodeOutputIsBiggerThanJoinOutput() FROM table_with_stats_b """; - // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges + // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 20 for remote exchanges assertDistributedPlan( query, Session.builder(getQueryRunner().getDefaultSession()) @@ -346,17 +412,16 @@ public void testPlanWhenUnionNodeOutputIsBiggerThanJoinOutput() .setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400") .build(), output( - // partition count should be 15 with just join node but since we also have union, it should be 20 - exchange(REMOTE, Optional.of(20), + exchange(REMOTE, GATHER, join(INNER, builder -> builder .equiCriteria("column_a", "column_a_1") .right(exchange(LOCAL, // partition count should be 15 with just join node but since we also have union, it should be 20 - exchange(REMOTE, Optional.of(20), + exchange(REMOTE, REPARTITION, Optional.of(20), project( tableScan("table_with_stats_b", ImmutableMap.of("column_a_1", "column_a")))))) // partition count should be 15 with just join node but since we also have union, it should be 20 - .left(exchange(REMOTE, Optional.of(20), + .left(exchange(REMOTE, REPARTITION, Optional.of(20), project( node(FilterNode.class, tableScan("table_with_stats_a", ImmutableMap.of("column_a", "column_a", "column_b_0", "column_b"))))))), @@ -367,7 +432,7 @@ public void testPlanWhenUnionNodeOutputIsBiggerThanJoinOutput() public void testPlanWhenEstimatedPartitionCountBasedOnRowsIsMoreThanOutputSize() { @Language("SQL") String query = """ - SELECT count(column_a) FROM table_with_stats_a + SELECT count(column_a) FROM table_with_stats_a group by column_b """; // DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges @@ -381,10 +446,12 @@ SELECT count(column_a) FROM table_with_stats_a .setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "20") .build(), output( - node(AggregationNode.class, - exchange(LOCAL, - exchange(REMOTE, Optional.of(10), - node(AggregationNode.class, - node(TableScanNode.class))))))); + project( + node(AggregationNode.class, + exchange(LOCAL, + exchange(REMOTE, REPARTITION, Optional.of(10), + node(AggregationNode.class, + project( + node(TableScanNode.class))))))))); } } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java index bbcfcfebac95..2bc9aba57314 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeFileOperations.java @@ -132,13 +132,11 @@ public void testReadWholePartition() assertFileSystemAccesses( "SELECT count(*) FROM test_read_part_key WHERE key = 'p1'", ImmutableMultiset.builder() - .addCopies(new FileOperation(LAST_CHECKPOINT, "_last_checkpoint", INPUT_FILE_NEW_STREAM), 3) // TODO (https://github.com/trinodb/trino/issues/16782) should be checked once per query + .addCopies(new FileOperation(LAST_CHECKPOINT, "_last_checkpoint", INPUT_FILE_NEW_STREAM), 2) // TODO (https://github.com/trinodb/trino/issues/16782) should be checked once per query .addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000000.json", INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000001.json", INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000002.json", INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000003.json", INPUT_FILE_NEW_STREAM), 3) // TODO (https://github.com/trinodb/trino/issues/16780) why is last transaction log accessed more times than others? - .addCopies(new FileOperation(TRINO_EXTENDED_STATS_JSON, "extended_stats.json", INPUT_FILE_EXISTS), 1) - .addCopies(new FileOperation(TRINO_EXTENDED_STATS_JSON, "extended_stats.json", INPUT_FILE_NEW_STREAM), 1) + .addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000003.json", INPUT_FILE_NEW_STREAM), 2) // TODO (https://github.com/trinodb/trino/issues/16780) why is last transaction log accessed more times than others? .build()); // Read partition and synthetic columns diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/thrift/TestHiveMetastoreAccessOperations.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/thrift/TestHiveMetastoreAccessOperations.java index 508407ff331f..55e7c3061ef4 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/thrift/TestHiveMetastoreAccessOperations.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/metastore/thrift/TestHiveMetastoreAccessOperations.java @@ -99,7 +99,6 @@ public void testSelect() assertMetastoreInvocations("SELECT * FROM test_select_from", ImmutableMultiset.builder() .add(GET_TABLE) - .add(GET_TABLE_STATISTICS) .build()); } @@ -113,7 +112,6 @@ public void testSelectPartitionedTable() .addCopies(GET_TABLE, 2) .add(GET_PARTITION_NAMES_BY_FILTER) .add(GET_PARTITIONS_BY_NAMES) - .add(GET_PARTITION_STATISTICS) .build()); assertUpdate("INSERT INTO test_select_partition SELECT 2 AS data, 20 AS part", 1); @@ -122,7 +120,6 @@ public void testSelectPartitionedTable() .addCopies(GET_TABLE, 2) .add(GET_PARTITION_NAMES_BY_FILTER) .add(GET_PARTITIONS_BY_NAMES) - .add(GET_PARTITION_STATISTICS) .build()); // Specify a specific partition @@ -142,7 +139,6 @@ public void testSelectWithFilter() assertMetastoreInvocations("SELECT * FROM test_select_from_where WHERE age = 2", ImmutableMultiset.builder() .add(GET_TABLE) - .add(GET_TABLE_STATISTICS) .build()); } @@ -155,7 +151,6 @@ public void testSelectFromView() assertMetastoreInvocations("SELECT * FROM test_select_view_view", ImmutableMultiset.builder() .addCopies(GET_TABLE, 2) - .add(GET_TABLE_STATISTICS) .build()); } @@ -168,7 +163,6 @@ public void testSelectFromViewWithFilter() assertMetastoreInvocations("SELECT * FROM test_select_view_where_view WHERE age = 2", ImmutableMultiset.builder() .addCopies(GET_TABLE, 2) - .add(GET_TABLE_STATISTICS) .build()); } @@ -241,7 +235,6 @@ public void testAnalyze() assertMetastoreInvocations("ANALYZE test_analyze", ImmutableMultiset.builder() .add(GET_TABLE) - .add(GET_TABLE_STATISTICS) .add(UPDATE_TABLE_STATISTICS) .build()); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java index 8ef14b24945d..236d551f406d 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlJdbcConnectionCreation.java @@ -86,7 +86,7 @@ public Object[][] testCases() {"SELECT * FROM information_schema.schemata", 1, Optional.empty()}, {"SELECT * FROM information_schema.tables", 1, Optional.empty()}, {"SELECT * FROM information_schema.columns", 1, Optional.empty()}, - {"SELECT * FROM nation", 3, Optional.empty()}, + {"SELECT * FROM nation", 2, Optional.empty()}, {"SELECT * FROM TABLE (system.query(query => 'SELECT * FROM tpch.nation'))", 2, Optional.empty()}, {"CREATE TABLE copy_of_nation AS SELECT * FROM nation", 6, Optional.empty()}, {"INSERT INTO copy_of_nation SELECT * FROM nation", 6, Optional.empty()},