diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java index 57e0237ff2f6..1a6b1542e9ed 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/FixedSourcePartitionedScheduler.java @@ -20,25 +20,25 @@ import io.airlift.log.Logger; import io.trino.execution.RemoteTask; import io.trino.execution.TableExecuteContextManager; -import io.trino.execution.scheduler.ScheduleResult.BlockedReason; import io.trino.metadata.InternalNode; import io.trino.metadata.Split; import io.trino.server.DynamicFilterService; import io.trino.split.SplitSource; import io.trino.sql.planner.plan.PlanNodeId; +import java.util.ArrayDeque; import java.util.ArrayList; import java.util.HashMap; -import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Queue; import java.util.Set; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; -import static com.google.common.base.Verify.verify; -import static io.airlift.concurrent.MoreFutures.whenAnyComplete; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.trino.execution.scheduler.SourcePartitionedScheduler.newSourcePartitionedSchedulerAsSourceScheduler; import static java.util.Objects.requireNonNull; @@ -49,7 +49,7 @@ public class FixedSourcePartitionedScheduler private final StageExecution stageExecution; private final List nodes; - private final List sourceSchedulers; + private final Queue sourceSchedulers; private final PartitionIdAllocator partitionIdAllocator; private final Map scheduledTasks; @@ -80,8 +80,6 @@ public FixedSourcePartitionedScheduler( ArrayList sourceSchedulers = new ArrayList<>(); - boolean firstPlanNode = true; - partitionIdAllocator = new PartitionIdAllocator(); scheduledTasks = new HashMap<>(); for (PlanNodeId planNodeId : schedulingOrder) { @@ -101,12 +99,8 @@ public FixedSourcePartitionedScheduler( scheduledTasks); sourceSchedulers.add(sourceScheduler); - - if (firstPlanNode) { - firstPlanNode = false; - } } - this.sourceSchedulers = sourceSchedulers; + this.sourceSchedulers = new ArrayDeque<>(sourceSchedulers); } @Override @@ -126,37 +120,36 @@ public ScheduleResult schedule() newTasks = newTasksBuilder.build(); } - boolean allBlocked = true; - List> blocked = new ArrayList<>(); - BlockedReason blockedReason = BlockedReason.NO_ACTIVE_DRIVER_GROUP; - + ListenableFuture blocked = immediateFuture(null); + ScheduleResult.BlockedReason blockedReason = null; int splitsScheduled = 0; - Iterator schedulerIterator = sourceSchedulers.iterator(); - while (schedulerIterator.hasNext()) { - SourceScheduler sourceScheduler = schedulerIterator.next(); - - ScheduleResult schedule = sourceScheduler.schedule(); + while (!sourceSchedulers.isEmpty()) { + SourceScheduler scheduler = sourceSchedulers.peek(); + ScheduleResult schedule = scheduler.schedule(); splitsScheduled += schedule.getSplitsScheduled(); + blocked = schedule.getBlocked(); + if (schedule.getBlockedReason().isPresent()) { - blocked.add(schedule.getBlocked()); - blockedReason = blockedReason.combineWith(schedule.getBlockedReason().get()); + blockedReason = schedule.getBlockedReason().get(); } else { - verify(schedule.getBlocked().isDone(), "blockedReason not provided when scheduler is blocked"); - allBlocked = false; + blockedReason = null; } - if (schedule.isFinished()) { - stageExecution.schedulingComplete(sourceScheduler.getPlanNodeId()); - schedulerIterator.remove(); - sourceScheduler.close(); + // if the source is not done scheduling, stop scheduling for now + if (!blocked.isDone() || !schedule.isFinished()) { + break; } + + stageExecution.schedulingComplete(scheduler.getPlanNodeId()); + sourceSchedulers.remove().close(); } - if (allBlocked) { - return new ScheduleResult(sourceSchedulers.isEmpty(), newTasks, whenAnyComplete(blocked), blockedReason, splitsScheduled); + if (blockedReason != null) { + return new ScheduleResult(sourceSchedulers.isEmpty(), newTasks, blocked, blockedReason, splitsScheduled); } else { + checkState(blocked.isDone(), "blockedReason not provided when scheduler is blocked"); return new ScheduleResult(sourceSchedulers.isEmpty(), newTasks, splitsScheduled); } } diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScheduleResult.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScheduleResult.java index d1d23f202699..8bdaf00ba40a 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScheduleResult.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScheduleResult.java @@ -34,23 +34,6 @@ public enum BlockedReason WAITING_FOR_SOURCE, MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE, /**/; - - public BlockedReason combineWith(BlockedReason other) - { - switch (this) { - case WRITER_SCALING: - throw new IllegalArgumentException("cannot be combined"); - case NO_ACTIVE_DRIVER_GROUP: - return other; - case SPLIT_QUEUES_FULL: - return other == SPLIT_QUEUES_FULL || other == NO_ACTIVE_DRIVER_GROUP ? SPLIT_QUEUES_FULL : MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE; - case WAITING_FOR_SOURCE: - return other == WAITING_FOR_SOURCE || other == NO_ACTIVE_DRIVER_GROUP ? WAITING_FOR_SOURCE : MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE; - case MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE: - return MIXED_SPLIT_QUEUES_FULL_AND_WAITING_FOR_SOURCE; - } - throw new IllegalArgumentException("Unknown blocked reason: " + other); - } } private final Set newTasks; diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestBucketedQueryWithManySplits.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestBucketedQueryWithManySplits.java new file mode 100644 index 000000000000..01610e3d2da3 --- /dev/null +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestBucketedQueryWithManySplits.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.hive; + +import com.google.common.collect.ImmutableMap; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import org.testng.annotations.Test; + +import static java.lang.String.format; + +public class TestBucketedQueryWithManySplits + extends AbstractTestQueryFramework +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return HiveQueryRunner.builder() + .setNodeCount(1) + .setExtraProperties(ImmutableMap.of( + "query.schedule-split-batch-size", "1", + "node-scheduler.max-splits-per-node", "1", + "node-scheduler.max-pending-splits-per-task", "1")) + .build(); + } + + @Test(timeOut = 120_000) + public void testBucketedQueryWithManySplits() + { + QueryRunner queryRunner = getQueryRunner(); + queryRunner.execute("CREATE TABLE tbl_a (col bigint, bucket bigint) WITH (bucketed_by=array['bucket'], bucket_count=10)"); + queryRunner.execute("CREATE TABLE tbl_b (col bigint, bucket bigint) WITH (bucketed_by=array['bucket'], bucket_count=10)"); + + for (int i = 0; i < 50; i++) { + queryRunner.execute(format("INSERT INTO tbl_a VALUES (%s, %s)", i, i)); + queryRunner.execute(format("INSERT INTO tbl_b VALUES (%s, %s)", i, i)); + } + + // query should not deadlock + assertQuery("" + + "WITH test_data AS" + + " (SELECT bucket" + + " FROM" + + " (SELECT" + + " bucket" + + " FROM tbl_a" + + " UNION ALL" + + " SELECT" + + " bucket" + + " FROM tbl_b) " + + " GROUP BY bucket) " + + "SELECT COUNT(1) FROM test_data", + "VALUES 50"); + + assertUpdate("DROP TABLE tbl_a"); + assertUpdate("DROP TABLE tbl_b"); + } +}