diff --git a/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java b/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java index 418a1c772c8d..18d64a6d37cb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/RemoteTask.java @@ -62,7 +62,10 @@ public interface RemoteTask PartitionedSplitsInfo getPartitionedSplitsInfo(); - void fail(Throwable cause); + /** + * Fails task from the coordinator perspective immediately, without waiting for acknowledgement from the remote task + */ + void failLocallyImmediately(Throwable cause); /** * Fails task remotely; only transitions to failed state when we receive confirmation that remote operation is completed diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java index 482560f93ef9..9e646931bb34 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java @@ -68,9 +68,15 @@ import static io.trino.execution.DynamicFiltersCollector.INITIAL_DYNAMIC_FILTERS_VERSION; import static io.trino.execution.DynamicFiltersCollector.INITIAL_DYNAMIC_FILTER_DOMAINS; import static io.trino.execution.TaskState.ABORTED; +import static io.trino.execution.TaskState.ABORTING; +import static io.trino.execution.TaskState.CANCELED; +import static io.trino.execution.TaskState.CANCELING; import static io.trino.execution.TaskState.FAILED; +import static io.trino.execution.TaskState.FAILING; +import static io.trino.execution.TaskState.FINISHED; import static io.trino.execution.TaskState.RUNNING; import static io.trino.util.Failures.toFailures; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -147,7 +153,7 @@ private SqlTask( // Pass a memory context supplier instead of a memory context to the output buffer, // because we haven't created the task context that holds the memory context yet. () -> queryContext.getTaskContextByTaskId(taskId).localMemoryContext(), - () -> notifyStatusChanged(), + this::notifyStatusChanged, exchangeManagerRegistry); taskStateMachine = new TaskStateMachine(taskId, taskNotificationExecutor); } @@ -157,54 +163,66 @@ private void initialize(Consumer onDone, CounterStat failedTasks) { requireNonNull(onDone, "onDone is null"); requireNonNull(failedTasks, "failedTasks is null"); + + AtomicBoolean outputBufferCleanedUp = new AtomicBoolean(); taskStateMachine.addStateChangeListener(newState -> { - if (!newState.isDone()) { - if (newState != RUNNING) { - // notify that task state changed (apart from initial RUNNING state notification) - notifyStatusChanged(); + if (newState.isTerminatingOrDone()) { + if (newState.isTerminating()) { + // This section must be synchronized to lock out any threads that might be attempting to create a SqlTaskExecution + synchronized (taskHolderLock) { + // If a SqlTaskExecution exists, it decides when termination is complete. Otherwise, we can mark termination completed immediately + if (taskHolderReference.get().getTaskExecution() == null) { + taskStateMachine.terminationComplete(); + } + } } - return; - } - - // Update failed tasks counter - if (newState == FAILED) { - failedTasks.update(1); - } - - // store final task info - synchronized (taskHolderLock) { - TaskHolder taskHolder = taskHolderReference.get(); - if (taskHolder.isFinished()) { - // another concurrent worker already set the final state - return; + else if (newState.isDone()) { + // Update failed tasks counter + if (newState == FAILED) { + failedTasks.update(1); + } + // store final task info + boolean finished = false; + synchronized (taskHolderLock) { + TaskHolder taskHolder = taskHolderReference.get(); + if (!taskHolder.isFinished()) { + TaskHolder newHolder = new TaskHolder( + createTaskInfo(taskHolder), + taskHolder.getIoStats(), + taskHolder.getDynamicFilterDomains()); + checkState(taskHolderReference.compareAndSet(taskHolder, newHolder), "unsynchronized concurrent task holder update"); + finished = true; + } + } + // Successfully set the final task info, cleanup the output buffer and call the completion handler + if (finished) { + try { + onDone.accept(this); + } + catch (Exception e) { + log.warn(e, "Error running task cleanup callback %s", SqlTask.this.taskId); + } + } + } + // make sure buffers are cleaned up + if (outputBufferCleanedUp.compareAndSet(false, true)) { + switch (newState) { + // don't close buffers for a failed query + // closed buffers signal to upstream tasks that everything finished cleanly + case FAILED, FAILING, ABORTED, ABORTING -> + outputBuffer.abort(); + case FINISHED, CANCELED, CANCELING -> + outputBuffer.destroy(); + default -> + throw new IllegalStateException(format("Invalid state for output buffer destruction: %s", newState)); + } } - - TaskHolder newHolder = new TaskHolder( - createTaskInfo(taskHolder), - taskHolder.getIoStats(), - taskHolder.getDynamicFilterDomains()); - checkState(taskHolderReference.compareAndSet(taskHolder, newHolder), "unsynchronized concurrent task holder update"); - } - - // make sure buffers are cleaned up - if (newState == FAILED || newState == ABORTED) { - // don't close buffers for a failed query - // closed buffers signal to upstream tasks that everything finished cleanly - outputBuffer.abort(); - } - else { - outputBuffer.destroy(); } - try { - onDone.accept(this); - } - catch (Exception e) { - log.warn(e, "Error running task cleanup callback %s", SqlTask.this.taskId); + // notify that task state changed (apart from initial RUNNING state notification) + if (newState != RUNNING) { + notifyStatusChanged(); } - - // notify that task is finished - notifyStatusChanged(); }); } @@ -283,7 +301,7 @@ private TaskStatus createTaskStatus(TaskHolder taskHolder) TaskState state = taskStateMachine.getState(); List failures = ImmutableList.of(); - if (state == FAILED) { + if (state == FAILED || state == FAILING) { failures = toFailures(taskStateMachine.getFailureCauses()); } @@ -490,8 +508,8 @@ private SqlTaskExecution tryCreateSqlTaskExecution(Session session, PlanFragment return execution; } - // Don't create a new execution if the task is already done - if (taskStateMachine.getState().isDone()) { + // Don't create SqlTaskExecution once termination has started + if (taskStateMachine.getState().isTerminatingOrDone()) { return null; } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index 11992e742b0a..ea79a82a5322 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -64,8 +64,8 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableMap.toImmutableMap; -import static com.google.common.collect.Iterables.concat; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.trino.SystemSessionProperties.getInitialSplitsPerNode; import static io.trino.SystemSessionProperties.getMaxDriversPerTask; import static io.trino.SystemSessionProperties.getSplitConcurrencyAdjustmentInterval; @@ -90,10 +90,12 @@ public class SqlTaskExecution private final Executor notificationExecutor; private final SplitMonitor splitMonitor; + private final DriverAndTaskTerminationTracker driverAndTaskTerminationTracker; private final Map driverRunnerFactoriesWithSplitLifeCycle; private final List driverRunnerFactoriesWithTaskLifeCycle; private final Map driverRunnerFactoriesWithRemoteSource; + private final List allDriverRunnerFactories; @GuardedBy("this") private final Map maxAcknowledgedSplitByPlanNode = new HashMap<>(); @@ -127,14 +129,16 @@ public SqlTaskExecution( this.notificationExecutor = requireNonNull(notificationExecutor, "notificationExecutor is null"); this.splitMonitor = requireNonNull(splitMonitor, "splitMonitor is null"); + this.driverAndTaskTerminationTracker = new DriverAndTaskTerminationTracker(taskStateMachine); try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { + List driverFactories = localExecutionPlan.getDriverFactories(); // index driver factories Set partitionedSources = ImmutableSet.copyOf(localExecutionPlan.getPartitionedSourceOrder()); ImmutableMap.Builder driverRunnerFactoriesWithSplitLifeCycle = ImmutableMap.builder(); ImmutableList.Builder driverRunnerFactoriesWithTaskLifeCycle = ImmutableList.builder(); ImmutableMap.Builder driverRunnerFactoriesWithRemoteSource = ImmutableMap.builder(); - for (DriverFactory driverFactory : localExecutionPlan.getDriverFactories()) { + for (DriverFactory driverFactory : driverFactories) { Optional sourceId = driverFactory.getSourceId(); if (sourceId.isPresent() && partitionedSources.contains(sourceId.get())) { driverRunnerFactoriesWithSplitLifeCycle.put(sourceId.get(), new DriverSplitRunnerFactory(driverFactory, true)); @@ -148,6 +152,10 @@ public SqlTaskExecution( this.driverRunnerFactoriesWithSplitLifeCycle = driverRunnerFactoriesWithSplitLifeCycle.buildOrThrow(); this.driverRunnerFactoriesWithTaskLifeCycle = driverRunnerFactoriesWithTaskLifeCycle.build(); this.driverRunnerFactoriesWithRemoteSource = driverRunnerFactoriesWithRemoteSource.buildOrThrow(); + this.allDriverRunnerFactories = ImmutableList.builderWithExpectedSize(driverFactories.size()) + .addAll(this.driverRunnerFactoriesWithTaskLifeCycle) + .addAll(this.driverRunnerFactoriesWithSplitLifeCycle.values()) + .build(); this.pendingSplitsByPlanNode = this.driverRunnerFactoriesWithSplitLifeCycle.keySet().stream() .collect(toImmutableMap(identity(), ignore -> new PendingSplitsForPlanNode())); @@ -157,28 +165,32 @@ public SqlTaskExecution( "Fragment is partitioned, but not all partitioned drivers were found"); // don't register the task if it is already completed (most likely failed during planning above) - if (!taskStateMachine.getState().isDone()) { - taskHandle = createTaskHandle(taskStateMachine, taskContext, outputBuffer, localExecutionPlan, taskExecutor); + if (taskStateMachine.getState().isTerminatingOrDone()) { + taskHandle = null; + driverFactories.forEach(DriverFactory::noMoreDrivers); } else { - taskHandle = null; + taskHandle = createTaskHandle(taskStateMachine, taskContext, outputBuffer, driverFactories, taskExecutor, driverAndTaskTerminationTracker); } } } - public void start() + // this must be synchronized to prevent a concurrent call to checkTaskCompletion() from proceeding before all task lifecycle drivers are created + public synchronized void start() { try (SetThreadName ignored = new SetThreadName("Task-%s", getTaskId())) { - // Task handle was not created because the task is already done, nothing to do - if (taskHandle == null) { - return; + // Signal immediate termination complete if task termination has started + if (taskStateMachine.getState().isTerminating()) { + taskStateMachine.terminationComplete(); + } + else if (taskHandle != null) { + // The scheduleDriversForTaskLifeCycle method calls enqueueDriverSplitRunner, which registers a callback with access to this object. + // The call back is accessed from another thread, so this code cannot be placed in the constructor. This must also happen before outputBuffer + // callbacks are registered to prevent a task completion check before task lifecycle splits are created + scheduleDriversForTaskLifeCycle(); + // Output buffer state change listener callback must not run in the constructor to avoid leaking a reference to "this" across to another thread + outputBuffer.addStateChangeListener(new CheckTaskCompletionOnBufferFinish(SqlTaskExecution.this)); } - // The scheduleDriversForTaskLifeCycle method calls enqueueDriverSplitRunner, which registers a callback with access to this object. - // The call back is accessed from another thread, so this code cannot be placed in the constructor. This must also happen before outputBuffer - // callbacks are registered to prevent a task completion check before task lifecycle splits are created - scheduleDriversForTaskLifeCycle(); - // Output buffer state change listener callback must not run in the constructor to avoid leaking a reference to "this" across to another thread - outputBuffer.addStateChangeListener(new CheckTaskCompletionOnBufferFinish(SqlTaskExecution.this)); } } @@ -187,8 +199,9 @@ private static TaskHandle createTaskHandle( TaskStateMachine taskStateMachine, TaskContext taskContext, OutputBuffer outputBuffer, - LocalExecutionPlan localExecutionPlan, - TaskExecutor taskExecutor) + List driverFactories, + TaskExecutor taskExecutor, + DriverAndTaskTerminationTracker driverAndTaskTerminationTracker) { TaskHandle taskHandle = taskExecutor.addTask( taskStateMachine.getTaskId(), @@ -197,10 +210,16 @@ private static TaskHandle createTaskHandle( getSplitConcurrencyAdjustmentInterval(taskContext.getSession()), getMaxDriversPerTask(taskContext.getSession())); taskStateMachine.addStateChangeListener(state -> { - if (state.isDone()) { - taskExecutor.removeTask(taskHandle); - for (DriverFactory factory : localExecutionPlan.getDriverFactories()) { - factory.noMoreDrivers(); + if (state.isTerminatingOrDone()) { + if (!taskHandle.isDestroyed()) { + taskExecutor.removeTask(taskHandle); + for (DriverFactory factory : driverFactories) { + factory.noMoreDrivers(); + } + } + // Need to re-check the live driver count since termination may have occurred without any running + if (state.isTerminating()) { + driverAndTaskTerminationTracker.checkTaskTermination(); } } }); @@ -222,6 +241,11 @@ public void addSplitAssignments(List splitAssignments) requireNonNull(splitAssignments, "splitAssignments is null"); checkState(!Thread.holdsLock(this), "Cannot add split assignments while holding a lock on the %s", getClass().getSimpleName()); + // Avoid accepting new splits once the task is terminating or done + if (taskStateMachine.getState().isTerminatingOrDone()) { + return; + } + try (SetThreadName ignored = new SetThreadName("Task-%s", taskId)) { // update our record of split assignments and schedule drivers for new partitioned splits Set updatedUnpartitionedSources = updateSplitAssignments(splitAssignments); @@ -436,21 +460,29 @@ public synchronized Set getNoMoreSplits() private synchronized void checkTaskCompletion() { - if (taskStateMachine.getState().isDone()) { + TaskState taskState = taskStateMachine.getState(); + if (taskState.isDone()) { return; } - // are there more drivers expected? - for (DriverSplitRunnerFactory driverSplitRunnerFactory : concat(driverRunnerFactoriesWithTaskLifeCycle, driverRunnerFactoriesWithSplitLifeCycle.values())) { - if (!driverSplitRunnerFactory.isNoMoreDrivers()) { - return; - } + // have all drivers finished terminating? + if (taskState.isTerminating()) { + driverAndTaskTerminationTracker.checkTaskTermination(); + return; } + // do we still have running tasks? if (remainingSplitRunners.get() != 0) { return; } + // are there more drivers expected? + for (DriverSplitRunnerFactory driverSplitRunnerFactory : allDriverRunnerFactories) { + if (!driverSplitRunnerFactory.isNoMoreDrivers()) { + return; + } + } + // no more output will be created outputBuffer.setNoMorePages(); @@ -484,6 +516,7 @@ public String toString() return toStringHelper(this) .add("taskId", taskId) .add("remainingSplitRunners", remainingSplitRunners.get()) + .add("liveCreatedDrivers", driverAndTaskTerminationTracker.getLiveCreatedDrivers()) .toString(); } @@ -587,10 +620,37 @@ public DriverSplitRunner createDriverRunner(@Nullable ScheduledSplit partitioned return new DriverSplitRunner(this, driverContext, partitionedSplit); } + /** + * @return the created {@link Driver}, or null if the driver factory is already closed because the task is terminating + */ + @Nullable public Driver createDriver(DriverContext driverContext, @Nullable ScheduledSplit partitionedSplit) { - Driver driver = driverFactory.createDriver(driverContext); + // Attempt to increment the driver count eagerly, but skip driver creation if the task is already terminating or done + if (!driverAndTaskTerminationTracker.tryCreateNewDriver()) { + return null; + } + Driver driver; + try { + driver = driverFactory.createDriver(driverContext); + } + catch (Throwable t) { + try { + // driverFactory is already closed, ignore the exception and return null, but don't swallow fatal errors + if (t instanceof Exception && driverFactory.isNoMoreDrivers()) { + return null; + } + // this exception is unexpected if driverFactory has not been closed, so rethrow it + throw t; + } + finally { + // decrement the live driver count since driver creation failed + driverAndTaskTerminationTracker.driverDestroyed(); + } + } + // register driver destroyed listener to detect when termination completes + driver.getDestroyedFuture().addListener(driverAndTaskTerminationTracker::driverDestroyed, directExecutor()); try { if (partitionedSplit != null) { // TableScanOperator requires partitioned split to be added before the first call to process @@ -765,6 +825,11 @@ public ListenableFuture processFor(Duration duration) if (this.driver == null) { this.driver = driverSplitRunnerFactory.createDriver(driverContext, partitionedSplit); + // Termination has begun, mark the runner as closed and return + if (this.driver == null) { + closed = true; + return immediateVoidFuture(); + } } driver = this.driver; @@ -816,4 +881,52 @@ public void stateChanged(BufferState newState) } } } + + private static final class DriverAndTaskTerminationTracker + { + private final TaskStateMachine taskStateMachine; + private final AtomicLong liveCreatedDrivers = new AtomicLong(); + + private DriverAndTaskTerminationTracker(TaskStateMachine taskStateMachine) + { + this.taskStateMachine = requireNonNull(taskStateMachine, "taskStateMachine is null"); + } + + public boolean tryCreateNewDriver() + { + // Eagerly increment the counter before checking the state machine + liveCreatedDrivers.incrementAndGet(); + // If termination has started already, we need to decrement the counter and check for termination complete + if (taskStateMachine.getState().isTerminatingOrDone()) { + driverDestroyed(); + return false; + } + return true; + } + + public void driverDestroyed() + { + if (liveCreatedDrivers.decrementAndGet() == 0) { + checkTaskTermination(); + } + } + + public long getLiveCreatedDrivers() + { + return liveCreatedDrivers.get(); + } + + public void checkTaskTermination() + { + if (taskStateMachine.getState().isTerminating()) { + long liveCreatedDrivers = this.liveCreatedDrivers.get(); + // Allow unexpectedly negative values to complete task termination to avoid having stuck tasks, but + // throw an exception afterwards to avoid masking bugs + if (liveCreatedDrivers <= 0) { + taskStateMachine.terminationComplete(); + checkState(liveCreatedDrivers == 0, "liveCreatedDrivers is negative: %s", liveCreatedDrivers); + } + } + } + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index fc70fa468dc1..d0bb6beb8e5d 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -823,10 +823,10 @@ public String getTaskInstanceId() return task.getTaskInstanceId(); } - public boolean isTaskFailed() + public boolean isTaskFailedOrFailing() { return switch (task.getTaskState()) { - case ABORTED, FAILED -> true; + case ABORTED, ABORTING, FAILED, FAILING -> true; default -> false; }; } diff --git a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java index 55d085a2a7bd..6b5482330fe7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/StageStateMachine.java @@ -274,7 +274,9 @@ public BasicStageStats getBasicStageStats(Supplier> taskInfos TaskState taskState = taskInfo.getTaskStatus().getState(); TaskStats taskStats = taskInfo.getStats(); - if (taskState == TaskState.FAILED) { + boolean taskFailedOrFailing = taskState == TaskState.FAILED || taskState == TaskState.FAILING; + + if (taskFailedOrFailing) { failedTasks++; } @@ -284,7 +286,7 @@ public BasicStageStats getBasicStageStats(Supplier> taskInfos completedDrivers += taskStats.getCompletedDrivers(); cumulativeUserMemory += taskStats.getCumulativeUserMemory(); - if (taskState == TaskState.FAILED) { + if (taskFailedOrFailing) { failedCumulativeUserMemory += taskStats.getCumulativeUserMemory(); } @@ -295,7 +297,7 @@ public BasicStageStats getBasicStageStats(Supplier> taskInfos totalScheduledTime += taskStats.getTotalScheduledTime().roundTo(NANOSECONDS); totalCpuTime += taskStats.getTotalCpuTime().roundTo(NANOSECONDS); - if (taskState == TaskState.FAILED) { + if (taskFailedOrFailing) { failedScheduledTime += taskStats.getTotalScheduledTime().roundTo(NANOSECONDS); failedCpuTime += taskStats.getTotalCpuTime().roundTo(NANOSECONDS); } @@ -454,8 +456,9 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) else { runningTasks++; } + boolean taskFailedOrFailing = taskState == TaskState.FAILED || taskState == TaskState.FAILING; - if (taskState == TaskState.FAILED) { + if (taskFailedOrFailing) { failedTasks++; } @@ -468,14 +471,14 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) completedDrivers += taskStats.getCompletedDrivers(); cumulativeUserMemory += taskStats.getCumulativeUserMemory(); - if (taskState == TaskState.FAILED) { + if (taskFailedOrFailing) { failedCumulativeUserMemory += taskStats.getCumulativeUserMemory(); } totalScheduledTime += taskStats.getTotalScheduledTime().roundTo(NANOSECONDS); totalCpuTime += taskStats.getTotalCpuTime().roundTo(NANOSECONDS); totalBlockedTime += taskStats.getTotalBlockedTime().roundTo(NANOSECONDS); - if (taskState == TaskState.FAILED) { + if (taskFailedOrFailing) { failedScheduledTime += taskStats.getTotalScheduledTime().roundTo(NANOSECONDS); failedCpuTime += taskStats.getTotalCpuTime().roundTo(NANOSECONDS); } @@ -508,7 +511,7 @@ public StageInfo getStageInfo(Supplier> taskInfosSupplier) physicalWrittenDataSize += taskStats.getPhysicalWrittenDataSize().toBytes(); - if (taskState == TaskState.FAILED) { + if (taskFailedOrFailing) { failedPhysicalInputDataSize += taskStats.getPhysicalInputDataSize().toBytes(); failedPhysicalInputPositions += taskStats.getPhysicalInputPositions(); failedPhysicalInputReadTime += taskStats.getPhysicalInputReadTime().roundTo(NANOSECONDS); diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java index dc27dea56ca3..571078f272ad 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskManagerConfig.java @@ -71,6 +71,7 @@ public class TaskManagerConfig private Duration statusRefreshMaxWait = new Duration(1, TimeUnit.SECONDS); private Duration infoUpdateInterval = new Duration(3, TimeUnit.SECONDS); + private Duration taskTerminationTimeout = new Duration(1, TimeUnit.MINUTES); private boolean interruptStuckSplitTasksEnabled = true; private Duration interruptStuckSplitTasksWarningThreshold = new Duration(10, TimeUnit.MINUTES); @@ -138,6 +139,21 @@ public TaskManagerConfig setInfoUpdateInterval(Duration infoUpdateInterval) return this; } + @MinDuration("1ms") + @NotNull + public Duration getTaskTerminationTimeout() + { + return taskTerminationTimeout; + } + + @Config("task.termination-timeout") + @ConfigDescription("Maximum duration to wait for a task to complete termination before failing the task on the coordinator") + public TaskManagerConfig setTaskTerminationTimeout(Duration taskTerminationTimeout) + { + this.taskTerminationTimeout = taskTerminationTimeout; + return this; + } + public boolean isPerOperatorCpuTimerEnabled() { return perOperatorCpuTimerEnabled; diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskState.java b/core/trino-main/src/main/java/io/trino/execution/TaskState.java index 5191205b865c..c630e87adac7 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskState.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskState.java @@ -25,42 +25,56 @@ public enum TaskState * be in the planned state until, the dependencies of the task * have begun producing output. */ - PLANNED(false), + PLANNED(false, false), /** * Task is running. */ - RUNNING(false), + RUNNING(false, false), /** * Task has finished executing and output is left to be consumed. * In this state, there will be no new drivers, the existing drivers have finished * and the output buffer of the task is at-least in a 'no-more-pages' state. */ - FLUSHING(false), + FLUSHING(false, false), /** * Task has finished executing and all output has been consumed. */ - FINISHED(true), + FINISHED(true, false), + /** + * Task was canceled, but not all drivers have finishing exiting + */ + CANCELING(false, true), /** * Task was canceled by a user. */ - CANCELED(true), + CANCELED(true, false), + /** + * Task was told to abort, but not all drivers have finishing exiting + */ + ABORTING(false, true), /** * Task was aborted due to a failure in the query. The failure * was not in this task. */ - ABORTED(true), + ABORTED(true, false), + /** + * Task has been marked as failed, but not all drivers have finishing exiting + */ + FAILING(false, true), /** * Task execution failed. */ - FAILED(true); + FAILED(true, false); public static final Set TERMINAL_TASK_STATES = Stream.of(TaskState.values()).filter(TaskState::isDone).collect(toImmutableSet()); private final boolean doneState; + private final boolean terminating; - TaskState(boolean doneState) + TaskState(boolean doneState, boolean terminating) { this.doneState = doneState; + this.terminating = terminating; } /** @@ -70,4 +84,14 @@ public boolean isDone() { return doneState; } + + public boolean isTerminating() + { + return terminating; + } + + public boolean isTerminatingOrDone() + { + return terminating | doneState; + } } diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java index 1e9b6e890dca..c41e6e054bfe 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskStateMachine.java @@ -32,7 +32,15 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; +import static io.trino.execution.TaskState.ABORTED; +import static io.trino.execution.TaskState.ABORTING; +import static io.trino.execution.TaskState.CANCELED; +import static io.trino.execution.TaskState.CANCELING; +import static io.trino.execution.TaskState.FAILED; +import static io.trino.execution.TaskState.FAILING; +import static io.trino.execution.TaskState.FINISHED; import static io.trino.execution.TaskState.FLUSHING; import static io.trino.execution.TaskState.RUNNING; import static io.trino.execution.TaskState.TERMINAL_TASK_STATES; @@ -59,7 +67,7 @@ public TaskStateMachine(TaskId taskId, Executor executor) { this.taskId = requireNonNull(taskId, "taskId is null"); this.executor = requireNonNull(executor, "executor is null"); - taskState = new StateMachine<>("task " + taskId, executor, TaskState.RUNNING, TERMINAL_TASK_STATES); + taskState = new StateMachine<>("task " + taskId, executor, RUNNING, TERMINAL_TASK_STATES); taskState.addStateChangeListener(newState -> log.debug("Task %s is %s", taskId, newState)); } @@ -103,31 +111,47 @@ public void transitionToFlushing() public void finished() { - transitionToDoneState(TaskState.FINISHED); + taskState.setIf(FINISHED, currentState -> !currentState.isTerminatingOrDone()); } public void cancel() { - transitionToDoneState(TaskState.CANCELED); + startTermination(CANCELING); } public void abort() { - transitionToDoneState(TaskState.ABORTED); + startTermination(ABORTING); } public void failed(Throwable cause) { failureCauses.add(cause); - transitionToDoneState(TaskState.FAILED); + startTermination(FAILING); } - private void transitionToDoneState(TaskState doneState) + public void terminationComplete() { - requireNonNull(doneState, "doneState is null"); - checkArgument(doneState.isDone(), "doneState %s is not a done state", doneState); + TaskState currentState = taskState.get(); + if (currentState.isDone()) { + return; // ignore redundant completion events + } + checkState(currentState.isTerminating(), "current state %s is not a terminating state", currentState); + TaskState newState = switch (currentState) { + case CANCELING -> CANCELED; + case ABORTING -> ABORTED; + case FAILING -> FAILED; + default -> throw new IllegalStateException("Unhandled terminating state: " + currentState); + }; + taskState.compareAndSet(currentState, newState); + } + + private void startTermination(TaskState terminatingState) + { + requireNonNull(terminatingState, "terminatingState is null"); + checkArgument(terminatingState.isTerminating(), "terminatingState %s is not a terminating state", terminatingState); - taskState.setIf(doneState, currentState -> !currentState.isDone()); + taskState.setIf(terminatingState, currentState -> !currentState.isTerminatingOrDone()); } /** diff --git a/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java b/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java index 6b02153d1f59..c501e223ccfc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java +++ b/core/trino-main/src/main/java/io/trino/execution/executor/TaskExecutor.java @@ -283,7 +283,10 @@ public synchronized TaskHandle addTask( public void removeTask(TaskHandle taskHandle) { try (SetThreadName ignored = new SetThreadName("Task-%s", taskHandle.getTaskId())) { - doRemoveTask(taskHandle); + // Skip additional scheduling if the task was already destroyed + if (!doRemoveTask(taskHandle)) { + return; + } } // replace blocked splits that were terminated @@ -293,13 +296,23 @@ public void removeTask(TaskHandle taskHandle) } } - private void doRemoveTask(TaskHandle taskHandle) + /** + * Returns true if the task handle was destroyed and removed splits as a result that may need to be replaced. Otherwise, + * if the {@link TaskHandle} was already destroyed or no splits were removed then this method returns false and no additional + * splits need to be scheduled. + */ + private boolean doRemoveTask(TaskHandle taskHandle) { List splits; synchronized (this) { tasks.remove(taskHandle); - splits = taskHandle.destroy(); + // Task is already destroyed + if (taskHandle.isDestroyed()) { + return false; + } + + splits = taskHandle.destroy(); // stop tracking splits (especially blocked splits which may never unblock) allSplits.removeAll(splits); intermediateSplits.removeAll(splits); @@ -318,6 +331,7 @@ private void doRemoveTask(TaskHandle taskHandle) completedTasksPerLevel.incrementAndGet(computeLevel(threadUsageNanos)); log.debug("Task finished or failed %s", taskHandle.getTaskId()); + return !splits.isEmpty(); } public List> enqueueSplits(TaskHandle taskHandle, boolean intermediate, List taskSplits) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java index 7ab378272a49..2ba98b003ccc 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java @@ -78,6 +78,7 @@ import static io.trino.operator.ExchangeOperator.REMOTE_CATALOG_HANDLE; import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static io.trino.spi.StandardErrorCode.REMOTE_HOST_GONE; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; /** @@ -271,7 +272,7 @@ public synchronized void fail(Throwable failureCause) public synchronized void failTask(TaskId taskId, Throwable failureCause) { RemoteTask task = requireNonNull(tasks.get(taskId.getPartitionId()), () -> "task not found: " + taskId); - task.fail(failureCause); + task.failLocallyImmediately(failureCause); fail(failureCause); } @@ -343,21 +344,22 @@ private void updateTaskStatus(TaskStatus taskStatus) TaskState taskState = taskStatus.getState(); switch (taskState) { + case FAILING: case FAILED: RuntimeException failure = taskStatus.getFailures().stream() .findFirst() .map(this::rewriteTransportFailure) .map(ExecutionFailureInfo::toException) + // task is failed or failing, so we need to create a synthetic exception to fail the stage now .orElseGet(() -> new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason")); fail(failure); break; + case CANCELING: case CANCELED: - // A task should only be in the canceled state if the STAGE is cancelled - fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the CANCELED state but stage is " + stateMachine.getState())); - break; + case ABORTING: case ABORTED: - // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) - fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stateMachine.getState())); + // A task should only be in the aborting, aborted, canceling, or canceled state if the STAGE is done (ABORTED or FAILED) + fail(new TrinoException(GENERIC_INTERNAL_ERROR, format("A task is in the %s state but stage is %s", taskState, stateMachine.getState()))); break; case FLUSHING: newFlushingOrFinishedTaskObserved = addFlushingTask(taskStatus.getTaskId()); diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java index 1c24a1ddd5d5..6684ad270b43 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/ScaledWriterScheduler.java @@ -125,7 +125,7 @@ private boolean isWeightedBufferFull() double totalOutputSize = 0.0; double overutilizedOutputSize = 0.0; for (TaskStatus task : sourceTasksProvider.get()) { - if (!task.getState().isDone()) { + if (!task.getState().isTerminatingOrDone()) { long outputDataSize = task.getOutputDataSize().toBytes(); totalOutputSize += outputDataSize; if (task.getOutputBufferStatus().isOverutilized()) { @@ -140,7 +140,7 @@ private boolean isWeightedBufferFull() private boolean isAverageBufferFull() { return sourceTasksProvider.get().stream() - .filter(task -> !task.getState().isDone()) + .filter(task -> !task.getState().isTerminatingOrDone()) .map(TaskStatus::getOutputBufferStatus) .map(OutputBufferStatus::isOverutilized) .mapToDouble(full -> full ? 1.0 : 0.0) diff --git a/core/trino-main/src/main/java/io/trino/operator/Driver.java b/core/trino-main/src/main/java/io/trino/operator/Driver.java index d9f99007671c..7944881d70fb 100644 --- a/core/trino-main/src/main/java/io/trino/operator/Driver.java +++ b/core/trino-main/src/main/java/io/trino/operator/Driver.java @@ -84,8 +84,9 @@ public class Driver private final Map> revokingOperators = new HashMap<>(); private final AtomicReference state = new AtomicReference<>(State.ALIVE); + private final SettableFuture destroyedFuture = SettableFuture.create(); - private final DriverLock exclusiveLock = new DriverLock(); + private final DriverLock exclusiveLock = new DriverLock(state, destroyedFuture); @GuardedBy("exclusiveLock") private SplitAssignment currentSplitAssignment; @@ -94,7 +95,7 @@ public class Driver private enum State { - ALIVE, NEED_DESTRUCTION, DESTROYED + ALIVE, NEED_DESTRUCTION, DESTROYING, DESTROYED } public static Driver createDriver(DriverContext driverContext, List operators) @@ -158,6 +159,11 @@ public DriverContext getDriverContext() return driverContext; } + public ListenableFuture getDestroyedFuture() + { + return destroyedFuture; + } + public Optional getSourceId() { return sourceOperator.map(SourceOperator::getSourceId); @@ -184,20 +190,20 @@ public boolean isFinished() checkLockNotHeld("Cannot check finished status while holding the driver lock"); // if we can get the lock, attempt a clean shutdown; otherwise someone else will shutdown - Optional result = tryWithLockUninterruptibly(this::isFinishedInternal); - return result.orElseGet(() -> state.get() != State.ALIVE || driverContext.isDone()); + Optional result = tryWithLockUninterruptibly(this::isTerminatingOrDoneInternal); + return result.orElseGet(() -> state.get() != State.ALIVE || driverContext.isTerminatingOrDone()); } @GuardedBy("exclusiveLock") - private boolean isFinishedInternal() + private boolean isTerminatingOrDoneInternal() { - checkLockHeld("Lock must be held to call isFinishedInternal"); + checkLockHeld("Lock must be held to call isTerminatingOrDoneInternal"); - boolean finished = state.get() != State.ALIVE || driverContext.isDone() || activeOperators.isEmpty() || activeOperators.get(activeOperators.size() - 1).isFinished(); - if (finished) { + boolean terminatingOrDone = state.get() != State.ALIVE || activeOperators.isEmpty() || activeOperators.get(activeOperators.size() - 1).isFinished() || driverContext.isTerminatingOrDone(); + if (terminatingOrDone) { state.compareAndSet(State.ALIVE, State.NEED_DESTRUCTION); } - return finished; + return terminatingOrDone; } public void updateSplitAssignment(SplitAssignment splitAssignment) @@ -295,7 +301,7 @@ public ListenableFuture process(Duration maxRuntime, int maxIterations) try { long start = System.nanoTime(); int iterations = 0; - while (!isFinishedInternal()) { + while (!isTerminatingOrDoneInternal()) { ListenableFuture future = processInternal(operationTimer); iterations++; if (!future.isDone()) { @@ -381,7 +387,7 @@ private ListenableFuture processInternal(OperationTimer operationTimer) } boolean movedPage = false; - for (int i = 0; i < activeOperators.size() - 1 && !driverContext.isDone(); i++) { + for (int i = 0; i < activeOperators.size() - 1 && !driverContext.isTerminatingOrDone(); i++) { Operator current = activeOperators.get(i); Operator next = activeOperators.get(i + 1); @@ -473,7 +479,7 @@ private ListenableFuture processInternal(OperationTimer operationTimer) @GuardedBy("exclusiveLock") private void handleMemoryRevoke() { - for (int i = 0; i < activeOperators.size() && !driverContext.isDone(); i++) { + for (int i = 0; i < activeOperators.size() && !driverContext.isTerminatingOrDone(); i++) { Operator operator = activeOperators.get(i); if (revokingOperators.containsKey(operator)) { @@ -504,7 +510,7 @@ private void destroyIfNecessary() { checkLockHeld("Lock must be held to call destroyIfNecessary"); - if (!state.compareAndSet(State.NEED_DESTRUCTION, State.DESTROYED)) { + if (!state.compareAndSet(State.NEED_DESTRUCTION, State.DESTROYING)) { return; } @@ -528,6 +534,10 @@ private void destroyIfNecessary() "Error destroying driver for task %s", driverContext.getTaskId()); } + finally { + // Mark destruction as having completed after driverContext.finished() is complete + state.set(State.DESTROYED); + } if (inFlightException != null) { // this will always be an Error or Runtime @@ -752,6 +762,15 @@ private static class DriverLock { private final ReentrantLock lock = new ReentrantLock(); + private final AtomicReference state; + private final SettableFuture destroyedFuture; + + private DriverLock(AtomicReference state, SettableFuture destroyedFuture) + { + this.state = requireNonNull(state, "state is null"); + this.destroyedFuture = requireNonNull(destroyedFuture, "destroyedFuture is null"); + } + @GuardedBy("this") private Thread currentOwner; @GuardedBy("this") @@ -798,12 +817,19 @@ private synchronized void setOwner(boolean interruptionAllowed) // state to prevent further processing in the Driver. } - public synchronized void unlock() + public void unlock() { checkState(lock.isHeldByCurrentThread(), "Current thread does not hold lock"); - currentOwner = null; - currentOwnerInterruptionAllowed = false; + synchronized (this) { + currentOwner = null; + currentOwnerInterruptionAllowed = false; + } lock.unlock(); + // Set the destroyed signal after releasing the lock since callbacks are fired synchronously and + // otherwise could cause a deadlock + if (state.get() == State.DESTROYED) { + destroyedFuture.set(null); + } } public synchronized List getInterrupterStack() diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java index cd8ff2bcc31b..9c3b61aa6bfb 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverContext.java @@ -181,13 +181,14 @@ public void finished() public void failed(Throwable cause) { - pipelineContext.failed(cause); - finished.set(true); + if (finished.compareAndSet(false, true)) { + pipelineContext.driverFailed(cause); + } } - public boolean isDone() + public boolean isTerminatingOrDone() { - return finished.get() || pipelineContext.isDone(); + return finished.get() || pipelineContext.isTerminatingOrDone(); } public ListenableFuture reserveSpill(long bytes) diff --git a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java index 72adaf481c43..e637c8ef8727 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java +++ b/core/trino-main/src/main/java/io/trino/operator/DriverFactory.java @@ -37,8 +37,9 @@ public class DriverFactory private final Optional sourceId; private final OptionalInt driverInstances; + // must synchronize between createDriver() and noMoreDrivers(), but isNoMoreDrivers() is safe without synchronizing @GuardedBy("this") - private boolean noMoreDrivers; + private volatile boolean noMoreDrivers; public DriverFactory(int pipelineId, boolean inputDriver, boolean outputDriver, List operatorFactories, OptionalInt driverInstances) { @@ -93,16 +94,20 @@ public List getOperatorFactories() return operatorFactories; } - public synchronized Driver createDriver(DriverContext driverContext) + public Driver createDriver(DriverContext driverContext) { - checkState(!noMoreDrivers, "noMoreDrivers is already set"); requireNonNull(driverContext, "driverContext is null"); - List operators = new ArrayList<>(); + List operators = new ArrayList<>(operatorFactories.size()); try { - for (OperatorFactory operatorFactory : operatorFactories) { - Operator operator = operatorFactory.createOperator(driverContext); - operators.add(operator); + synchronized (this) { + // must check noMoreDrivers after acquiring the lock + checkState(!noMoreDrivers, "noMoreDrivers is already set"); + for (OperatorFactory operatorFactory : operatorFactories) { + Operator operator = operatorFactory.createOperator(driverContext); + operators.add(operator); + } } + // Driver creation can continue without holding the lock return Driver.createDriver(driverContext, operators); } catch (Throwable failure) { @@ -141,7 +146,9 @@ public synchronized void noMoreDrivers() } } - public synchronized boolean isNoMoreDrivers() + // no need to synchronize when just checking the boolean flag + @SuppressWarnings("GuardedBy") + public boolean isNoMoreDrivers() { return noMoreDrivers; } diff --git a/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java b/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java index 87025b246bda..364c030d3257 100644 --- a/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/PipelineContext.java @@ -242,14 +242,14 @@ public void start() taskContext.start(); } - public void failed(Throwable cause) + public void driverFailed(Throwable cause) { taskContext.failed(cause); } - public boolean isDone() + public boolean isTerminatingOrDone() { - return taskContext.isDone(); + return taskContext.isTerminatingOrDone(); } public synchronized ListenableFuture reserveSpill(long bytes) diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java index b8380ad23afa..dc693037b66a 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TaskContext.java +++ b/core/trino-main/src/main/java/io/trino/operator/TaskContext.java @@ -85,6 +85,7 @@ public class TaskContext private final AtomicReference executionStartTime = new AtomicReference<>(); private final AtomicReference lastExecutionStartTime = new AtomicReference<>(); + private final AtomicReference terminatingStartTime = new AtomicReference<>(); private final AtomicReference executionEndTime = new AtomicReference<>(); private final List pipelineContexts = new CopyOnWriteArrayList<>(); @@ -213,14 +214,19 @@ public void start() private void updateStatsIfDone(TaskState newState) { - if (newState.isDone()) { + if (newState.isTerminating()) { + terminatingStartTime.compareAndSet(null, DateTime.now()); + } + else if (newState.isDone()) { DateTime now = DateTime.now(); long majorGcCount = gcMonitor.getMajorGcCount(); long majorGcTime = gcMonitor.getMajorGcTime().roundTo(NANOSECONDS); + long nanoTimeNow = System.nanoTime(); + // before setting the end times, make sure a start has been recorded executionStartTime.compareAndSet(null, now); - startNanos.compareAndSet(0, System.nanoTime()); + startNanos.compareAndSet(0, nanoTimeNow); startFullGcCount.compareAndSet(-1, majorGcCount); startFullGcTimeNanos.compareAndSet(-1, majorGcTime); @@ -230,7 +236,7 @@ private void updateStatsIfDone(TaskState newState) // use compare and set from initial value to avoid overwriting if there // were a duplicate notification, which shouldn't happen executionEndTime.compareAndSet(null, now); - endNanos.compareAndSet(0, System.nanoTime()); + endNanos.compareAndSet(0, nanoTimeNow); endFullGcCount.compareAndSet(-1, majorGcCount); endFullGcTimeNanos.compareAndSet(-1, majorGcTime); } @@ -241,9 +247,9 @@ public void failed(Throwable cause) taskStateMachine.failed(cause); } - public boolean isDone() + public boolean isTerminatingOrDone() { - return taskStateMachine.getState().isDone(); + return taskStateMachine.getState().isTerminatingOrDone(); } public TaskState getState() @@ -558,6 +564,7 @@ public TaskStats getTaskStats() taskStateMachine.getCreatedTime(), executionStartTime.get(), lastExecutionStartTime.get(), + terminatingStartTime.get(), lastExecutionEndTime == 0 ? null : new DateTime(lastExecutionEndTime), executionEndTime.get(), elapsedTime.convertToMostSuccinctTimeUnit(), diff --git a/core/trino-main/src/main/java/io/trino/operator/TaskStats.java b/core/trino-main/src/main/java/io/trino/operator/TaskStats.java index be7bcb0067f4..492634e2daf5 100644 --- a/core/trino-main/src/main/java/io/trino/operator/TaskStats.java +++ b/core/trino-main/src/main/java/io/trino/operator/TaskStats.java @@ -36,6 +36,7 @@ public class TaskStats private final DateTime createTime; private final DateTime firstStartTime; private final DateTime lastStartTime; + private final DateTime terminatingStartTime; private final DateTime lastEndTime; private final DateTime endTime; @@ -97,6 +98,7 @@ public TaskStats(DateTime createTime, DateTime endTime) null, null, null, + null, endTime, new Duration(0, MILLISECONDS), new Duration(0, MILLISECONDS), @@ -143,6 +145,7 @@ public TaskStats( @JsonProperty("createTime") DateTime createTime, @JsonProperty("firstStartTime") DateTime firstStartTime, @JsonProperty("lastStartTime") DateTime lastStartTime, + @JsonProperty("terminatingStartTime") DateTime terminatingStartTime, @JsonProperty("lastEndTime") DateTime lastEndTime, @JsonProperty("endTime") DateTime endTime, @JsonProperty("elapsedTime") Duration elapsedTime, @@ -200,6 +203,7 @@ public TaskStats( this.createTime = requireNonNull(createTime, "createTime is null"); this.firstStartTime = firstStartTime; this.lastStartTime = lastStartTime; + this.terminatingStartTime = terminatingStartTime; this.lastEndTime = lastEndTime; this.endTime = endTime; this.elapsedTime = requireNonNull(elapsedTime, "elapsedTime is null"); @@ -293,6 +297,13 @@ public DateTime getLastStartTime() return lastStartTime; } + @Nullable + @JsonProperty + public DateTime getTerminatingStartTime() + { + return terminatingStartTime; + } + @Nullable @JsonProperty public DateTime getLastEndTime() @@ -541,6 +552,7 @@ public TaskStats summarize() createTime, firstStartTime, lastStartTime, + terminatingStartTime, lastEndTime, endTime, elapsedTime, @@ -589,6 +601,7 @@ public TaskStats summarizeFinal() createTime, firstStartTime, lastStartTime, + terminatingStartTime, lastEndTime, endTime, elapsedTime, diff --git a/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java b/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java index 5da4ad0ab3c6..974084c7aaab 100644 --- a/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java +++ b/core/trino-main/src/main/java/io/trino/server/HttpRemoteTaskFactory.java @@ -71,6 +71,7 @@ public class HttpRemoteTaskFactory private final Duration maxErrorDuration; private final Duration taskStatusRefreshMaxWait; private final Duration taskInfoUpdateInterval; + private final Duration taskTerminationTimeout; private final ExecutorService coreExecutor; private final Executor executor; private final ThreadPoolExecutorMBean executorMBean; @@ -103,6 +104,7 @@ public HttpRemoteTaskFactory( this.maxErrorDuration = config.getRemoteTaskMaxErrorDuration(); this.taskStatusRefreshMaxWait = taskConfig.getStatusRefreshMaxWait(); this.taskInfoUpdateInterval = taskConfig.getInfoUpdateInterval(); + this.taskTerminationTimeout = taskConfig.getTaskTerminationTimeout(); this.coreExecutor = newCachedThreadPool(daemonThreadsNamed("remote-task-callback-%s")); this.executor = new BoundedExecutor(coreExecutor, config.getRemoteTaskMaxCallbackThreads()); this.executorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) coreExecutor); @@ -155,6 +157,7 @@ public RemoteTask createRemoteTask( maxErrorDuration, taskStatusRefreshMaxWait, taskInfoUpdateInterval, + taskTerminationTimeout, summarizeTaskInfo, taskStatusCodec, dynamicFilterDomainsCodec, diff --git a/core/trino-main/src/main/java/io/trino/server/TaskResource.java b/core/trino-main/src/main/java/io/trino/server/TaskResource.java index 5d586aa6475c..007f7a4909f8 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskResource.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskResource.java @@ -533,7 +533,7 @@ private static Response createBufferResultResponse(SqlTaskWithResults taskWithRe .header(TRINO_PAGE_NEXT_TOKEN, result.getNextToken()) .header(TRINO_BUFFER_COMPLETE, result.isBufferComplete()) // check for task failure after getting the result to ensure it's consistent with isBufferComplete() - .header(TRINO_TASK_FAILED, taskWithResults.isTaskFailed()) + .header(TRINO_TASK_FAILED, taskWithResults.isTaskFailedOrFailing()) .build(); } } diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java index c897bbec7ac4..9e59b6bec98d 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/HttpRemoteTask.java @@ -36,6 +36,7 @@ import io.trino.Session; import io.trino.execution.DynamicFiltersCollector; import io.trino.execution.DynamicFiltersCollector.VersionedDynamicFilterDomains; +import io.trino.execution.ExecutionFailureInfo; import io.trino.execution.FutureStateChange; import io.trino.execution.NodeTaskMap.PartitionedSplitCountTracker; import io.trino.execution.PartitionedSplitsInfo; @@ -57,12 +58,12 @@ import io.trino.server.FailTaskRequest; import io.trino.server.TaskUpdateRequest; import io.trino.spi.SplitWeight; +import io.trino.spi.TrinoException; import io.trino.spi.TrinoTransportException; import io.trino.sql.planner.PlanFragment; import io.trino.sql.planner.plan.DynamicFilterId; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.util.Failures; import org.joda.time.DateTime; import javax.annotation.concurrent.GuardedBy; @@ -100,6 +101,7 @@ import static io.airlift.http.client.Request.Builder.prepareDelete; import static io.airlift.http.client.Request.Builder.preparePost; import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator; +import static io.airlift.units.Duration.nanosSince; import static io.trino.SystemSessionProperties.getMaxRemoteTaskRequestSize; import static io.trino.SystemSessionProperties.getMaxUnacknowledgedSplitsPerTask; import static io.trino.SystemSessionProperties.getRemoteTaskGuaranteedSplitsPerRequest; @@ -107,12 +109,14 @@ import static io.trino.SystemSessionProperties.isRemoteTaskAdaptiveUpdateRequestSizeEnabled; import static io.trino.execution.DynamicFiltersCollector.INITIAL_DYNAMIC_FILTERS_VERSION; import static io.trino.execution.TaskInfo.createInitialTask; -import static io.trino.execution.TaskState.ABORTED; import static io.trino.execution.TaskState.FAILED; import static io.trino.execution.TaskStatus.failWith; import static io.trino.server.remotetask.RequestErrorTracker.logError; +import static io.trino.spi.HostAddress.fromUri; +import static io.trino.spi.StandardErrorCode.REMOTE_TASK_ERROR; import static io.trino.util.Failures.toFailure; import static java.lang.Math.addExact; +import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.NANOSECONDS; @@ -137,6 +141,7 @@ public final class HttpRemoteTask private final DynamicFiltersCollector outboundDynamicFiltersCollector; // The version of dynamic filters that has been successfully sent to the worker private final AtomicLong sentDynamicFiltersVersion = new AtomicLong(INITIAL_DYNAMIC_FILTERS_VERSION); + private final AtomicLong terminationStartedNanos = new AtomicLong(); private final AtomicReference> currentRequest = new AtomicReference<>(); @@ -167,6 +172,7 @@ public final class HttpRemoteTask private final Executor executor; private final ScheduledExecutorService errorScheduledExecutor; private final Duration maxErrorDuration; + private final Duration taskTerminationTimeout; private final JsonCodec taskInfoCodec; private final JsonCodec taskUpdateRequestCodec; @@ -180,7 +186,8 @@ public final class HttpRemoteTask private final PartitionedSplitCountTracker partitionedSplitCountTracker; private final AtomicBoolean started = new AtomicBoolean(false); - private final AtomicBoolean aborting = new AtomicBoolean(false); + private final AtomicBoolean terminating = new AtomicBoolean(false); + private final AtomicBoolean cleanedUp = new AtomicBoolean(false); private final int guaranteedSplitsPerRequest; private final long maxRequestSizeInBytes; @@ -202,6 +209,7 @@ public HttpRemoteTask( Duration maxErrorDuration, Duration taskStatusRefreshMaxWait, Duration taskInfoUpdateInterval, + Duration taskTerminationTimeout, boolean summarizeTaskInfo, JsonCodec taskStatusCodec, JsonCodec dynamicFilterDomainsCodec, @@ -240,6 +248,7 @@ public HttpRemoteTask( this.executor = executor; this.errorScheduledExecutor = errorScheduledExecutor; this.maxErrorDuration = requireNonNull(maxErrorDuration, "maxErrorDuration is null"); + this.taskTerminationTimeout = requireNonNull(taskTerminationTimeout, "taskTerminationTimeout is null"); this.summarizeTaskInfo = summarizeTaskInfo; this.taskInfoCodec = taskInfoCodec; this.taskUpdateRequestCodec = taskUpdateRequestCodec; @@ -291,7 +300,7 @@ public HttpRemoteTask( TaskInfo initialTask = createInitialTask(taskId, location, nodeId, pipelinedBufferStates, new TaskStats(DateTime.now(), null)); this.dynamicFiltersFetcher = new DynamicFiltersFetcher( - this::fail, + this::fatalUnacknowledgedFailure, taskId, location, taskStatusRefreshMaxWait, @@ -304,7 +313,7 @@ public HttpRemoteTask( dynamicFilterService); this.taskStatusFetcher = new ContinuousTaskStatusFetcher( - this::fail, + this::fatalUnacknowledgedFailure, initialTask.getTaskStatus(), taskStatusRefreshMaxWait, taskStatusCodec, @@ -316,7 +325,7 @@ public HttpRemoteTask( stats); this.taskInfoFetcher = new TaskInfoFetcher( - this::fail, + this::fatalUnacknowledgedFailure, taskStatusFetcher, initialTask, httpClient, @@ -332,8 +341,9 @@ public HttpRemoteTask( taskStatusFetcher.addStateChangeListener(newStatus -> { TaskState state = newStatus.getState(); - if (state.isDone()) { - cleanUpTask(); + // cleanup when done or partially cleanup when terminating begins + if (state.isTerminatingOrDone()) { + cleanUpTask(state); } else { partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); @@ -397,7 +407,7 @@ public synchronized void addSplits(Multimap splitsBySource) requireNonNull(splitsBySource, "splitsBySource is null"); // only add pending split if not done - if (getTaskStatus().getState().isDone()) { + if (getTaskStatus().getState().isTerminatingOrDone() || terminating.get()) { return; } @@ -435,7 +445,7 @@ public synchronized void addSplits(Multimap splitsBySource) @Override public synchronized void noMoreSplits(PlanNodeId sourceId) { - if (noMoreSplits.containsKey(sourceId)) { + if (noMoreSplits.containsKey(sourceId) || terminating.get()) { return; } @@ -446,7 +456,7 @@ public synchronized void noMoreSplits(PlanNodeId sourceId) @Override public void setOutputBuffers(OutputBuffers newOutputBuffers) { - if (getTaskStatus().getState().isDone()) { + if (getTaskStatus().getState().isTerminatingOrDone() || terminating.get()) { return; } @@ -470,6 +480,10 @@ public PartitionedSplitsInfo getPartitionedSplitsInfo() if (taskStatus.getState().isDone()) { return PartitionedSplitsInfo.forZeroSplits(); } + // Do not consider queued or unacknowledged splits if the task is in the process of terminating + if (taskStatus.getState().isTerminating()) { + return PartitionedSplitsInfo.forSplitCountAndWeightSum(taskStatus.getRunningPartitionedDrivers(), taskStatus.getRunningPartitionedSplitsWeight()); + } PartitionedSplitsInfo unacknowledgedSplitsInfo = getUnacknowledgedPartitionedSplitsInfo(); int count = unacknowledgedSplitsInfo.getCount() + taskStatus.getQueuedPartitionedDrivers() + taskStatus.getRunningPartitionedDrivers(); long weight = unacknowledgedSplitsInfo.getWeightSum() + taskStatus.getQueuedPartitionedSplitsWeight() + taskStatus.getRunningPartitionedSplitsWeight(); @@ -488,7 +502,7 @@ public PartitionedSplitsInfo getUnacknowledgedPartitionedSplitsInfo() public PartitionedSplitsInfo getQueuedPartitionedSplitsInfo() { TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { + if (taskStatus.getState().isTerminatingOrDone()) { return PartitionedSplitsInfo.forZeroSplits(); } PartitionedSplitsInfo unacknowledgedSplitsInfo = getUnacknowledgedPartitionedSplitsInfo(); @@ -518,7 +532,7 @@ private int getPendingSourceSplitCount() private long getQueuedPartitionedSplitsWeight() { TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { + if (taskStatus.getState().isTerminatingOrDone()) { return 0; } return getPendingSourceSplitsWeight() + taskStatus.getQueuedPartitionedSplitsWeight(); @@ -670,8 +684,8 @@ boolean adjustSplitBatchSize(List splitAssignments, long reques private void sendUpdate() { TaskStatus taskStatus = getTaskStatus(); - // don't update if the task is already finished - if (taskStatus.getState().isDone()) { + // don't update if the task is already finishing or finished, or if we have sent a termination command + if (taskStatus.getState().isTerminatingOrDone() || terminating.get()) { return; } checkState(started.get()); @@ -768,29 +782,50 @@ private synchronized SplitAssignment getSplitAssignment(PlanNodeId planNodeId, i } @Override - public synchronized void cancel() + public void abort() { - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { - return; + // Only trigger abort commands if we aren't already canceling or failing the task remotely + if (!terminating.compareAndSet(false, true)) { + return; + } + + synchronized (this) { + try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { + if (!getTaskStatus().getState().isTerminatingOrDone()) { + scheduleAsyncCleanupRequest("abort", true); + } } + } + } + + @Override + public void cancel() + { + // Only cancel the task if we aren't already attempting to abort or fail the task remotely + if (!terminating.compareAndSet(false, true)) { + return; + } - // send cancel to task and ignore response - scheduleAsyncCleanupRequest(new Backoff(maxErrorDuration), "cancel", false); + synchronized (this) { + try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { + TaskStatus taskStatus = getTaskStatus(); + if (!taskStatus.getState().isTerminatingOrDone()) { + scheduleAsyncCleanupRequest("cancel", false); + } + } } } - private void cleanUpTask() + private void cleanUpTask(TaskState taskState) { - checkState(getTaskStatus().getState().isDone(), "attempt to clean up a task that is not done yet"); + checkState(taskState.isTerminatingOrDone(), "attempt to clean up a task that is not terminating or done: %s", taskState); // clear pending splits to free memory synchronized (this) { pendingSplits.clear(); pendingSourceSplitCount = 0; pendingSourceSplitsWeight = 0; - partitionedSplitCountTracker.setPartitionedSplits(PartitionedSplitsInfo.forZeroSplits()); + partitionedSplitCountTracker.setPartitionedSplits(getPartitionedSplitsInfo()); splitQueueHasSpace = true; whenSplitQueueHasSpace.complete(null, executor); } @@ -798,54 +833,61 @@ private void cleanUpTask() // clear pending outbound dynamic filters to free memory outboundDynamicFiltersCollector.acknowledge(Long.MAX_VALUE); - // cancel pending request - Future request = currentRequest.getAndSet(null); - if (request != null) { - request.cancel(true); - } - - taskStatusFetcher.stop(); - - // The remote task is likely to get a delete from the PageBufferClient first. - // We send an additional delete anyway to get the final TaskInfo - scheduleAsyncCleanupRequest(new Backoff(maxErrorDuration), "cleanup", true); - } + // only when termination is complete do we shut down status fetching + if (taskState.isDone()) { + // stop continuously fetching task status + taskStatusFetcher.stop(); + // cancel pending request + Future request = currentRequest.getAndSet(null); + if (request != null) { + request.cancel(true); + } - @Override - public synchronized void abort() - { - if (getTaskStatus().getState().isDone()) { - return; + // The remote task is likely to get a delete from the PageBufferClient first. + // We send an additional delete anyway to get the final TaskInfo + scheduleAsyncCleanupRequest("cleanup", true); } - - TaskStatus status = failWith(getTaskStatus(), ABORTED, ImmutableList.of()); - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - taskStatusFetcher.updateTaskStatus(status); - // send abort to task - scheduleAsyncCleanupRequest(new Backoff(maxErrorDuration), "abort", true); + else { + // check for termination timeout + long terminationStartedNanos = this.terminationStartedNanos.get(); + if (terminationStartedNanos == 0) { + long currentTimeNanos = System.nanoTime(); + // If reported time is exactly 0, increase it by 1 nanosecond so that "0" can be treated as "not set" + if (currentTimeNanos == 0) { + currentTimeNanos = 1; + } + this.terminationStartedNanos.compareAndSet(0, currentTimeNanos); + } + else { + Duration terminatingTime = nanosSince(terminationStartedNanos); + if (terminatingTime.compareTo(taskTerminationTimeout) >= 0) { + // timeout and force cleanup locally + fatalUnacknowledgedFailure(new TrinoException(REMOTE_TASK_ERROR, format("Task %s failed to terminate after %s, last known state: %s", taskId, taskTerminationTimeout, taskState))); + } + } } } - private void scheduleAsyncCleanupRequest(Backoff cleanupBackoff, String action, boolean abort) + private void scheduleAsyncCleanupRequest(String action, boolean abort) { - scheduleAsyncCleanupRequest(cleanupBackoff, action, () -> buildDeleteTaskRequest(abort)); + scheduleAsyncCleanupRequest(action, () -> buildDeleteTaskRequest(abort)); } - private void scheduleAsyncCleanupRequest(Backoff cleanupBackoff, String action, FailTaskRequest failTaskRequest) + private void scheduleAsyncCleanupRequest(String action, FailTaskRequest failTaskRequest) { - scheduleAsyncCleanupRequest(cleanupBackoff, action, () -> buildFailTaskRequest(failTaskRequest)); + scheduleAsyncCleanupRequest(action, () -> buildFailTaskRequest(failTaskRequest)); } - private void scheduleAsyncCleanupRequest(Backoff cleanupBackoff, String action, Supplier remoteRequestSupplier) + private void scheduleAsyncCleanupRequest(String action, Supplier remoteRequestSupplier) { - if (!aborting.compareAndSet(false, true)) { - // Do not initiate another round of cleanup requests if one had been initiated. - // Otherwise, we can get into an asynchronous recursion here. For example, when aborting a task after REMOTE_TASK_MISMATCH. + // Only allow a single final cleanup request once the task status sees a final state + TaskState taskStatusState = getTaskStatus().getState(); + if (taskStatusState.isDone() && !cleanedUp.compareAndSet(false, true)) { return; } Request request = remoteRequestSupplier.get(); - doScheduleAsyncCleanupRequest(cleanupBackoff, request, action); + doScheduleAsyncCleanupRequest(new Backoff(maxErrorDuration), request, action); } private Request buildDeleteTaskRequest(boolean abort) @@ -878,30 +920,35 @@ public void onSuccess(JsonResponse result) updateTaskInfo(result.getValue()); } finally { - if (!getTaskInfo().getTaskStatus().getState().isDone()) { - cleanUpLocally(); + // if cleanup operation has not at least started task termination, mark the task failed + TaskState taskState = getTaskInfo().getTaskStatus().getState(); + if (!taskState.isTerminatingOrDone()) { + fatalUnacknowledgedFailure(new TrinoTransportException(REMOTE_TASK_ERROR, fromUri(request.getUri()), format("Unable to %s task at %s, last known state was: %s", action, request.getUri(), taskState))); } } } @Override + @SuppressWarnings("FormatStringAnnotation") // we manipulate the format string and there's no way to make Error Prone accept the result public void onFailure(Throwable t) { - if (t instanceof RejectedExecutionException && httpClient.isClosed()) { - logError(t, "Unable to %s task at %s. HTTP client is closed.", action, request.getUri()); - cleanUpLocally(); + // final task info has been received, no need to resend the request + if (getTaskInfo().getTaskStatus().getState().isDone()) { return; } - // record failure - if (cleanupBackoff.failure()) { - logError(t, "Unable to %s task at %s. Back off depleted.", action, request.getUri()); - cleanUpLocally(); + if (t instanceof RejectedExecutionException && httpClient.isClosed()) { + String message = format("Unable to %s task at %s. HTTP client is closed.", action, request.getUri()); + logError(t, message); + fatalUnacknowledgedFailure(new TrinoTransportException(REMOTE_TASK_ERROR, fromUri(request.getUri()), message)); return; } - // final task info is set - if (taskInfoFetcher.getTaskInfo().getTaskStatus().getState().isDone()) { + // record failure + if (cleanupBackoff.failure()) { + String message = format("Unable to %s task at %s. Back off depleted.", action, request.getUri()); + logError(t, message); + fatalUnacknowledgedFailure(new TrinoTransportException(REMOTE_TASK_ERROR, fromUri(request.getUri()), message)); return; } @@ -920,25 +967,35 @@ public void onFailure(Throwable t) /** * Move the task directly to the failed state if there was a failure in this task */ - @Override - public synchronized void fail(Throwable cause) + private synchronized void fatalUnacknowledgedFailure(Throwable cause) { - TaskStatus taskStatus = getTaskStatus(); - if (!taskStatus.getState().isDone()) { - log.debug(cause, "Remote task %s failed with %s", taskStatus.getSelf(), cause); - } - - TaskStatus status = failWith(getTaskStatus(), FAILED, ImmutableList.of(toFailure(cause))); - taskStatusFetcher.updateTaskStatus(status); - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - if (cause instanceof TrinoTransportException) { - // task is unreachable - cleanUpLocally(); - } - else { - // send abort to task - scheduleAsyncCleanupRequest(new Backoff(maxErrorDuration), "abort", true); + TaskStatus taskStatus = getTaskStatus(); + if (!taskStatus.getState().isDone()) { + // Update the taskInfo with the new taskStatus. + + // Generally, we send a cleanup request to the worker, and update the TaskInfo on + // the coordinator based on what we fetched from the worker. If we somehow cannot + // get the cleanup request to the worker, the TaskInfo that we fetch for the worker + // likely will not say the task is done however many times we try. In this case, + // we have to set the local query info directly so that we stop trying to fetch + // updated TaskInfo from the worker. This way, the task on the worker eventually + // expires due to lack of activity. + List failures = ImmutableList.builderWithExpectedSize(taskStatus.getFailures().size() + 1) + .add(toFailure(cause)) + .addAll(taskStatus.getFailures()) + .build(); + taskStatus = failWith(taskStatus, FAILED, failures); + if (cause instanceof TrinoTransportException) { + // Since this TaskInfo is updated in the client without having received them from the + // worker, the stats may not reflect the actual final stats had we been able to reach the worker + // to get them. + updateTaskInfo(getTaskInfo().withTaskStatus(taskStatus)); + } + else { + // Let the status callbacks trigger the cleanup command remotely after switching states + taskStatusFetcher.updateTaskStatus(taskStatus); + } } } } @@ -947,36 +1004,43 @@ public synchronized void fail(Throwable cause) * Trigger remote task failure. Task status will be updated only when request sent to remote node returns. */ @Override - public synchronized void failRemotely(Throwable cause) + public void failRemotely(Throwable cause) { - try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { - TaskStatus taskStatus = getTaskStatus(); - if (taskStatus.getState().isDone()) { - return; + if (!terminating.compareAndSet(false, true)) { + return; + } + + synchronized (this) { + try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { + TaskStatus taskStatus = getTaskStatus(); + if (!taskStatus.getState().isTerminatingOrDone()) { + log.debug(cause, "Remote task %s failed with %s", taskStatus.getSelf(), cause); + scheduleAsyncCleanupRequest("fail", new FailTaskRequest(toFailure(cause))); + } } - scheduleAsyncCleanupRequest(new Backoff(maxErrorDuration), "fail", new FailTaskRequest(Failures.toFailure(cause))); } } - private void cleanUpLocally() + @Override + public void failLocallyImmediately(Throwable cause) { - // Update the taskInfo with the new taskStatus. - - // Generally, we send a cleanup request to the worker, and update the TaskInfo on - // the coordinator based on what we fetched from the worker. If we somehow cannot - // get the cleanup request to the worker, the TaskInfo that we fetch for the worker - // likely will not say the task is done however many times we try. In this case, - // we have to set the local query info directly so that we stop trying to fetch - // updated TaskInfo from the worker. This way, the task on the worker eventually - // expires due to lack of activity. - - // This is required because the query state machine depends on TaskInfo (instead of task status) - // to transition its own state. - // TODO: Update the query state machine and stage state machine to depend on TaskStatus instead - - // Since this TaskInfo is updated in the client the "complete" flag will not be set, - // indicating that the stats may not reflect the final stats on the worker. - updateTaskInfo(getTaskInfo().withTaskStatus(getTaskStatus())); + requireNonNull(cause, "cause is null"); + // Prevent concurrent abort commands after this point + terminating.set(true); + synchronized (this) { + try (SetThreadName ignored = new SetThreadName("HttpRemoteTask-%s", taskId)) { + TaskStatus taskStatus = getTaskStatus(); + if (!taskStatus.getState().isDone()) { + // Record and force the task into a failed state immediately without waiting for the task to respond. A final cleanup + // command will be sent to the task, but will not await the response + List failures = ImmutableList.builderWithExpectedSize(taskStatus.getFailures().size() + 1) + .add(toFailure(cause)) + .addAll(taskStatus.getFailures()) + .build(); + taskStatusFetcher.updateTaskStatus(failWith(taskStatus, FAILED, failures)); + } + } + } } private HttpUriBuilder getHttpUriBuilder(TaskStatus taskStatus) @@ -1049,11 +1113,11 @@ public void failed(Throwable cause) scheduleUpdate(); } catch (Error e) { - fail(e); + fatalUnacknowledgedFailure(e); throw e; } catch (RuntimeException e) { - fail(e); + fatalUnacknowledgedFailure(e); } } } @@ -1062,13 +1126,13 @@ public void failed(Throwable cause) public void fatal(Throwable cause) { try (SetThreadName ignored = new SetThreadName("UpdateResponseHandler-%s", taskId)) { - fail(cause); + fatalUnacknowledgedFailure(cause); } } private void updateStats() { - Duration requestRoundTrip = Duration.nanosSince(currentRequestStartNanos); + Duration requestRoundTrip = nanosSince(currentRequestStartNanos); stats.updateRoundTripMillis(requestRoundTrip.toMillis()); } } diff --git a/core/trino-main/src/main/java/io/trino/server/remotetask/RequestErrorTracker.java b/core/trino-main/src/main/java/io/trino/server/remotetask/RequestErrorTracker.java index 5f49981c4e61..d8cb573c3a94 100644 --- a/core/trino-main/src/main/java/io/trino/server/remotetask/RequestErrorTracker.java +++ b/core/trino-main/src/main/java/io/trino/server/remotetask/RequestErrorTracker.java @@ -13,10 +13,8 @@ */ package io.trino.server.remotetask; -import com.google.common.collect.ObjectArrays; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; -import com.google.errorprone.annotations.FormatMethod; import io.airlift.event.client.ServiceUnavailableException; import io.airlift.log.Logger; import io.airlift.units.Duration; @@ -139,15 +137,14 @@ public void requestFailed(Throwable reason) } } - @FormatMethod @SuppressWarnings("FormatStringAnnotation") // we manipulate the format string and there's no way to make Error Prone accept the result - static void logError(Throwable t, String format, Object... args) + static void logError(Throwable t, String message) { if (isExpectedError(t)) { - log.error(format + ": %s", ObjectArrays.concat(args, t)); + log.error("%s: %s", message, t); } else { - log.error(t, format, args); + log.error(t, message); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java index 986422e7d8d9..c8f97cfe6fc1 100644 --- a/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/MockRemoteTaskFactory.java @@ -40,7 +40,6 @@ import io.trino.metadata.Split; import io.trino.operator.TaskContext; import io.trino.operator.TaskStats; -import io.trino.spi.SplitWeight; import io.trino.spiller.SpillSpaceTracker; import io.trino.sql.planner.Partitioning; import io.trino.sql.planner.PartitioningScheme; @@ -58,7 +57,6 @@ import java.net.URI; import java.util.ArrayList; -import java.util.Collection; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -238,35 +236,8 @@ public String getNodeId() @Override public TaskInfo getTaskInfo() { - TaskState state = taskStateMachine.getState(); - List failures = ImmutableList.of(); - if (state == TaskState.FAILED) { - failures = toFailures(taskStateMachine.getFailureCauses()); - } - return new TaskInfo( - new TaskStatus( - taskStateMachine.getTaskId(), - TASK_INSTANCE_ID, - nextTaskInfoVersion.getAndIncrement(), - state, - location, - nodeId, - failures, - 0, - 0, - outputBuffer.getStatus(), - DataSize.ofBytes(0), - DataSize.ofBytes(0), - Optional.empty(), - DataSize.ofBytes(0), - DataSize.ofBytes(0), - DataSize.ofBytes(0), - 0, - new Duration(0, MILLISECONDS), - INITIAL_DYNAMIC_FILTERS_VERSION, - 0L, - 0L), + getTaskStatus(), DateTime.now(), outputBuffer.getInfo(), ImmutableSet.of(), @@ -276,18 +247,24 @@ public TaskInfo getTaskInfo() } @Override - public TaskStatus getTaskStatus() + public synchronized TaskStatus getTaskStatus() { + TaskState state = taskStateMachine.getState(); + List failures = ImmutableList.of(); + if (state == TaskState.FAILED || state == TaskState.FAILING) { + failures = toFailures(taskStateMachine.getFailureCauses()); + } + TaskStats stats = taskContext.getTaskStats(); PartitionedSplitsInfo combinedSplitsInfo = getPartitionedSplitsInfo(); PartitionedSplitsInfo queuedSplitsInfo = getQueuedPartitionedSplitsInfo(); return new TaskStatus(taskStateMachine.getTaskId(), TASK_INSTANCE_ID, nextTaskInfoVersion.get(), - taskStateMachine.getState(), + state, location, nodeId, - ImmutableList.of(), + failures, queuedSplitsInfo.getCount(), combinedSplitsInfo.getCount() - queuedSplitsInfo.getCount(), outputBuffer.getStatus(), @@ -306,6 +283,9 @@ public TaskStatus getTaskStatus() private synchronized void updateSplitQueueSpace() { + if (runningDrivers == 0 && taskStateMachine.getState().isTerminating()) { + taskStateMachine.terminationComplete(); + } if (unacknowledgedSplits < maxUnacknowledgedSplits && getQueuedPartitionedSplitsInfo().getWeightSum() < 900L) { if (!whenSplitQueueHasSpace.isDone()) { whenSplitQueueHasSpace.set(null); @@ -356,8 +336,9 @@ public synchronized void setUnacknowledgedSplits(int unacknowledgedSplits) public synchronized void startSplits(int maxRunning) { - runningDrivers = splits.size(); - runningDrivers = Math.min(runningDrivers, maxRunning); + if (!taskStateMachine.getState().isTerminatingOrDone()) { + runningDrivers = Math.min(splits.size(), maxRunning); + } updateSplitQueueSpace(); } @@ -365,7 +346,10 @@ public synchronized void startSplits(int maxRunning) public void start() { taskStateMachine.addStateChangeListener(newValue -> { - if (newValue.isDone()) { + if (newValue.isTerminating()) { + updateSplitQueueSpace(); // potentially finish termination if runningDrivers is zero + } + else if (newValue.isDone()) { clearSplits(); } }); @@ -441,41 +425,49 @@ public void abort() } @Override - public void fail(Throwable cause) + public void failRemotely(Throwable cause) { taskStateMachine.failed(cause); clearSplits(); } @Override - public void failRemotely(Throwable cause) + public void failLocallyImmediately(Throwable cause) { taskStateMachine.failed(cause); clearSplits(); } @Override - public PartitionedSplitsInfo getPartitionedSplitsInfo() + public synchronized PartitionedSplitsInfo getPartitionedSplitsInfo() { if (taskStateMachine.getState().isDone()) { return PartitionedSplitsInfo.forZeroSplits(); } - synchronized (this) { - int count = 0; - long weight = 0; - for (PlanNodeId tableScanPlanNodeId : fragment.getPartitionedSources()) { - Collection partitionedSplits = splits.get(tableScanPlanNodeId); - count += partitionedSplits.size(); - weight = addExact(weight, SplitWeight.rawValueSum(partitionedSplits, Split::getSplitWeight)); + // Queued splits are ignored once a task beings terminating, since they will never be started + boolean countQueued = !taskStateMachine.getState().isTerminating(); + // Let's consider the first drivers encountered to be "running" + int remainingRunning = runningDrivers; + int splitCount = 0; + long splitWeight = 0; + for (PlanNodeId tableScanPlanNodeId : fragment.getPartitionedSources()) { + for (Split split : splits.get(tableScanPlanNodeId)) { + if (countQueued || remainingRunning > 0) { + if (remainingRunning > 0) { + remainingRunning--; + } + splitCount++; + splitWeight = addExact(splitWeight, split.getSplitWeight().getRawValue()); + } } - return PartitionedSplitsInfo.forSplitCountAndWeightSum(count, weight); } + return PartitionedSplitsInfo.forSplitCountAndWeightSum(splitCount, splitWeight); } @Override public synchronized PartitionedSplitsInfo getQueuedPartitionedSplitsInfo() { - if (taskStateMachine.getState().isDone()) { + if (taskStateMachine.getState().isTerminatingOrDone()) { return PartitionedSplitsInfo.forZeroSplits(); } // Let's consider the first drivers encountered to be "running" diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java index d2bc8c895139..17122a41da3b 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlStage.java @@ -14,8 +14,8 @@ package io.trino.execution; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.util.concurrent.SettableFuture; import io.trino.client.NodeVersion; @@ -23,6 +23,7 @@ import io.trino.execution.buffer.PipelinedOutputBuffers; import io.trino.execution.scheduler.SplitSchedulerStats; import io.trino.metadata.InternalNode; +import io.trino.metadata.Split; import io.trino.operator.RetryPolicy; import io.trino.spi.QueryId; import io.trino.spi.type.Type; @@ -34,13 +35,18 @@ import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.testing.TestingSplit; import io.trino.util.FinalizerService; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import java.net.URI; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; @@ -54,12 +60,15 @@ import static io.trino.sql.planner.SystemPartitioningHandle.SINGLE_DISTRIBUTION; import static io.trino.sql.planner.SystemPartitioningHandle.SOURCE_DISTRIBUTION; import static io.trino.sql.planner.plan.ExchangeNode.Type.REPARTITION; +import static io.trino.testing.TestingHandles.TEST_CATALOG_HANDLE; import static java.util.concurrent.Executors.newFixedThreadPool; import static java.util.concurrent.Executors.newScheduledThreadPool; import static java.util.concurrent.TimeUnit.MINUTES; +import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertFalse; import static org.testng.Assert.assertSame; import static org.testng.Assert.assertTrue; +import static org.testng.Assert.fail; public class TestSqlStage { @@ -115,9 +124,13 @@ private void testFinalStageInfoInternal() stage.addFinalStageInfoListener(finalStageInfo::set); // in a background thread add a ton of tasks - CountDownLatch latch = new CountDownLatch(1000); + CompletableFuture stopped = new CompletableFuture<>(); + CountDownLatch countDownLatch = new CountDownLatch(1000); + List createdTasks = Collections.synchronizedList(new ArrayList<>(2000)); Future addTasksTask = executor.submit(() -> { try { + PlanNodeId planNodeId = stage.getFragment().getPartitionedSources().get(0); + ImmutableListMultimap initialSplits = ImmutableListMultimap.of(planNodeId, new Split(TEST_CATALOG_HANDLE, new TestingSplit(true, ImmutableList.of()))); for (int i = 0; i < 1_000_000; i++) { if (Thread.interrupted()) { return; @@ -127,38 +140,77 @@ private void testFinalStageInfoInternal() URI.create("http://10.0.0." + (i / 10_000) + ":" + (i % 10_000)), NodeVersion.UNKNOWN, false); - stage.createTask( + Optional created = stage.createTask( node, i, 0, Optional.empty(), PipelinedOutputBuffers.createInitial(ARBITRARY), - ImmutableMultimap.of(), + initialSplits, ImmutableSet.of(), Optional.empty()); - latch.countDown(); + if (created.isPresent()) { + if (created.get() instanceof MockRemoteTaskFactory.MockRemoteTask mockTask) { + mockTask.start(); + mockTask.startSplits(1); + createdTasks.add(mockTask); + countDownLatch.countDown(); + } + else { + fail("Expected an instance of MockRemoteTask"); + } + } } } finally { - while (latch.getCount() > 0) { - latch.countDown(); + while (countDownLatch.getCount() > 0) { + countDownLatch.countDown(); } + stopped.complete(null); } }); - // wait for some tasks to be created, and then abort the query - latch.await(1, MINUTES); - assertFalse(stage.getStageInfo().getTasks().isEmpty()); + // wait for some tasks to be created, and then abort the stage + countDownLatch.await(); stage.finish(); + assertTrue(createdTasks.size() >= 1000); + + StageInfo stageInfo = stage.getStageInfo(); + // stage should not report final info because all tasks have a running driver, but + // all tasks should be cancelling + for (TaskInfo info : stageInfo.getTasks()) { + // Tasks can race with the stage finish operation and be cancelled fully before + // starting any splits running. These can report either cancelling or fully cancelled + // depending on the timing of TaskInfo being created + TaskState taskState = info.getTaskStatus().getState(); + int runningSplits = info.getTaskStatus().getRunningPartitionedDrivers(); + if (runningSplits == 0) { + assertTrue(taskState == TaskState.CANCELING || taskState == TaskState.CANCELED, "unexpected task state: " + taskState); + } + else { + assertEquals(taskState, TaskState.CANCELING); + assertTrue(runningSplits > 0, "must be running splits to not be already canceled"); + } + } + assertFalse(finalStageInfo.isDone()); + + // cancel the background thread adding tasks + addTasksTask.cancel(true); + // wait for the background thread to acknowledge having stopped new task creations + // so that we know that all tasks are present in the createdTasks list + stopped.join(); + + // finishing all running splits on the task should trigger termination complete + createdTasks.forEach(task -> { + task.clearSplits(); + assertEquals(task.getTaskStatus().getState(), TaskState.CANCELED); + }); // once the final stage info is available, verify that it is complete - StageInfo stageInfo = finalStageInfo.get(1, MINUTES); + stageInfo = finalStageInfo.get(1, MINUTES); assertFalse(stageInfo.getTasks().isEmpty()); assertTrue(stageInfo.isFinalStageInfo()); assertSame(stage.getStageInfo(), stageInfo); - - // cancel the background thread adding tasks - addTasksTask.cancel(true); } private static PlanFragment createExchangePlanFragment() diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java index a09a6f040ff4..86917ef94f29 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.stats.CounterStat; import io.airlift.stats.TestingGcMonitor; @@ -215,7 +216,15 @@ public void testCancel() assertNull(taskInfo.getStats().getEndTime()); taskInfo = sqlTask.cancel(); - assertEquals(taskInfo.getTaskStatus().getState(), TaskState.CANCELED); + // This call can race and report either cancelling or cancelled + assertTrue(taskInfo.getTaskStatus().getState().isTerminatingOrDone()); + // Task cancellation can race with output buffer state updates, but should transition to cancelled quickly + int attempts = 1; + while (!taskInfo.getTaskStatus().getState().isDone() && attempts < 3) { + taskInfo = Futures.getUnchecked(sqlTask.getTaskInfo(taskInfo.getTaskStatus().getVersion())); + attempts++; + } + assertEquals(taskInfo.getTaskStatus().getState(), TaskState.CANCELED, "Failed to see CANCELED after " + attempts + " attempts"); assertNotNull(taskInfo.getStats().getEndTime()); taskInfo = sqlTask.getTaskInfo(); @@ -289,7 +298,7 @@ public void testBufferCloseOnCancel() assertFalse(bufferResult.isDone()); sqlTask.cancel(); - assertEquals(sqlTask.getTaskInfo().getTaskStatus().getState(), TaskState.CANCELED); + assertTrue(sqlTask.getTaskInfo().getTaskStatus().getState().isTerminatingOrDone()); // buffer future will complete, the event is async so wait a bit for event to propagate bufferResult.get(1, SECONDS); @@ -312,6 +321,12 @@ public void testBufferNotCloseOnFail() long taskStatusVersion = sqlTask.getTaskInfo().getTaskStatus().getVersion(); sqlTask.failed(new Exception("test")); + // This call can race and return either FAILED or FAILING + TaskInfo taskInfo = sqlTask.getTaskInfo(taskStatusVersion).get(); + assertTrue(taskInfo.getTaskStatus().getState().isTerminatingOrDone()); + + // This call should resolve to FAILED if the prior call did not + taskStatusVersion = taskInfo.getTaskStatus().getVersion(); assertEquals(sqlTask.getTaskInfo(taskStatusVersion).get().getTaskStatus().getState(), TaskState.FAILED); // buffer will not be closed by fail event. event is async so wait a bit for event to fire diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java index 2505a927daee..039872c61ea9 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java @@ -170,6 +170,7 @@ public void testSimpleQuery() @Test public void testCancel() + throws InterruptedException, ExecutionException, TimeoutException { try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig())) { TaskId taskId = TASK_ID; @@ -181,7 +182,7 @@ public void testCancel() assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); assertNull(taskInfo.getStats().getEndTime()); - taskInfo = sqlTaskManager.cancelTask(taskId); + taskInfo = pollTerminatingTaskInfoUntilDone(sqlTaskManager, sqlTaskManager.cancelTask(taskId)); assertEquals(taskInfo.getTaskStatus().getState(), TaskState.CANCELED); assertNotNull(taskInfo.getStats().getEndTime()); @@ -193,6 +194,7 @@ public void testCancel() @Test public void testAbort() + throws InterruptedException, ExecutionException, TimeoutException { try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig())) { TaskId taskId = TASK_ID; @@ -204,7 +206,7 @@ public void testAbort() assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); assertNull(taskInfo.getStats().getEndTime()); - taskInfo = sqlTaskManager.abortTask(taskId); + taskInfo = pollTerminatingTaskInfoUntilDone(sqlTaskManager, sqlTaskManager.abortTask(taskId)); assertEquals(taskInfo.getTaskStatus().getState(), TaskState.ABORTED); assertNotNull(taskInfo.getStats().getEndTime()); @@ -237,7 +239,7 @@ public void testAbortResults() @Test public void testRemoveOldTasks() - throws Exception + throws InterruptedException, ExecutionException, TimeoutException { try (SqlTaskManager sqlTaskManager = createSqlTaskManager(new TaskManagerConfig().setInfoMaxAge(new Duration(5, TimeUnit.MILLISECONDS)))) { TaskId taskId = TASK_ID; @@ -245,7 +247,7 @@ public void testRemoveOldTasks() TaskInfo taskInfo = createTask(sqlTaskManager, taskId, PipelinedOutputBuffers.createInitial(PARTITIONED).withBuffer(OUT, 0).withNoMoreBufferIds()); assertEquals(taskInfo.getTaskStatus().getState(), TaskState.RUNNING); - taskInfo = sqlTaskManager.cancelTask(taskId); + taskInfo = pollTerminatingTaskInfoUntilDone(sqlTaskManager, sqlTaskManager.cancelTask(taskId)); assertEquals(taskInfo.getTaskStatus().getState(), TaskState.CANCELED); taskInfo = sqlTaskManager.getTaskInfo(taskId); @@ -291,7 +293,7 @@ public void testFailStuckSplitTasks() try (SqlTaskManager sqlTaskManager = createSqlTaskManager(taskManagerConfig, new NodeMemoryConfig(), taskExecutor, stackTraceElements -> true)) { sqlTaskManager.addStateChangeListener(TASK_ID, (state) -> { - if (state.isDone()) { + if (state.isTerminatingOrDone() && !taskHandle.isDestroyed()) { taskExecutor.removeTask(taskHandle); } }); @@ -300,8 +302,11 @@ public void testFailStuckSplitTasks() sqlTaskManager.failStuckSplitTasks(); mockSplitRunner.waitForFinish(); - assertEquals(sqlTaskManager.getAllTaskInfo().size(), 1); - assertEquals(sqlTaskManager.getAllTaskInfo().get(0).getTaskStatus().getState(), TaskState.FAILED); + List taskInfos = sqlTaskManager.getAllTaskInfo(); + assertEquals(taskInfos.size(), 1); + + TaskInfo taskInfo = pollTerminatingTaskInfoUntilDone(sqlTaskManager, taskInfos.get(0)); + assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FAILED); } } finally { @@ -429,6 +434,18 @@ private TaskInfo createTask(SqlTaskManager sqlTaskManager, TaskId taskId, Output ImmutableMap.of()); } + private static TaskInfo pollTerminatingTaskInfoUntilDone(SqlTaskManager taskManager, TaskInfo taskInfo) + throws InterruptedException, ExecutionException, TimeoutException + { + assertTrue(taskInfo.getTaskStatus().getState().isTerminatingOrDone()); + int attempts = 3; + while (attempts > 0 && taskInfo.getTaskStatus().getState().isTerminating()) { + taskInfo = taskManager.getTaskInfo(taskInfo.getTaskStatus().getTaskId(), taskInfo.getTaskStatus().getVersion()).get(5, SECONDS); + attempts--; + } + return taskInfo; + } + public static class MockDirectExchangeClientSupplier implements DirectExchangeClientSupplier { diff --git a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java index 7a04e0165390..6afeccbf15e0 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestTaskManagerConfig.java @@ -44,6 +44,7 @@ public void testDefaults() .setSplitConcurrencyAdjustmentInterval(new Duration(100, TimeUnit.MILLISECONDS)) .setStatusRefreshMaxWait(new Duration(1, TimeUnit.SECONDS)) .setInfoUpdateInterval(new Duration(3, TimeUnit.SECONDS)) + .setTaskTerminationTimeout(new Duration(1, TimeUnit.MINUTES)) .setPerOperatorCpuTimerEnabled(true) .setTaskCpuTimerEnabled(true) .setMaxWorkerThreads(Runtime.getRuntime().availableProcessors() * 2) @@ -88,6 +89,7 @@ public void testExplicitPropertyMappings() .put("task.split-concurrency-adjustment-interval", "1s") .put("task.status-refresh-max-wait", "2s") .put("task.info-update-interval", "2s") + .put("task.termination-timeout", "15s") .put("task.per-operator-cpu-timer-enabled", "false") .put("task.cpu-timer-enabled", "false") .put("task.max-index-memory", "512MB") @@ -127,6 +129,7 @@ public void testExplicitPropertyMappings() .setSplitConcurrencyAdjustmentInterval(new Duration(1, TimeUnit.SECONDS)) .setStatusRefreshMaxWait(new Duration(2, TimeUnit.SECONDS)) .setInfoUpdateInterval(new Duration(2, TimeUnit.SECONDS)) + .setTaskTerminationTimeout(new Duration(15, TimeUnit.SECONDS)) .setPerOperatorCpuTimerEnabled(false) .setTaskCpuTimerEnabled(false) .setMaxIndexMemoryUsage(DataSize.of(512, Unit.MEGABYTE)) diff --git a/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java index 769b7ea0f34e..0423053a8492 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestingRemoteTaskFactory.java @@ -264,12 +264,14 @@ public ListenableFuture whenSplitQueueHasSpace(long weightThreshold) public void cancel() { taskStateMachine.cancel(); + taskStateMachine.terminationComplete(); } @Override public void abort() { taskStateMachine.abort(); + taskStateMachine.terminationComplete(); } @Override @@ -279,15 +281,17 @@ public PartitionedSplitsInfo getPartitionedSplitsInfo() } @Override - public void fail(Throwable cause) + public void failRemotely(Throwable cause) { taskStateMachine.failed(cause); + taskStateMachine.terminationComplete(); } @Override - public void failRemotely(Throwable cause) + public void failLocallyImmediately(Throwable cause) { taskStateMachine.failed(cause); + taskStateMachine.terminationComplete(); } @Override diff --git a/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java b/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java index 72cb6598b894..f9b7a802765f 100644 --- a/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java +++ b/core/trino-main/src/test/java/io/trino/memory/TestLeastWastedEffortTaskLowMemoryKiller.java @@ -261,6 +261,7 @@ private static TaskInfo buildTaskInfo(TaskId taskId, TaskState state, Duration s null, null, null, + null, new Duration(0, MILLISECONDS), new Duration(0, MILLISECONDS), 0, diff --git a/core/trino-main/src/test/java/io/trino/operator/TestDriver.java b/core/trino-main/src/test/java/io/trino/operator/TestDriver.java index d34a7be83d6d..9880887c871b 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestDriver.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestDriver.java @@ -129,7 +129,7 @@ public void testConcurrentClose() // let these threads race scheduledExecutor.submit(() -> driver.processForDuration(new Duration(1, TimeUnit.NANOSECONDS))); // don't want to call isFinishedInternal in processFor scheduledExecutor.submit(driver::close); - while (!driverContext.isDone()) { + while (!driverContext.isTerminatingOrDone()) { Uninterruptibles.sleepUninterruptibly(1, TimeUnit.MILLISECONDS); } } @@ -206,10 +206,13 @@ public void testBrokenOperatorCloseWhileProcessing() driver.close(); assertTrue(driver.isFinished()); + assertFalse(driver.getDestroyedFuture().isDone()); assertThatThrownBy(() -> driverProcessFor.get(1, TimeUnit.SECONDS)) .isInstanceOf(ExecutionException.class) .hasCause(new TrinoException(GENERIC_INTERNAL_ERROR, "Driver was interrupted")); + + assertTrue(driver.getDestroyedFuture().isDone()); } @Test @@ -230,10 +233,12 @@ public void testBrokenOperatorProcessWhileClosing() assertTrue(driver.processForDuration(new Duration(1, TimeUnit.MILLISECONDS)).isDone()); assertTrue(driver.isFinished()); + assertFalse(driver.getDestroyedFuture().isDone()); brokenOperator.unlock(); assertTrue(driverClose.get()); + assertTrue(driver.getDestroyedFuture().isDone()); } @Test @@ -309,17 +314,20 @@ public void testBrokenOperatorAddSource() driver.updateSplitAssignment(new SplitAssignment(sourceId, ImmutableSet.of(new ScheduledSplit(0, sourceId, newMockSplit())), true)); - assertFalse(driver.isFinished()); + assertFalse(driver.getDestroyedFuture().isDone()); // processFor always returns NOT_BLOCKED, because DriveLockResult was not acquired assertTrue(driver.processForDuration(new Duration(1, TimeUnit.SECONDS)).isDone()); assertFalse(driver.isFinished()); driver.close(); assertTrue(driver.isFinished()); + assertFalse(driver.getDestroyedFuture().isDone()); assertThatThrownBy(() -> driverProcessFor.get(1, TimeUnit.SECONDS)) .isInstanceOf(ExecutionException.class) .hasCause(new TrinoException(GENERIC_INTERNAL_ERROR, "Driver was interrupted")); + + assertTrue(driver.getDestroyedFuture().isDone()); } private static Split newMockSplit() diff --git a/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java b/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java index a77581ac7305..1399e0ba7161 100644 --- a/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java +++ b/core/trino-main/src/test/java/io/trino/operator/TestTaskStats.java @@ -34,6 +34,7 @@ public class TestTaskStats new DateTime(1), new DateTime(2), new DateTime(100), + new DateTime(102), new DateTime(101), new DateTime(3), new Duration(4, NANOSECONDS), @@ -103,6 +104,7 @@ public static void assertExpectedTaskStats(TaskStats actual) assertEquals(actual.getCreateTime(), new DateTime(1, UTC)); assertEquals(actual.getFirstStartTime(), new DateTime(2, UTC)); assertEquals(actual.getLastStartTime(), new DateTime(100, UTC)); + assertEquals(actual.getTerminatingStartTime(), new DateTime(102, UTC)); assertEquals(actual.getLastEndTime(), new DateTime(101, UTC)); assertEquals(actual.getEndTime(), new DateTime(3, UTC)); assertEquals(actual.getElapsedTime(), new Duration(4, NANOSECONDS)); diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java index 1ae9f4190fb6..c2e39c064c21 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/BaseClickHouseConnectorTest.java @@ -291,7 +291,7 @@ protected TestTable createTableWithDefaultColumns() public void testCharVarcharComparison() { assertThatThrownBy(super::testCharVarcharComparison) - .hasMessageContaining("For query: ") + .hasMessageContaining("For query") .hasMessageContaining("Actual rows") .hasMessageContaining("Expected rows"); diff --git a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java index 6a7ee6876235..68870bb00bc5 100644 --- a/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java +++ b/plugin/trino-kudu/src/test/java/io/trino/plugin/kudu/TestKuduConnectorTest.java @@ -802,7 +802,7 @@ public void testCharVarcharComparison() { // TODO https://github.com/trinodb/trino/issues/3597 Fix Kudu CREATE TABLE AS SELECT with char(n) type does not preserve trailing spaces assertThatThrownBy(super::testCharVarcharComparison) - .hasMessageContaining("For query: ") + .hasMessageContaining("For query") .hasMessageContaining("Actual rows") .hasMessageContaining("Expected rows"); diff --git a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java index 26ae59f549d2..9e54e5a13268 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/DistributedQueryRunner.java @@ -484,17 +484,7 @@ public MaterializedResult execute(@Language("SQL") String sql) @Override public MaterializedResult execute(Session session, @Language("SQL") String sql) { - lock.readLock().lock(); - try { - return trinoClient.execute(session, sql).getResult(); - } - catch (Throwable e) { - e.addSuppressed(new Exception("SQL: " + sql)); - throw e; - } - finally { - lock.readLock().unlock(); - } + return executeWithQueryId(session, sql).getResult(); } public MaterializedResultWithQueryId executeWithQueryId(Session session, @Language("SQL") String sql) @@ -504,6 +494,10 @@ public MaterializedResultWithQueryId executeWithQueryId(Session session, @Langua ResultWithQueryId result = trinoClient.execute(session, sql); return new MaterializedResultWithQueryId(result.getQueryId(), result.getResult()); } + catch (Throwable e) { + e.addSuppressed(new Exception("SQL: " + sql)); + throw e; + } finally { lock.readLock().unlock(); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/QueryAssertions.java b/testing/trino-testing/src/main/java/io/trino/testing/QueryAssertions.java index 8f2edae5ab0b..b1c8b31558df 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/QueryAssertions.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/QueryAssertions.java @@ -24,6 +24,7 @@ import io.trino.Session; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.QualifiedObjectName; +import io.trino.spi.QueryId; import io.trino.spi.TrinoException; import io.trino.sql.parser.ParsingException; import io.trino.sql.planner.Plan; @@ -58,6 +59,11 @@ private QueryAssertions() public static void assertUpdate(QueryRunner queryRunner, Session session, @Language("SQL") String sql, OptionalLong count, Optional> planAssertion) { + if (queryRunner instanceof DistributedQueryRunner distributedQueryRunner) { + assertDistributedUpdate(distributedQueryRunner, session, sql, count, planAssertion); + return; + } + long start = System.nanoTime(); MaterializedResult results; Plan queryPlan; @@ -94,6 +100,51 @@ else if (count.isPresent()) { } } + private static void assertDistributedUpdate(DistributedQueryRunner distributedQueryRunner, Session session, @Language("SQL") String sql, OptionalLong count, Optional> planAssertion) + { + long start = System.nanoTime(); + Plan queryPlan = null; + MaterializedResultWithQueryId resultWithQueryId = distributedQueryRunner.executeWithQueryId(session, sql); + QueryId queryId = resultWithQueryId.getQueryId(); + MaterializedResult results = resultWithQueryId.getResult().toTestTypes(); + if (planAssertion.isPresent()) { + try { + queryPlan = distributedQueryRunner.getQueryPlan(queryId); + } + catch (RuntimeException e) { + fail("Failed to get query plan for query " + queryId, e); + } + } + + Duration queryTime = nanosSince(start); + if (queryTime.compareTo(Duration.succinctDuration(1, SECONDS)) > 0) { + log.info("FINISHED query %s in Trino: %s", queryId, queryTime); + } + + if (planAssertion.isPresent()) { + try { + planAssertion.get().accept(queryPlan); + } + catch (Exception e) { + fail("Plan assertion failed for query " + queryId, e); + } + } + + if (results.getUpdateType().isEmpty()) { + fail("update type is not set for query " + queryId); + } + + if (results.getUpdateCount().isPresent()) { + if (count.isEmpty()) { + fail("expected no update count, but got " + results.getUpdateCount().getAsLong() + " for query " + queryId); + } + assertEquals(results.getUpdateCount().getAsLong(), count.getAsLong(), "update count for query " + queryId); + } + else if (count.isPresent()) { + fail("update count is not present for query " + queryId); + } + } + public static void assertQuery( QueryRunner actualQueryRunner, Session session, @@ -129,6 +180,11 @@ private static void assertQuery( boolean compareUpdate, Optional> planAssertion) { + if (actualQueryRunner instanceof DistributedQueryRunner distributedQueryRunner) { + assertDistributedQuery(distributedQueryRunner, session, actual, h2QueryRunner, expected, ensureOrdering, compareUpdate, planAssertion); + return; + } + long start = System.nanoTime(); MaterializedResult actualResults = null; Plan queryPlan = null; @@ -204,6 +260,87 @@ private static void assertQuery( } } + private static void assertDistributedQuery( + DistributedQueryRunner distributedQueryRunner, + Session session, + @Language("SQL") String actual, + H2QueryRunner h2QueryRunner, + @Language("SQL") String expected, + boolean ensureOrdering, + boolean compareUpdate, + Optional> planAssertion) + { + long start = System.nanoTime(); + QueryId queryId = null; + MaterializedResult actualResults = null; + try { + MaterializedResultWithQueryId resultWithQueryId = distributedQueryRunner.executeWithQueryId(session, actual); + queryId = resultWithQueryId.getQueryId(); + actualResults = resultWithQueryId.getResult().toTestTypes(); + } + catch (RuntimeException ex) { + fail("Execution of 'actual' query failed: " + actual, ex); + } + if (planAssertion.isPresent()) { + try { + planAssertion.get().accept(distributedQueryRunner.getQueryPlan(queryId)); + } + catch (Throwable t) { + t.addSuppressed(new Exception(format("SQL: %s [QueryId: %s]", actual, queryId))); + throw t; + } + } + Duration actualTime = nanosSince(start); + + long expectedStart = System.nanoTime(); + MaterializedResult expectedResults = null; + try { + expectedResults = h2QueryRunner.execute(session, expected, actualResults.getTypes()); + } + catch (RuntimeException ex) { + fail("Execution of 'expected' query failed: " + expected, ex); + } + Duration totalTime = nanosSince(start); + if (totalTime.compareTo(Duration.succinctDuration(1, SECONDS)) > 0) { + log.info("FINISHED in Trino: %s, H2: %s, total: %s", actualTime, nanosSince(expectedStart), totalTime); + } + + if (actualResults.getUpdateType().isPresent() || actualResults.getUpdateCount().isPresent()) { + if (actualResults.getUpdateType().isEmpty()) { + fail("update count present without update type for query " + queryId + ": \n" + actual); + } + if (!compareUpdate) { + fail("update type should not be present (use assertUpdate) for query " + queryId + ": \n" + actual); + } + } + + List actualRows = actualResults.getMaterializedRows(); + List expectedRows = expectedResults.getMaterializedRows(); + + if (compareUpdate) { + if (actualResults.getUpdateType().isEmpty()) { + fail("update type not present for query " + queryId + ": \n" + actual); + } + if (actualResults.getUpdateCount().isEmpty()) { + fail("update count not present for query " + queryId + ": \n" + actual); + } + assertEquals(actualRows.size(), 1, "For query " + queryId + ": \n " + actual + "\n:"); + assertEquals(expectedRows.size(), 1, "For query " + queryId + ": \n " + actual + "\n:"); + MaterializedRow row = expectedRows.get(0); + assertEquals(row.getFieldCount(), 1, "For query " + queryId + ": \n " + actual + "\n:"); + assertEquals(row.getField(0), actualResults.getUpdateCount().getAsLong(), "For query " + queryId + ": \n " + actual + "\n:"); + } + + if (ensureOrdering) { + if (!actualRows.equals(expectedRows)) { + assertEquals(actualRows, expectedRows, "For query " + queryId + ": \n " + actual + "\n:"); + } + } + else { + assertEqualsIgnoreOrder(actualRows, expectedRows, "For query " + queryId + ": \n " + actual); + } + } + public static void assertQueryEventually( QueryRunner actualQueryRunner, Session session, @@ -274,6 +411,9 @@ protected static void assertQuerySucceeds(QueryRunner queryRunner, Session sessi try { queryRunner.execute(session, sql); } + catch (QueryFailedException e) { + fail(format("Expected query %s to succeed: %s", e.getQueryId(), sql), e); + } catch (RuntimeException e) { fail(format("Expected query to succeed: %s", sql), e); } @@ -287,8 +427,14 @@ protected static void assertQueryFailsEventually(QueryRunner queryRunner, Sessio protected static void assertQueryFails(QueryRunner queryRunner, Session session, @Language("SQL") String sql, @Language("RegExp") String expectedMessageRegExp) { try { - queryRunner.execute(session, sql); - fail(format("Expected query to fail: %s", sql)); + if (queryRunner instanceof DistributedQueryRunner distributedQueryRunner) { + MaterializedResultWithQueryId resultWithQueryId = distributedQueryRunner.executeWithQueryId(session, sql); + fail(format("Expected query to fail: %s [QueryId: %s]", sql, resultWithQueryId.getQueryId())); + } + else { + queryRunner.execute(session, sql); + fail(format("Expected query to fail: %s", sql)); + } } catch (RuntimeException exception) { exception.addSuppressed(new Exception("Query: " + sql)); @@ -300,13 +446,27 @@ protected static void assertQueryFails(QueryRunner queryRunner, Session session, protected static void assertQueryReturnsEmptyResult(QueryRunner queryRunner, Session session, @Language("SQL") String sql) { + QueryId queryId = null; try { - MaterializedResult results = queryRunner.execute(session, sql).toTestTypes(); + MaterializedResult results; + if (queryRunner instanceof DistributedQueryRunner distributedQueryRunner) { + MaterializedResultWithQueryId resultWithQueryId = distributedQueryRunner.executeWithQueryId(session, sql); + queryId = resultWithQueryId.getQueryId(); + results = resultWithQueryId.getResult().toTestTypes(); + } + else { + results = queryRunner.execute(session, sql).toTestTypes(); + } assertNotNull(results); assertEquals(results.getRowCount(), 0); } catch (RuntimeException ex) { - fail("Execution of query failed: " + sql, ex); + if (queryId == null) { + fail("Execution of query failed: " + sql, ex); + } + else { + fail(format("Execution of query failed: %s [QueryId: %s]", sql, queryId), ex); + } } }