diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java index 7f119afc89cc..0e77b9f30d0e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedQueryScheduler.java @@ -68,11 +68,13 @@ import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.SplitSourceFactory; import io.trino.sql.planner.SubPlan; +import io.trino.sql.planner.optimizations.PlanNodeSearcher; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.RemoteSourceNode; import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.TableWriterNode; import io.trino.tracing.TrinoAttributes; import java.net.URI; @@ -112,6 +114,8 @@ import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; import static io.airlift.concurrent.MoreFutures.whenAnyComplete; import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.trino.SystemSessionProperties.getMaxHashPartitionCount; +import static io.trino.SystemSessionProperties.getMaxWriterTaskCount; import static io.trino.SystemSessionProperties.getQueryRetryAttempts; import static io.trino.SystemSessionProperties.getRetryDelayScaleFactor; import static io.trino.SystemSessionProperties.getRetryInitialDelay; @@ -875,6 +879,7 @@ public static DistributedStagesScheduler create( partitioning.partitionCount)); Map> bucketToPartitionMap = createBucketToPartitionMap( + queryStateMachine.getSession(), coordinatorStagesScheduler.getBucketToPartitionForStagesConsumedByCoordinator(), stageManager, partitioningCache); @@ -956,6 +961,7 @@ public static DistributedStagesScheduler create( } private static Map> createBucketToPartitionMap( + Session session, Map> bucketToPartitionForStagesConsumedByCoordinator, StageManager stageManager, Function partitioningCache) @@ -969,7 +975,7 @@ private static Map> createBucketToPartitionMap( partitioningCache, fragment.getRoot(), fragment.getRemoteSourceNodes(), - fragment.getPartitionCount()); + getFragmentMaxPartitionCount(session, fragment)); for (SqlStage childStage : stageManager.getChildren(stage.getStageId())) { result.put(childStage.getFragment().getId(), bucketToPartition); } @@ -982,7 +988,7 @@ private static Optional getBucketToPartition( Function partitioningCache, PlanNode fragmentRoot, List remoteSourceNodes, - Optional partitionCount) + int partitionCount) { if (partitioningHandle.equals(SOURCE_DISTRIBUTION) || partitioningHandle.equals(SCALED_WRITER_ROUND_ROBIN_DISTRIBUTION)) { return Optional.of(new int[1]); @@ -1050,7 +1056,7 @@ private static StageScheduler createStageScheduler( Span stageSpan = stageExecution.getStageSpan(); PlanFragment fragment = stageExecution.getFragment(); PartitioningHandle partitioningHandle = fragment.getPartitioning(); - Optional partitionCount = fragment.getPartitionCount(); + int partitionCount = getFragmentMaxPartitionCount(session, fragment); Map splitSources = splitSourceFactory.createSplitSources(session, stageSpan, fragment); if (!splitSources.isEmpty()) { queryStateMachine.addStateChangeListener(new StateChangeListener<>() @@ -1119,7 +1125,6 @@ public void stateChanged(QueryState newState) .collect(toImmutableList()); Supplier> writerTasksProvider = stageExecution::getTaskStatuses; - checkState(partitionCount.isPresent(), "Partition count cannot be empty when scale writers is used"); ScaledWriterScheduler scheduler = new ScaledWriterScheduler( stageExecution, sourceTasksProvider, @@ -1127,7 +1132,7 @@ public void stateChanged(QueryState newState) nodeScheduler.createNodeSelector(session, Optional.empty()), executor, getWriterScalingMinDataProcessed(session), - partitionCount.get()); + partitionCount); whenAllStages(childStageExecutions, StageExecution.State::isDone) .addListener(scheduler::finish, directExecutor()); @@ -1153,7 +1158,7 @@ public void stateChanged(QueryState newState) List stageNodeList; if (fragment.getRemoteSourceNodes().stream().allMatch(node -> node.getExchangeType() == REPLICATE)) { // no remote source - bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle); + bucketNodeMap = nodePartitioningManager.getBucketNodeMap(session, partitioningHandle, partitionCount); stageNodeList = new ArrayList<>(nodeScheduler.createNodeSelector(session, catalogHandle).allNodes()); Collections.shuffle(stageNodeList); } @@ -1176,6 +1181,13 @@ public void stateChanged(QueryState newState) tableExecuteContextManager); } + private static int getFragmentMaxPartitionCount(Session session, PlanFragment fragment) + { + return fragment.getPartitionCount().orElseGet(() -> PlanNodeSearcher.searchFrom(fragment.getRoot()) + .whereIsInstanceOfAny(TableWriterNode.class) + .matches() ? getMaxWriterTaskCount(session) : getMaxHashPartitionCount(session)); + } + private static void closeSplitSources(Collection splitSources) { for (SplitSource source : splitSources) { @@ -1576,12 +1588,12 @@ public Optional getFailedStageId() } } - private record PartitioningKey(PartitioningHandle handle, Optional partitionCount) + private record PartitioningKey(PartitioningHandle handle, int partitionCount) { - public PartitioningKey(PartitioningHandle handle, Optional partitionCount) + public PartitioningKey(PartitioningHandle handle, int partitionCount) { this.handle = requireNonNull(handle, "handle cannot be null"); - this.partitionCount = requireNonNull(partitionCount, "partitionCount cannot be null"); + this.partitionCount = partitionCount; } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java index 59b9e339f059..beb895a2728c 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/LocalExchange.java @@ -287,9 +287,9 @@ public static int getBucketCount(Session session, NodePartitioningManager nodePa { if (partitioning.getConnectorHandle() instanceof MergePartitioningHandle) { // TODO: can we always use this code path? - return nodePartitioningManager.getNodePartitioningMap(session, partitioning).getBucketToPartition().length; + return nodePartitioningManager.getNodePartitioningMap(session, partitioning, 1000).getBucketToPartition().length; } - return nodePartitioningManager.getBucketNodeMap(session, partitioning).getBucketCount(); + return nodePartitioningManager.getBucketCount(session, partitioning); } private static boolean isSystemPartitioning(PartitioningHandle partitioning) diff --git a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java index 631a1ea6b72b..efde572b5522 100644 --- a/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java +++ b/core/trino-main/src/main/java/io/trino/operator/output/SkewedPartitionRebalancer.java @@ -147,7 +147,7 @@ public static PartitionFunction createPartitionFunction( // compared to only a single hive bucket reaching the min limit. int bucketCount = (handle.getConnectorHandle() instanceof SystemPartitioningHandle) ? SCALE_WRITERS_PARTITION_COUNT - : nodePartitioningManager.getBucketNodeMap(session, handle).getBucketCount(); + : nodePartitioningManager.getBucketCount(session, handle); return nodePartitioningManager.getPartitionFunction( session, scheme, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java index c668e2bcee75..01e8df9c35e0 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/NodePartitioningManager.java @@ -51,7 +51,6 @@ import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.SystemSessionProperties.getFaultTolerantExecutionMaxPartitionCount; -import static io.trino.SystemSessionProperties.getMaxHashPartitionCount; import static io.trino.SystemSessionProperties.getRetryPolicy; import static io.trino.execution.TaskManagerConfig.MAX_WRITER_COUNT; import static io.trino.operator.exchange.LocalExchange.SCALE_WRITERS_MAX_PARTITIONS_PER_WRITER; @@ -136,12 +135,7 @@ public BucketFunction getBucketFunction(Session session, PartitioningHandle part return bucketFunction; } - public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle) - { - return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>(), Optional.empty()); - } - - public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, Optional partitionCount) + public NodePartitionMap getNodePartitioningMap(Session session, PartitioningHandle partitioningHandle, int partitionCount) { return getNodePartitioningMap(session, partitioningHandle, new HashMap<>(), new AtomicReference<>(), partitionCount); } @@ -155,7 +149,7 @@ private NodePartitionMap getNodePartitioningMap( PartitioningHandle partitioningHandle, Map> bucketToNodeCache, AtomicReference> systemPartitioningCache, - Optional partitionCount) + int partitionCount) { requireNonNull(session, "session is null"); requireNonNull(partitioningHandle, "partitioningHandle is null"); @@ -188,9 +182,10 @@ private NodePartitionMap getNodePartitioningMap( } else { CatalogHandle catalogHandle = requiredCatalogHandle(partitioningHandle); + List allNodes = getAllNodes(session, catalogHandle); bucketToNode = bucketToNodeCache.computeIfAbsent( connectorBucketNodeMap.getBucketCount(), - bucketCount -> createArbitraryBucketToNode(connectorBucketNodeMap.getCacheKeyHint(), getAllNodes(session, catalogHandle), bucketCount)); + bucketCount -> createArbitraryBucketToNode(connectorBucketNodeMap.getCacheKeyHint(), allNodes.subList(0, Math.min(allNodes.size(), partitionCount)), bucketCount)); } } @@ -215,7 +210,7 @@ private NodePartitionMap getNodePartitioningMap( return new NodePartitionMap(partitionToNode, bucketToPartition, getSplitToBucket(session, partitioningHandle, bucketToNode.size())); } - private List systemBucketToNode(Session session, PartitioningHandle partitioningHandle, AtomicReference> nodesCache, Optional partitionCount) + private List systemBucketToNode(Session session, PartitioningHandle partitioningHandle, AtomicReference> nodesCache, int partitionCount) { SystemPartitioning partitioning = ((SystemPartitioningHandle) partitioningHandle.getConnectorHandle()).getPartitioning(); @@ -227,7 +222,7 @@ private List systemBucketToNode(Session session, PartitioningHandl case FIXED -> { List value = nodesCache.get(); if (value == null) { - value = nodeSelector.selectRandomNodes(partitionCount.orElse(getMaxHashPartitionCount(session))); + value = nodeSelector.selectRandomNodes(partitionCount); nodesCache.set(value); } yield value; @@ -238,7 +233,13 @@ private List systemBucketToNode(Session session, PartitioningHandl return nodes; } - public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partitioningHandle) + public int getBucketCount(Session session, PartitioningHandle partitioningHandle) + { + // we don't care about partition count at all, just bucket count + return getBucketNodeMap(session, partitioningHandle, 1000).getBucketCount(); + } + + public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partitioningHandle, int partitionCount) { Optional bucketNodeMap = getConnectorBucketNodeMap(session, partitioningHandle); int bucketCount = bucketNodeMap.map(ConnectorBucketNodeMap::getBucketCount).orElseGet(() -> getDefaultBucketCount(session, partitioningHandle)); @@ -250,6 +251,7 @@ public BucketNodeMap getBucketNodeMap(Session session, PartitioningHandle partit long seed = bucketNodeMap.map(ConnectorBucketNodeMap::getCacheKeyHint).orElse(ThreadLocalRandom.current().nextLong()); List nodes = getAllNodes(session, requiredCatalogHandle(partitioningHandle)); + nodes = nodes.subList(0, Math.min(nodes.size(), partitionCount)); return new BucketNodeMap(splitToBucket, createArbitraryBucketToNode(seed, nodes, bucketCount)); } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 1df245797cee..6e48361f33f3 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -23,6 +23,7 @@ import io.airlift.units.DataSize; import io.airlift.units.Duration; import io.trino.Session; +import io.trino.execution.StageInfo; import io.trino.filesystem.FileIterator; import io.trino.filesystem.Location; import io.trino.filesystem.TrinoFileSystem; @@ -43,10 +44,12 @@ import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.planner.Plan; +import io.trino.sql.planner.optimizations.PlanNodeSearcher; import io.trino.sql.planner.plan.ExchangeNode; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.OutputNode; import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.TableWriterNode; import io.trino.sql.planner.plan.ValuesNode; import io.trino.testing.BaseConnectorTest; import io.trino.testing.DistributedQueryRunner; @@ -100,6 +103,7 @@ import java.util.stream.Stream; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; @@ -109,9 +113,12 @@ import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.trino.SystemSessionProperties.DETERMINE_PARTITION_COUNT_FOR_WRITE_ENABLED; import static io.trino.SystemSessionProperties.ENABLE_DYNAMIC_FILTERING; +import static io.trino.SystemSessionProperties.MAX_HASH_PARTITION_COUNT; +import static io.trino.SystemSessionProperties.MAX_WRITER_TASK_COUNT; import static io.trino.SystemSessionProperties.SCALE_WRITERS; import static io.trino.SystemSessionProperties.TASK_MAX_WRITER_COUNT; import static io.trino.SystemSessionProperties.TASK_MIN_WRITER_COUNT; +import static io.trino.SystemSessionProperties.TASK_SCALE_WRITERS_ENABLED; import static io.trino.SystemSessionProperties.USE_PREFERRED_WRITE_PARTITIONING; import static io.trino.plugin.iceberg.IcebergFileFormat.AVRO; import static io.trino.plugin.iceberg.IcebergFileFormat.ORC; @@ -1440,11 +1447,11 @@ public void testSortByAllTypes() // Insert "large" number of rows, supposedly topping over iceberg.writer-sort-buffer-size so that temporary files are utilized by the sorting writer. assertUpdate( """ - INSERT INTO %s - SELECT v.* - FROM (VALUES %s, %s, %s) v - CROSS JOIN UNNEST (sequence(1, 10_000)) a(i) - """.formatted(tableName, values, highValues, lowValues), 30000); + INSERT INTO %s + SELECT v.* + FROM (VALUES %s, %s, %s) v + CROSS JOIN UNNEST (sequence(1, 10_000)) a(i) + """.formatted(tableName, values, highValues, lowValues), 30000); assertUpdate("DROP TABLE " + tableName); } @@ -5292,6 +5299,40 @@ public void testProjectionPushdownOnPartitionedTableWithComments() assertUpdate("DROP TABLE IF EXISTS test_projection_pushdown_comments"); } + @Test + public void testMaxWriterTaskCount() + { + int workerCount = getQueryRunner().getNodeCount(); + checkState(workerCount > 1, "testMaxWriterTaskCount requires multiple workers"); + + assertUpdate("CREATE TABLE test_max_writer_task_count_insert (id BIGINT) WITH (partitioning = ARRAY['id'])"); + + Session session = Session.builder(getSession()) + // disable writer scaling for the test + .setSystemProperty(SCALE_WRITERS, "false") + .setSystemProperty(TASK_SCALE_WRITERS_ENABLED, "false") + // limit number of writer tasks to 1 + .setSystemProperty(MAX_WRITER_TASK_COUNT, "1") + .setSystemProperty(MAX_HASH_PARTITION_COUNT, Integer.toString(workerCount)) + .build(); + QueryId id = getDistributedQueryRunner() + .executeWithPlan(session, """ + INSERT INTO test_max_writer_task_count_insert + SELECT * FROM TABLE(sequence(start => 0, stop => 100, step => 1)) + """) + .queryId(); + StageInfo writerStage = getDistributedQueryRunner().getCoordinator() + .getFullQueryInfo(id) + .getOutputStage() + .orElseThrow() + .getSubStages() + .getFirst(); + assertThat(PlanNodeSearcher.searchFrom(writerStage.getPlan().getRoot()).whereIsInstanceOfAny(TableWriterNode.class).matches()).isTrue(); + assertThat(writerStage.getTasks().size()).isEqualTo(1); + + assertUpdate("DROP TABLE IF EXISTS test_max_writer_task_count_insert"); + } + @Test public void testOptimize() throws Exception @@ -8187,35 +8228,35 @@ public void testDynamicFilterWithExplicitPartitionFilter() TestTable dimensionTable = newTrinoTable("dimension_table", "(date date, following_holiday boolean, year int)")) { assertUpdate( """ - INSERT INTO %s - VALUES - (DATE '2023-01-01' , false, 2023), - (DATE '2023-01-02' , true, 2023), - (DATE '2023-01-03' , false, 2023)""".formatted(dimensionTable.getName()), 3); + INSERT INTO %s + VALUES + (DATE '2023-01-01' , false, 2023), + (DATE '2023-01-02' , true, 2023), + (DATE '2023-01-03' , false, 2023)""".formatted(dimensionTable.getName()), 3); assertUpdate( """ - INSERT INTO %s - VALUES - (DATE '2023-01-02' , '#2023#1', DECIMAL '122.12'), - (DATE '2023-01-02' , '#2023#2', DECIMAL '124.12'), - (DATE '2023-01-02' , '#2023#3', DECIMAL '99.99'), - (DATE '2023-01-02' , '#2023#4', DECIMAL '95.12'), - (DATE '2023-01-03' , '#2023#5', DECIMAL '199.12'), - (DATE '2023-01-04' , '#2023#6', DECIMAL '99.55'), - (DATE '2023-01-05' , '#2023#7', DECIMAL '50.11'), - (DATE '2023-01-05' , '#2023#8', DECIMAL '60.20'), - (DATE '2023-01-05' , '#2023#9', DECIMAL '70.75'), - (DATE '2023-01-05' , '#2023#10', DECIMAL '80.12')""".formatted(salesTable.getName()), 10); + INSERT INTO %s + VALUES + (DATE '2023-01-02' , '#2023#1', DECIMAL '122.12'), + (DATE '2023-01-02' , '#2023#2', DECIMAL '124.12'), + (DATE '2023-01-02' , '#2023#3', DECIMAL '99.99'), + (DATE '2023-01-02' , '#2023#4', DECIMAL '95.12'), + (DATE '2023-01-03' , '#2023#5', DECIMAL '199.12'), + (DATE '2023-01-04' , '#2023#6', DECIMAL '99.55'), + (DATE '2023-01-05' , '#2023#7', DECIMAL '50.11'), + (DATE '2023-01-05' , '#2023#8', DECIMAL '60.20'), + (DATE '2023-01-05' , '#2023#9', DECIMAL '70.75'), + (DATE '2023-01-05' , '#2023#10', DECIMAL '80.12')""".formatted(salesTable.getName()), 10); String selectQuery = """ - SELECT receipt_id - FROM %s s - JOIN %s d - ON s.date = d.date - WHERE - d.following_holiday = true AND - d.date BETWEEN DATE '2023-01-01' AND DATE '2024-01-01'""".formatted(salesTable.getName(), dimensionTable.getName()); + SELECT receipt_id + FROM %s s + JOIN %s d + ON s.date = d.date + WHERE + d.following_holiday = true AND + d.date BETWEEN DATE '2023-01-01' AND DATE '2024-01-01'""".formatted(salesTable.getName(), dimensionTable.getName()); MaterializedResultWithPlan result = getDistributedQueryRunner().executeWithPlan( Session.builder(getSession()) .setCatalogSessionProperty(catalog, DYNAMIC_FILTERING_WAIT_TIMEOUT, "10s") @@ -8619,7 +8660,8 @@ public void testSetIllegalExtraPropertyKey() } } - @Test // regression test for https://github.com/trinodb/trino/issues/22922 + // regression test for https://github.com/trinodb/trino/issues/22922 + @Test void testArrayElementChange() { try (TestTable table = newTrinoTable( diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergParquetFaultTolerantExecutionConnectorTest.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergParquetFaultTolerantExecutionConnectorTest.java index c4a6e5a38f6e..c7b1a17de184 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergParquetFaultTolerantExecutionConnectorTest.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/iceberg/TestIcebergParquetFaultTolerantExecutionConnectorTest.java @@ -78,6 +78,13 @@ public void testStatsBasedRepartitionDataOnInsert() abort("We always get 3 partitions with FTE"); } + @Test + @Override + public void testMaxWriterTaskCount() + { + abort("Max writer task count is not supported with FTE"); + } + @Override protected boolean isFileSorted(String path, String sortColumnName) {