diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java index 3d2819708ff6..c8e290eecc7e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java @@ -448,13 +448,12 @@ public TaskInfo updateTask( this::notifyStatusChanged); taskHolderReference.compareAndSet(taskHolder, new TaskHolder(taskExecution)); needsPlan.set(false); + taskExecution.start(); } } - if (taskExecution != null) { - taskExecution.addSplitAssignments(splitAssignments); - taskExecution.getTaskContext().addDynamicFilter(dynamicFilterDomains); - } + taskExecution.addSplitAssignments(splitAssignments); + taskExecution.getTaskContext().addDynamicFilter(dynamicFilterDomains); } catch (Error e) { failed(e); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index 33f2bf84b6d0..0ed02836060e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -42,7 +41,6 @@ import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.NotThreadSafe; -import javax.annotation.concurrent.ThreadSafe; import java.lang.ref.WeakReference; import java.util.ArrayList; @@ -58,6 +56,9 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import static com.google.common.base.MoreObjects.toStringHelper; @@ -65,7 +66,7 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Iterables.concat; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static io.trino.SystemSessionProperties.getInitialSplitsPerNode; import static io.trino.SystemSessionProperties.getMaxDriversPerTask; @@ -113,34 +114,10 @@ public class SqlTaskExecution @GuardedBy("this") private final Map pendingSplitsByPlanNode; - private final Status status; + // number of created Drivers that haven't yet finished + private final AtomicLong remainingDrivers = new AtomicLong(); - static SqlTaskExecution createSqlTaskExecution( - TaskStateMachine taskStateMachine, - TaskContext taskContext, - OutputBuffer outputBuffer, - LocalExecutionPlan localExecutionPlan, - TaskExecutor taskExecutor, - Executor notificationExecutor, - SplitMonitor queryMonitor) - { - SqlTaskExecution task = new SqlTaskExecution( - taskStateMachine, - taskContext, - outputBuffer, - localExecutionPlan, - taskExecutor, - queryMonitor, - notificationExecutor); - try (SetThreadName ignored = new SetThreadName("Task-%s", task.getTaskId())) { - // The scheduleDriversForTaskLifeCycle method calls enqueueDriverSplitRunner, which registers a callback with access to this object. - // The call back is accessed from another thread, so this code cannot be placed in the constructor. - task.scheduleDriversForTaskLifeCycle(); - return task; - } - } - - private SqlTaskExecution( + public SqlTaskExecution( TaskStateMachine taskStateMachine, TaskContext taskContext, OutputBuffer outputBuffer, @@ -178,10 +155,6 @@ private SqlTaskExecution( this.pendingSplitsByPlanNode = this.driverRunnerFactoriesWithSplitLifeCycle.keySet().stream() .collect(toImmutableMap(identity(), ignore -> new PendingSplitsForPlanNode())); - this.status = new Status( - localExecutionPlan.getDriverFactories().stream() - .map(DriverFactory::getPipelineId) - .collect(toImmutableSet())); sourceStartOrder = localExecutionPlan.getPartitionedSourceOrder(); checkArgument(this.driverRunnerFactoriesWithSplitLifeCycle.keySet().equals(partitionedSources), @@ -199,6 +172,15 @@ private SqlTaskExecution( } } + public void start() + { + try (SetThreadName ignored = new SetThreadName("Task-%s", getTaskId())) { + // The scheduleDriversForTaskLifeCycle method calls enqueueDriverSplitRunner, which registers a callback with access to this object. + // The call back is accessed from another thread, so this code cannot be placed in the constructor. + scheduleDriversForTaskLifeCycle(); + } + } + // this is a separate method to ensure that the `this` reference is not leaked during construction private static TaskHandle createTaskHandle( TaskStateMachine taskStateMachine, @@ -296,11 +278,6 @@ private synchronized Map updateSplitAssignments(Lis } } - for (DriverSplitRunnerFactory driverSplitRunnerFactory : - Iterables.concat(driverRunnerFactoriesWithSplitLifeCycle.values(), driverRunnerFactoriesWithTaskLifeCycle)) { - driverSplitRunnerFactory.closeDriverFactoryIfFullyCreated(); - } - // update maxAcknowledgedSplit maxAcknowledgedSplit = splitAssignments.stream() .flatMap(source -> source.getSplits().stream()) @@ -392,6 +369,7 @@ private void scheduleDriversForTaskLifeCycle() driverRunnerFactory.noMoreDriverRunner(); verify(driverRunnerFactory.isNoMoreDriverRunner()); } + checkTaskCompletion(); } private synchronized void enqueueDriverSplitRunner(boolean forceRunSplit, List runners) @@ -406,7 +384,7 @@ private synchronized void enqueueDriverSplitRunner(boolean forceRunSplit, List() { @@ -415,7 +393,7 @@ public void onSuccess(Object result) { try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { // record driver is finished - status.decrementRemainingDriver(); + remainingDrivers.decrementAndGet(); checkTaskCompletion(); @@ -430,7 +408,7 @@ public void onFailure(Throwable cause) taskStateMachine.failed(cause); // record driver is finished - status.decrementRemainingDriver(); + remainingDrivers.decrementAndGet(); // fire failed event with cause splitMonitor.splitFailedEvent(taskId, getDriverStats(), cause); @@ -477,14 +455,14 @@ private synchronized void checkTaskCompletion() return; } - // are there more partition splits expected? - for (DriverSplitRunnerFactory driverSplitRunnerFactory : driverRunnerFactoriesWithSplitLifeCycle.values()) { - if (!driverSplitRunnerFactory.isNoMoreDriverRunner()) { + // are there more drivers expected? + for (DriverSplitRunnerFactory driverSplitRunnerFactory : concat(driverRunnerFactoriesWithTaskLifeCycle, driverRunnerFactoriesWithSplitLifeCycle.values())) { + if (!driverSplitRunnerFactory.isNoMoreDrivers()) { return; } } // do we still have running tasks? - if (status.getRemainingDriver() != 0) { + if (remainingDrivers.get() != 0) { return; } @@ -520,7 +498,7 @@ public String toString() { return toStringHelper(this) .add("taskId", taskId) - .add("remainingDrivers", status.getRemainingDriver()) + .add("remainingDrivers", remainingDrivers.get()) .add("unpartitionedSplitAssignments", unpartitionedSplitAssignments) .toString(); } @@ -595,7 +573,11 @@ private class DriverSplitRunnerFactory { private final DriverFactory driverFactory; private final PipelineContext pipelineContext; - private boolean closed; + + // number of created DriverSplitRunners that haven't created underlying Driver + private final AtomicInteger pendingCreations = new AtomicInteger(); + // true if no more DriverSplitRunners will be created + private final AtomicBoolean noMoreDriverRunner = new AtomicBoolean(); private DriverSplitRunnerFactory(DriverFactory driverFactory, boolean partitioned) { @@ -607,7 +589,8 @@ private DriverSplitRunnerFactory(DriverFactory driverFactory, boolean partitione // The former will take two arguments, and the latter will take one. This will simplify the signature quite a bit. public DriverSplitRunner createDriverRunner(@Nullable ScheduledSplit partitionedSplit) { - status.incrementPendingCreation(pipelineContext.getPipelineId()); + checkState(!noMoreDriverRunner.get(), "noMoreDriverRunner is set"); + pendingCreations.incrementAndGet(); // create driver context immediately so the driver existence is recorded in the stats // the number of drivers is used to balance work across nodes long splitWeight = partitionedSplit == null ? 0 : partitionedSplit.getSplit().getSplitWeight().getRawValue(); @@ -637,7 +620,7 @@ public Driver createDriver(DriverContext driverContext, @Nullable ScheduledSplit } } - status.decrementPendingCreation(pipelineContext.getPipelineId()); + pendingCreations.decrementAndGet(); closeDriverFactoryIfFullyCreated(); return driver; @@ -645,25 +628,28 @@ public Driver createDriver(DriverContext driverContext, @Nullable ScheduledSplit public void noMoreDriverRunner() { - status.setNoMoreDriverRunner(pipelineContext.getPipelineId()); + noMoreDriverRunner.set(true); closeDriverFactoryIfFullyCreated(); } public boolean isNoMoreDriverRunner() { - return status.isNoMoreDriverRunners(pipelineContext.getPipelineId()); + return noMoreDriverRunner.get(); } public void closeDriverFactoryIfFullyCreated() { - if (closed) { + if (driverFactory.isNoMoreDrivers()) { return; } - if (!isNoMoreDriverRunner() || status.getPendingCreation(pipelineContext.getPipelineId()) != 0) { - return; + if (isNoMoreDriverRunner() && pendingCreations.get() == 0) { + driverFactory.noMoreDrivers(); } - driverFactory.noMoreDrivers(); - closed = true; + } + + public boolean isNoMoreDrivers() + { + return driverFactory.isNoMoreDrivers(); } public OptionalInt getDriverInstances() @@ -780,94 +766,4 @@ public void stateChanged(BufferState newState) } } } - - @ThreadSafe - private static class Status - { - // no more driver runner: true if no more DriverSplitRunners will be created. - // pending creation: number of created DriverSplitRunners that haven't created underlying Driver. - // remaining driver: number of created Drivers that haven't yet finished. - - @GuardedBy("this") - private final int pipelineWithTaskLifeCycleCount; - - // For these 3 perX fields, they are populated lazily. If enumeration operations on the - // map can lead to side effects, no new entries can be created after such enumeration has - // happened. Otherwise, the order of entry creation and the enumeration operation will - // lead to different outcome. - @GuardedBy("this") - private final Map perPipeline; - @GuardedBy("this") - int pipelinesWithNoMoreDriverRunners; - - @GuardedBy("this") - private int overallRemainingDriver; - - public Status(Set pipelineIds) - { - int pipelineWithTaskLifeCycleCount = 0; - ImmutableMap.Builder perPipeline = ImmutableMap.builder(); - for (int pipelineId : pipelineIds) { - perPipeline.put(pipelineId, new PerPipelineStatus()); - pipelineWithTaskLifeCycleCount++; - } - this.pipelineWithTaskLifeCycleCount = pipelineWithTaskLifeCycleCount; - this.perPipeline = perPipeline.buildOrThrow(); - } - - public synchronized void setNoMoreDriverRunner(int pipelineId) - { - per(pipelineId).noMoreDriverRunners = true; - pipelinesWithNoMoreDriverRunners++; - } - - public synchronized void incrementPendingCreation(int pipelineId) - { - per(pipelineId).pendingCreation++; - } - - public synchronized void decrementPendingCreation(int pipelineId) - { - per(pipelineId).pendingCreation--; - } - - public synchronized void incrementRemainingDriver() - { - checkState(!(pipelinesWithNoMoreDriverRunners == pipelineWithTaskLifeCycleCount), "Cannot increment remainingDriver. NoMoreSplits is set."); - overallRemainingDriver++; - } - - public synchronized void decrementRemainingDriver() - { - checkState(overallRemainingDriver > 0, "Cannot decrement remainingDriver. Value is 0."); - overallRemainingDriver--; - } - - public synchronized int getPendingCreation(int pipelineId) - { - return per(pipelineId).pendingCreation; - } - - public synchronized int getRemainingDriver() - { - return overallRemainingDriver; - } - - public synchronized boolean isNoMoreDriverRunners(int pipelineId) - { - return per(pipelineId).noMoreDriverRunners; - } - - @GuardedBy("this") - private PerPipelineStatus per(int pipelineId) - { - return perPipeline.get(pipelineId); - } - } - - private static class PerPipelineStatus - { - int pendingCreation; - boolean noMoreDriverRunners; - } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java index 735f9319ba7e..e501684f0ebf 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecutionFactory.java @@ -28,7 +28,6 @@ import java.util.concurrent.Executor; import static com.google.common.base.Throwables.throwIfUnchecked; -import static io.trino.execution.SqlTaskExecution.createSqlTaskExecution; import static java.util.Objects.requireNonNull; public class SqlTaskExecutionFactory @@ -91,13 +90,13 @@ public SqlTaskExecution create( throw new RuntimeException(e); } } - return createSqlTaskExecution( + return new SqlTaskExecution( taskStateMachine, taskContext, outputBuffer, localExecutionPlan, taskExecutor, - taskNotificationExecutor, - splitMonitor); + splitMonitor, + taskNotificationExecutor); } } diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java index 812e5535f325..ba86b1998612 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java @@ -16,6 +16,8 @@ import com.google.common.collect.ImmutableList; import io.trino.sql.planner.plan.PlanNodeId; +import javax.annotation.concurrent.GuardedBy; + import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -34,7 +36,8 @@ public class DriverFactory private final Optional sourceId; private final OptionalInt driverInstances; - private boolean closed; + @GuardedBy("this") + private boolean noMoreDrivers; public DriverFactory(int pipelineId, boolean inputDriver, boolean outputDriver, List operatorFactories, OptionalInt driverInstances) { @@ -91,7 +94,7 @@ public List getOperatorFactories() public synchronized Driver createDriver(DriverContext driverContext) { - checkState(!closed, "DriverFactory is already closed"); + checkState(!noMoreDrivers, "noMoreDrivers is already set"); requireNonNull(driverContext, "driverContext is null"); ImmutableList.Builder operators = ImmutableList.builder(); for (OperatorFactory operatorFactory : operatorFactories) { @@ -103,12 +106,17 @@ public synchronized Driver createDriver(DriverContext driverContext) public synchronized void noMoreDrivers() { - if (closed) { + if (noMoreDrivers) { return; } - closed = true; + noMoreDrivers = true; for (OperatorFactory operatorFactory : operatorFactories) { operatorFactory.noMoreOperators(); } } + + public synchronized boolean isNoMoreDrivers() + { + return noMoreDrivers; + } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index 43bdf21962be..f4d769c4e011 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -135,14 +135,15 @@ public void testSimple() OptionalInt.empty())), ImmutableList.of(TABLE_SCAN_NODE_ID)); TaskContext taskContext = newTestingTaskContext(taskNotificationExecutor, driverYieldExecutor, taskStateMachine); - SqlTaskExecution sqlTaskExecution = SqlTaskExecution.createSqlTaskExecution( + SqlTaskExecution sqlTaskExecution = new SqlTaskExecution( taskStateMachine, taskContext, outputBuffer, localExecutionPlan, taskExecutor, - taskNotificationExecutor, - createTestSplitMonitor()); + createTestSplitMonitor(), + taskNotificationExecutor); + sqlTaskExecution.start(); // // test body