Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.trino.execution.warnings.WarningCollector;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.PartitioningHandle;
import io.trino.sql.planner.PartitioningScheme;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SystemPartitioningHandle;
Expand Down Expand Up @@ -59,19 +60,20 @@
import static io.trino.SystemSessionProperties.isDeterminePartitionCountForWriteEnabled;
import static io.trino.sql.planner.optimizations.QueryCardinalityUtil.isAtMostScalar;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE;
import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.trino.sql.planner.plan.SimplePlanRewriter.rewriteWith;
import static java.lang.Double.isNaN;
import static java.lang.Math.max;
import static java.util.Objects.requireNonNull;

/**
* This rule looks at the amount of data read and processed by the query to determine the value of partition count
* used for remote exchanges. It helps to increase the concurrency of the engine in the case of large cluster.
* used for remote partitioned exchanges. It helps to increase the concurrency of the engine in the case of large cluster.
* This rule is also cautious about lack of or incorrect statistics therefore it skips for input multiplying nodes like
* CROSS JOIN or UNNEST.
*
* E.g. 1:
* Given query: SELECT count(column_a) FROM table_with_stats_a
* Given query: SELECT count(column_a) FROM table_with_stats_a group by column_b
* config:
* MIN_INPUT_SIZE_PER_TASK: 500 MB
* Input table data size: 1000 MB
Expand Down Expand Up @@ -114,6 +116,11 @@ public PlanNode optimize(
requireNonNull(types, "types is null");
requireNonNull(tableStatsProvider, "tableStatsProvider is null");

// Skip partition count determination if no partitioned remote exchanges exist in the plan anyway
if (!isEligibleRemoteExchangePresent(plan)) {
return plan;
}

// Unless enabled, skip for write nodes since writing partitioned data with small amount of nodes could cause
// memory related issues even when the amount of data is small.
boolean isWriteQuery = PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(INSERT_NODES).matches();
Expand Down Expand Up @@ -308,6 +315,24 @@ private static double getSourceNodesOutputStats(PlanNode root, ToDoubleFunction<
.sum();
}

private static boolean isEligibleRemoteExchangePresent(PlanNode root)
{
return PlanNodeSearcher.searchFrom(root)
.where(node -> node instanceof ExchangeNode exchangeNode && isEligibleRemoteExchange(exchangeNode))
.matches();
}

private static boolean isEligibleRemoteExchange(ExchangeNode exchangeNode)
{
if (exchangeNode.getScope() != REMOTE || exchangeNode.getType() != REPARTITION) {
return false;
}
PartitioningHandle partitioningHandle = exchangeNode.getPartitioningScheme().getPartitioning().getHandle();
return !partitioningHandle.isScaleWriters()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is probably enforced by PlanNodeSearcher.searchFrom(plan).whereIsInstanceOfAny(INSERT_NODES).matches(); check.

@gaurav8297 ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was until #17024 enabled automatic partition determination for write queries behind a session property (default: false).

&& !partitioningHandle.isSingleNode()
&& partitioningHandle.getConnectorHandle() instanceof SystemPartitioningHandle;
}

private static class Rewriter
extends SimplePlanRewriter<Void>
{
Expand All @@ -321,20 +346,20 @@ private Rewriter(int partitionCount)
@Override
public PlanNode visitExchange(ExchangeNode node, RewriteContext<Void> context)
{
PartitioningHandle handle = node.getPartitioningScheme().getPartitioning().getHandle();
if (!(node.getScope() == REMOTE && handle.getConnectorHandle() instanceof SystemPartitioningHandle)) {
return node;
}

List<PlanNode> sources = node.getSources().stream()
.map(context::rewrite)
.collect(toImmutableList());

PartitioningScheme partitioningScheme = node.getPartitioningScheme();
if (isEligibleRemoteExchange(node)) {
partitioningScheme = partitioningScheme.withPartitionCount(Optional.of(partitionCount));
}

return new ExchangeNode(
node.getId(),
node.getType(),
node.getScope(),
node.getPartitioningScheme().withPartitionCount(Optional.of(partitionCount)),
partitioningScheme,
sources,
node.getInputs(),
node.getOrderingScheme());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,11 @@ public static PlanMatchPattern exchange(ExchangeNode.Scope scope, Optional<Integ
return exchange(scope, Optional.empty(), Optional.empty(), ImmutableList.of(), ImmutableSet.of(), Optional.empty(), ImmutableList.of(), Optional.of(partitionCount), sources);
}

public static PlanMatchPattern exchange(ExchangeNode.Scope scope, ExchangeNode.Type type, Optional<Integer> partitionCount, PlanMatchPattern... sources)
{
return exchange(scope, Optional.of(type), Optional.empty(), ImmutableList.of(), ImmutableSet.of(), Optional.empty(), ImmutableList.of(), Optional.of(partitionCount), sources);
}

public static PlanMatchPattern exchange(ExchangeNode.Scope scope, PartitioningHandle partitioningHandle, Optional<Integer> partitionCount, PlanMatchPattern... sources)
{
return exchange(scope, Optional.empty(), Optional.of(partitioningHandle), ImmutableList.of(), ImmutableSet.of(), Optional.empty(), ImmutableList.of(), Optional.of(partitionCount), sources);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
import static io.trino.spi.statistics.TableStatistics.empty;
import static io.trino.spi.type.VarcharType.VARCHAR;
import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange;
import static io.trino.sql.planner.assertions.PlanMatchPattern.filter;
import static io.trino.sql.planner.assertions.PlanMatchPattern.join;
import static io.trino.sql.planner.assertions.PlanMatchPattern.node;
import static io.trino.sql.planner.assertions.PlanMatchPattern.output;
import static io.trino.sql.planner.assertions.PlanMatchPattern.project;
import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.LOCAL;
import static io.trino.sql.planner.plan.ExchangeNode.Scope.REMOTE;
import static io.trino.sql.planner.plan.ExchangeNode.Type.GATHER;
import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION;
import static io.trino.sql.planner.plan.ExchangeNode.Type.REPLICATE;
import static io.trino.sql.planner.plan.JoinNode.Type.INNER;
import static io.trino.testing.TestingSession.testSessionBuilder;

Expand Down Expand Up @@ -101,29 +105,91 @@ protected LocalQueryRunner createLocalQueryRunner()
}

@Test
public void testPlanWhenTableStatisticsArePresent()
public void testSimpleSelect()
{
@Language("SQL") String query = """
SELECT count(column_a) FROM table_with_stats_a
""";
@Language("SQL") String query = "SELECT * FROM table_with_stats_a";

// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 5 for remote exchanges
// DeterminePartitionCount optimizer rule should not fire since no remote exchanges are present
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "10")
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "100")
.setSystemProperty(MIN_HASH_PARTITION_COUNT, "4")
.setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB")
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
.build(),
output(
node(TableScanNode.class)));
}

@Test
public void testSimpleFilter()
{
@Language("SQL") String query = "SELECT column_a FROM table_with_stats_a WHERE column_b IS NULL";

// DeterminePartitionCount optimizer rule should not fire since no remote exchanges are present
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "100")
.setSystemProperty(MIN_HASH_PARTITION_COUNT, "4")
.setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB")
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
.build(),
output(
project(
filter("column_b IS NULL",
tableScan("table_with_stats_a", ImmutableMap.of("column_a", "column_a", "column_b", "column_b"))))));
}

@Test
public void testSimpleCount()
{
@Language("SQL") String query = "SELECT count(*) FROM table_with_stats_a";

// DeterminePartitionCount optimizer rule should not fire since no remote repartition exchanges are present
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "100")
.setSystemProperty(MIN_HASH_PARTITION_COUNT, "4")
.setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB")
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
.build(),
output(
node(AggregationNode.class,
exchange(LOCAL,
exchange(REMOTE, Optional.of(5),
exchange(REMOTE, GATHER, Optional.empty(),
node(AggregationNode.class,
node(TableScanNode.class)))))));
}

@Test
public void testPlanWhenTableStatisticsArePresent()
{
@Language("SQL") String query = """
SELECT count(column_a) FROM table_with_stats_a group by column_b
""";

// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(MAX_HASH_PARTITION_COUNT, "20")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you could move it as default session setup for these tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok to explicitly configure the methods separately since many of the tests still use different values to assert different count values selected from the parameters used.

.setSystemProperty(MIN_HASH_PARTITION_COUNT, "4")
.setSystemProperty(MIN_INPUT_SIZE_PER_TASK, "20MB")
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
.build(),
output(
project(
node(AggregationNode.class,
exchange(LOCAL,
exchange(REMOTE, REPARTITION, Optional.of(10),
node(AggregationNode.class,
project(
node(TableScanNode.class)))))))));
}

@Test
public void testPlanWhenTableStatisticsAreAbsent()
{
Expand Down Expand Up @@ -184,7 +250,7 @@ public void testPlanWhenCrossJoinIsScalar()
SELECT * FROM table_with_stats_a CROSS JOIN (select max(column_a) from table_with_stats_b) t(a)
""";

// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 1 for remote exchanges
// DeterminePartitionCount optimizer rule should not fire since no remote repartitioning exchanges are present
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
Expand All @@ -197,10 +263,10 @@ public void testPlanWhenCrossJoinIsScalar()
join(INNER, builder -> builder
.right(
exchange(LOCAL,
exchange(REMOTE, Optional.of(15),
exchange(REMOTE, REPLICATE, Optional.empty(),
node(AggregationNode.class,
exchange(LOCAL,
exchange(REMOTE, Optional.of(15),
exchange(REMOTE, GATHER, Optional.empty(),
node(AggregationNode.class,
node(TableScanNode.class))))))))
.left(node(TableScanNode.class)))));
Expand Down Expand Up @@ -242,7 +308,7 @@ public void testPlanWhenJoinNodeOutputIsBiggerThanRowsScanned()
SELECT a.column_a FROM table_with_stats_a as a JOIN table_with_stats_b as b ON a.column_a = b.column_a
""";

// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 20 for remote exchanges
// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
Expand Down Expand Up @@ -336,7 +402,7 @@ public void testPlanWhenUnionNodeOutputIsBiggerThanJoinOutput()
FROM table_with_stats_b
""";

// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges
// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 20 for remote exchanges
assertDistributedPlan(
query,
Session.builder(getQueryRunner().getDefaultSession())
Expand All @@ -346,17 +412,16 @@ public void testPlanWhenUnionNodeOutputIsBiggerThanJoinOutput()
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "400")
.build(),
output(
// partition count should be 15 with just join node but since we also have union, it should be 20
exchange(REMOTE, Optional.of(20),
exchange(REMOTE, GATHER,
join(INNER, builder -> builder
.equiCriteria("column_a", "column_a_1")
.right(exchange(LOCAL,
// partition count should be 15 with just join node but since we also have union, it should be 20
exchange(REMOTE, Optional.of(20),
exchange(REMOTE, REPARTITION, Optional.of(20),
project(
tableScan("table_with_stats_b", ImmutableMap.of("column_a_1", "column_a"))))))
// partition count should be 15 with just join node but since we also have union, it should be 20
.left(exchange(REMOTE, Optional.of(20),
.left(exchange(REMOTE, REPARTITION, Optional.of(20),
project(
node(FilterNode.class,
tableScan("table_with_stats_a", ImmutableMap.of("column_a", "column_a", "column_b_0", "column_b"))))))),
Expand All @@ -367,7 +432,7 @@ public void testPlanWhenUnionNodeOutputIsBiggerThanJoinOutput()
public void testPlanWhenEstimatedPartitionCountBasedOnRowsIsMoreThanOutputSize()
{
@Language("SQL") String query = """
SELECT count(column_a) FROM table_with_stats_a
SELECT count(column_a) FROM table_with_stats_a group by column_b
""";

// DeterminePartitionCount optimizer rule should fire and set the partitionCount to 10 for remote exchanges
Expand All @@ -381,10 +446,12 @@ SELECT count(column_a) FROM table_with_stats_a
.setSystemProperty(MIN_INPUT_ROWS_PER_TASK, "20")
.build(),
output(
node(AggregationNode.class,
exchange(LOCAL,
exchange(REMOTE, Optional.of(10),
node(AggregationNode.class,
node(TableScanNode.class)))))));
project(
node(AggregationNode.class,
exchange(LOCAL,
exchange(REMOTE, REPARTITION, Optional.of(10),
node(AggregationNode.class,
project(
node(TableScanNode.class)))))))));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,11 @@ public void testReadWholePartition()
assertFileSystemAccesses(
"SELECT count(*) FROM test_read_part_key WHERE key = 'p1'",
ImmutableMultiset.<FileOperation>builder()
.addCopies(new FileOperation(LAST_CHECKPOINT, "_last_checkpoint", INPUT_FILE_NEW_STREAM), 3) // TODO (https://github.com/trinodb/trino/issues/16782) should be checked once per query
.addCopies(new FileOperation(LAST_CHECKPOINT, "_last_checkpoint", INPUT_FILE_NEW_STREAM), 2) // TODO (https://github.com/trinodb/trino/issues/16782) should be checked once per query
.addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000000.json", INPUT_FILE_NEW_STREAM), 1)
.addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000001.json", INPUT_FILE_NEW_STREAM), 1)
.addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000002.json", INPUT_FILE_NEW_STREAM), 1)
.addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000003.json", INPUT_FILE_NEW_STREAM), 3) // TODO (https://github.com/trinodb/trino/issues/16780) why is last transaction log accessed more times than others?
.addCopies(new FileOperation(TRINO_EXTENDED_STATS_JSON, "extended_stats.json", INPUT_FILE_EXISTS), 1)
.addCopies(new FileOperation(TRINO_EXTENDED_STATS_JSON, "extended_stats.json", INPUT_FILE_NEW_STREAM), 1)
.addCopies(new FileOperation(TRANSACTION_LOG_JSON, "00000000000000000003.json", INPUT_FILE_NEW_STREAM), 2) // TODO (https://github.com/trinodb/trino/issues/16780) why is last transaction log accessed more times than others?
.build());

// Read partition and synthetic columns
Expand Down
Loading