Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 58 additions & 31 deletions core/trino-main/src/main/java/io/trino/execution/SqlTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.joda.time.DateTime;

import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy;

import java.net.URI;
import java.util.List;
Expand Down Expand Up @@ -90,7 +91,9 @@ public class SqlTask
private final AtomicReference<DateTime> lastHeartbeat = new AtomicReference<>(DateTime.now());
private final AtomicLong taskStatusVersion = new AtomicLong(TaskStatus.STARTING_VERSION);
private final FutureStateChange<?> taskStatusVersionChange = new FutureStateChange<>();

// Must be synchronized when updating the current task holder reference, but not when only reading the current reference value
private final Object taskHolderLock = new Object();
@GuardedBy("taskHolderLock")
private final AtomicReference<TaskHolder> taskHolderReference = new AtomicReference<>(new TaskHolder());
private final AtomicBoolean needsPlan = new AtomicBoolean(true);
private final AtomicReference<String> traceToken = new AtomicReference<>();
Expand Down Expand Up @@ -167,19 +170,18 @@ private void initialize(Consumer<SqlTask> onDone, CounterStat failedTasks)
}

// store final task info
while (true) {
synchronized (taskHolderLock) {
TaskHolder taskHolder = taskHolderReference.get();
if (taskHolder.isFinished()) {
// another concurrent worker already set the final state
return;
}

if (taskHolderReference.compareAndSet(taskHolder, new TaskHolder(
TaskHolder newHolder = new TaskHolder(
createTaskInfo(taskHolder),
taskHolder.getIoStats(),
taskHolder.getDynamicFilterDomains()))) {
break;
}
taskHolder.getDynamicFilterDomains());
checkState(taskHolderReference.compareAndSet(taskHolder, newHolder), "unsynchronized concurrent task holder update");
}

// make sure buffers are cleaned up
Expand Down Expand Up @@ -433,44 +435,69 @@ public TaskInfo updateTask(
// a VALUES query).
outputBuffer.setOutputBuffers(outputBuffers);

// assure the task execution is only created once
SqlTaskExecution taskExecution;
synchronized (this) {
// is task already complete?
TaskHolder taskHolder = taskHolderReference.get();
if (taskHolder.isFinished()) {
return taskHolder.getFinalTaskInfo();
}
taskExecution = taskHolder.getTaskExecution();
if (taskExecution == null) {
checkState(fragment.isPresent(), "fragment must be present");
taskExecution = sqlTaskExecutionFactory.create(
session,
queryContext,
taskStateMachine,
outputBuffer,
fragment.get(),
this::notifyStatusChanged);
taskHolderReference.compareAndSet(taskHolder, new TaskHolder(taskExecution));
needsPlan.set(false);
taskExecution.start();
}
// is task already complete?
TaskHolder taskHolder = taskHolderReference.get();
if (taskHolder.isFinished()) {
return taskHolder.getFinalTaskInfo();
}

taskExecution.addSplitAssignments(splitAssignments);
taskExecution.getTaskContext().addDynamicFilter(dynamicFilterDomains);
SqlTaskExecution taskExecution = taskHolder.getTaskExecution();
if (taskExecution == null) {
checkState(fragment.isPresent(), "fragment must be present");
taskExecution = tryCreateSqlTaskExecution(session, fragment.get());
}
// taskExecution can still be null if the creation was skipped
if (taskExecution != null) {
taskExecution.addSplitAssignments(splitAssignments);
taskExecution.getTaskContext().addDynamicFilter(dynamicFilterDomains);
}
}
catch (Error e) {
failed(e);
throw e;
}
catch (RuntimeException e) {
failed(e);
return failed(e);
}

return getTaskInfo();
}

@Nullable
private SqlTaskExecution tryCreateSqlTaskExecution(Session session, PlanFragment fragment)
{
synchronized (taskHolderLock) {
// Recheck holder for task execution after acquiring the lock
TaskHolder taskHolder = taskHolderReference.get();
if (taskHolder.isFinished()) {
return null;
}
SqlTaskExecution execution = taskHolder.getTaskExecution();
if (execution != null) {
return execution;
}

// Don't create a new execution if the task is already done
if (taskStateMachine.getState().isDone()) {
return null;
}

execution = sqlTaskExecutionFactory.create(
session,
queryContext,
taskStateMachine,
outputBuffer,
fragment,
this::notifyStatusChanged);
needsPlan.set(false);
execution.start();
// this must happen after taskExecution.start(), otherwise it could become visible to a
// concurrent update without being fully initialized
checkState(taskHolderReference.compareAndSet(taskHolder, new TaskHolder(execution)), "unsynchronized concurrent task holder update");
return execution;
}
}

public ListenableFuture<BufferResult> getTaskResults(PipelinedOutputBuffers.OutputBufferId bufferId, long startingSequenceId, DataSize maxSize)
{
requireNonNull(bufferId, "bufferId is null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,17 +163,22 @@ public SqlTaskExecution(
else {
taskHandle = null;
}

outputBuffer.addStateChangeListener(new CheckTaskCompletionOnBufferFinish(SqlTaskExecution.this));
}
}

public void start()
{
try (SetThreadName ignored = new SetThreadName("Task-%s", getTaskId())) {
// Task handle was not created because the task is already done, nothing to do
if (taskHandle == null) {
return;
}
// The scheduleDriversForTaskLifeCycle method calls enqueueDriverSplitRunner, which registers a callback with access to this object.
// The call back is accessed from another thread, so this code cannot be placed in the constructor.
// The call back is accessed from another thread, so this code cannot be placed in the constructor. This must also happen before outputBuffer
// callbacks are registered to prevent a task completion check before task lifecycle splits are created
scheduleDriversForTaskLifeCycle();
// Output buffer state change listener callback must not run in the constructor to avoid leaking a reference to "this" across to another thread
outputBuffer.addStateChangeListener(new CheckTaskCompletionOnBufferFinish(SqlTaskExecution.this));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move to a separate commit

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is all related code. The movement to here alone would cause issues with scheduleDriversWithTaskLifecycle() not being called when taskHandle == null. This is "fine" because it doesn't break the task state machine model today, but causes problems in a world with terminating task states.

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,17 +401,6 @@ public ListenableFuture<TaskInfo> getTaskInfo(TaskId taskId, long currentVersion
return sqlTask.getTaskInfo(currentVersion);
}

/**
* Gets the unique instance id of a task. This can be used to detect a task
* that was destroyed and recreated.
*/
public String getTaskInstanceId(TaskId taskId)
{
SqlTask sqlTask = tasks.getUnchecked(taskId);
sqlTask.recordHeartbeat();
return sqlTask.getTaskInstanceId();
}

/**
* Gets future status for the task after the state changes from
* {@code current state}. If the task has not been created yet, an
Expand Down Expand Up @@ -508,14 +497,15 @@ private TaskInfo doUpdateTask(
* NOTE: this design assumes that only tasks and buffers that will
* eventually exist are queried.
*/
public ListenableFuture<BufferResult> getTaskResults(TaskId taskId, PipelinedOutputBuffers.OutputBufferId bufferId, long startingSequenceId, DataSize maxSize)
public SqlTaskWithResults getTaskResults(TaskId taskId, PipelinedOutputBuffers.OutputBufferId bufferId, long startingSequenceId, DataSize maxSize)
{
requireNonNull(taskId, "taskId is null");
requireNonNull(bufferId, "bufferId is null");
checkArgument(startingSequenceId >= 0, "startingSequenceId is negative");
requireNonNull(maxSize, "maxSize is null");

return tasks.getUnchecked(taskId).getTaskResults(bufferId, startingSequenceId, maxSize);
SqlTask task = tasks.getUnchecked(taskId);
return new SqlTaskWithResults(task, task.getTaskResults(bufferId, startingSequenceId, maxSize));
}

/**
Expand Down Expand Up @@ -778,4 +768,39 @@ private void failStuckSplitTasks()
}
}
}

public static final class SqlTaskWithResults
{
private final SqlTask task;
private final ListenableFuture<BufferResult> resultsFuture;

public SqlTaskWithResults(SqlTask task, ListenableFuture<BufferResult> resultsFuture)
{
this.task = requireNonNull(task, "task is null");
this.resultsFuture = requireNonNull(resultsFuture, "resultsFuture is null");
}

public void recordHeartbeat()
{
task.recordHeartbeat();
}

public String getTaskInstanceId()
{
return task.getTaskInstanceId();
}

public boolean isTaskFailed()
{
return switch (task.getTaskState()) {
case ABORTED, FAILED -> true;
default -> false;
};
}

public ListenableFuture<BufferResult> getResultsFuture()
{
return resultsFuture;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,10 @@ public ListenableFuture<BufferResult> get(OutputBufferId bufferId, long token, D
if (outputBuffer == null) {
synchronized (this) {
if (delegate == null) {
if (stateMachine.getState() == FINISHED) {
return immediateFuture(emptyResults(taskInstanceId, 0, true));
if (stateMachine.getState().isTerminal()) {
// only set complete when finished, otherwise
boolean complete = stateMachine.getState() == FINISHED;
return immediateFuture(emptyResults(taskInstanceId, 0, complete));
}

PendingRead pendingRead = new PendingRead(bufferId, token, maxSize);
Expand Down Expand Up @@ -310,19 +312,31 @@ public void destroy()
@Override
public void abort()
{
List<PendingRead> pendingReads = ImmutableList.of();
OutputBuffer outputBuffer = delegate;
if (outputBuffer == null) {
synchronized (this) {
if (delegate == null) {
// ignore abort if the buffer already in a terminal state.
stateMachine.abort();
if (!stateMachine.abort()) {
return;
}

// Do not free readers on fail
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assumption is that sending emptyResults(taskInstanceId, 0, completed: false) is safe, even though the previous comment suggested otherwise, because TaskResource will respond similarly when the timeout threshold is reached. Sending these responses instead of letting them time out should propagate task failure information more quickly than allowing the timeout to expire.

return;
pendingReads = ImmutableList.copyOf(this.pendingReads);
this.pendingReads.clear();
}
outputBuffer = delegate;
}
}

// if there is no output buffer, send an empty result without buffer completed signaled
if (outputBuffer == null) {
for (PendingRead pendingRead : pendingReads) {
pendingRead.getFutureResult().set(emptyResults(taskInstanceId, 0, false));
}
return;
}

outputBuffer.abort();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ public synchronized void schedulingComplete(PlanNodeId partitionedSource)
@Override
public synchronized void cancel()
{
stateMachine.transitionToCanceled();
getAllTasks().forEach(RemoteTask::cancel);
// Only send tasks a cancel command if the stage is successfully cancelled and not already failed
if (stateMachine.transitionToCanceled()) {
getAllTasks().forEach(RemoteTask::cancel);
}
}

@Override
Expand Down
Loading