Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Consumer;
import java.util.stream.Stream;

Expand Down Expand Up @@ -116,12 +118,15 @@ public class PipelinedStageExecution
private final Map<Integer, RemoteTask> tasks = new ConcurrentHashMap<>();

// current stage task tracking
@GuardedBy("this")
private final ReentrantReadWriteLock allTasksLock = new ReentrantReadWriteLock();
@GuardedBy("allTasksLock")
private final Set<TaskId> allTasks = new HashSet<>();
@GuardedBy("this")
private final Set<TaskId> finishedTasks = new HashSet<>();
@GuardedBy("this")
private final Set<TaskId> flushingTasks = new HashSet<>();

private final AtomicInteger finishedTasksCounter = new AtomicInteger();
private final AtomicInteger flushingOrFinishedTasksCounter = new AtomicInteger();
private volatile boolean flushingTaskWasObserved;
private final Set<TaskId> finishedTasks = ConcurrentHashMap.newKeySet();
Comment thread
radek-kondziolka marked this conversation as resolved.
private final Set<TaskId> flushingOrFinishedTasks = ConcurrentHashMap.newKeySet();

// source task tracking
@GuardedBy("this")
Expand Down Expand Up @@ -225,23 +230,37 @@ public synchronized void schedulingComplete()
return;
}

if (isFlushing()) {
stateMachine.transitionToFlushing();
try {
allTasksLock.readLock().lock();
if (isFlushing(flushingOrFinishedTasksCounter.get())) {
stateMachine.transitionToFlushing();
}
if (isFinished(finishedTasksCounter.get())) {
stateMachine.transitionToFinished();
}
}
if (finishedTasks.containsAll(allTasks)) {
stateMachine.transitionToFinished();
finally {
allTasksLock.readLock().unlock();
}

for (PlanNodeId partitionedSource : stage.getFragment().getPartitionedSources()) {
schedulingComplete(partitionedSource);
}
}

private synchronized boolean isFlushing()
@GuardedBy("allTasksLock")
private boolean isFlushing(long flushingOrFinishedTasksCounter)
{
// to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished.
return !flushingTasks.isEmpty()
&& allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId));
// allTasks is protected by allTasksLock so that number of total tasks was not changed
return flushingOrFinishedTasksCounter == allTasks.size() && flushingTaskWasObserved;
}

@GuardedBy("allTasksLock")
private boolean isFinished(long finishedTasksCounter)
{
// allTasks is protected by allTasksLock so that number of total tasks was not changed
return finishedTasksCounter == allTasks.size();
}

@Override
Expand Down Expand Up @@ -287,69 +306,74 @@ public synchronized Optional<RemoteTask> scheduleTask(
int partition,
Multimap<PlanNodeId, Split> initialSplits)
{
if (stateMachine.getState().isDone()) {
return Optional.empty();
}
try {
allTasksLock.writeLock().lock();
Comment thread
radek-kondziolka marked this conversation as resolved.

checkArgument(!tasks.containsKey(partition), "A task for partition %s already exists", partition);
if (stateMachine.getState().isDone()) {
return Optional.empty();
}

OutputBuffers outputBuffers = outputBufferManagers.get(stage.getFragment().getId()).getOutputBuffers();
checkArgument(!tasks.containsKey(partition), "A task for partition %s already exists", partition);

Optional<RemoteTask> optionalTask = stage.createTask(
node,
partition,
attempt,
bucketToPartition,
outputBuffers,
initialSplits,
ImmutableSet.of(),
Optional.empty());
OutputBuffers outputBuffers = outputBufferManagers.get(stage.getFragment().getId()).getOutputBuffers();

if (optionalTask.isEmpty()) {
return Optional.empty();
}
Optional<RemoteTask> optionalTask = stage.createTask(
node,
partition,
attempt,
bucketToPartition,
outputBuffers,
initialSplits,
ImmutableSet.of(),
Optional.empty());

RemoteTask task = optionalTask.get();
if (optionalTask.isEmpty()) {
return Optional.empty();
}

tasks.put(partition, task);
RemoteTask task = optionalTask.get();
checkArgument(task.getTaskStatus().getState() != TaskState.FINISHED && task.getTaskStatus().getState() != TaskState.FLUSHING);

ImmutableMultimap.Builder<PlanNodeId, Split> exchangeSplits = ImmutableMultimap.builder();
sourceTasks.forEach((fragmentId, sourceTask) -> {
TaskStatus status = sourceTask.getTaskStatus();
if (status.getState() != TaskState.FINISHED) {
PlanNodeId planNodeId = exchangeSources.get(fragmentId).getId();
exchangeSplits.put(planNodeId, createExchangeSplit(sourceTask, task));
}
});
tasks.put(partition, task);

allTasks.add(task.getTaskId());
ImmutableMultimap.Builder<PlanNodeId, Split> exchangeSplits = ImmutableMultimap.builder();
sourceTasks.forEach((fragmentId, sourceTask) -> {
TaskStatus status = sourceTask.getTaskStatus();
if (status.getState() != TaskState.FINISHED) {
PlanNodeId planNodeId = exchangeSources.get(fragmentId).getId();
exchangeSplits.put(planNodeId, createExchangeSplit(sourceTask, task));
}
});

task.addSplits(exchangeSplits.build());
completeSources.forEach(task::noMoreSplits);
allTasks.add(task.getTaskId());

task.addStateChangeListener(this::updateTaskStatus);
task.addSplits(exchangeSplits.build());
completeSources.forEach(task::noMoreSplits);

task.start();
task.addStateChangeListener(this::updateTaskStatus);

taskLifecycleListener.taskCreated(stage.getFragment().getId(), task);
task.start();

// update output buffers
OutputBufferId outputBufferId = new OutputBufferId(task.getTaskId().getPartitionId());
updateSourceTasksOutputBuffers(outputBufferManager -> outputBufferManager.addOutputBuffer(outputBufferId));
taskLifecycleListener.taskCreated(stage.getFragment().getId(), task);

return Optional.of(task);
// update output buffers
OutputBufferId outputBufferId = new OutputBufferId(task.getTaskId().getPartitionId());
updateSourceTasksOutputBuffers(outputBufferManager -> outputBufferManager.addOutputBuffer(outputBufferId));

return Optional.of(task);
}
finally {
allTasksLock.writeLock().unlock();
}
}

private synchronized void updateTaskStatus(TaskStatus taskStatus)
private synchronized void updateNotSuccessfulTaskStatus(TaskStatus taskStatus)
{
State stageState = stateMachine.getState();
if (stageState.isDone()) {
return;
}

TaskState taskState = taskStatus.getState();

switch (taskState) {
switch (taskStatus.getState()) {
case FAILED:
RuntimeException failure = taskStatus.getFailures().stream()
.findFirst()
Expand All @@ -366,25 +390,72 @@ private synchronized void updateTaskStatus(TaskStatus taskStatus)
// A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED)
fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState));
break;
case FLUSHING:
flushingTasks.add(taskStatus.getTaskId());
break;
case FINISHED:
finishedTasks.add(taskStatus.getTaskId());
Comment thread
radek-kondziolka marked this conversation as resolved.
flushingTasks.remove(taskStatus.getTaskId());
break;
default:
break;
Comment thread
radek-kondziolka marked this conversation as resolved.
}
}

private void updateTaskStatus(TaskStatus taskStatus)
{
State stageState = stateMachine.getState();
Copy link
Copy Markdown
Member

@sopel39 sopel39 Sep 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this can be simplified without introducing too much complex multi-threaded code. A lot of the code in updateTaskStatus doesn't need to be synchronized:

private void updateTaskStatus(TaskStatus taskStatus) {
        State stageState = stateMachine.getState();
        if (stageState.isDone()) {
            return;
        }

        TaskState taskState = taskStatus.getState();
        switch (taskState) {
            case FAILED:
                RuntimeException failure = taskStatus.getFailures().stream()
                        .findFirst()
                        .map(this::rewriteTransportFailure)
                        .map(ExecutionFailureInfo::toException)
                        .orElse(new TrinoException(GENERIC_INTERNAL_ERROR, "A task failed for an unknown reason"));
                fail(failure);
                break;
            case CANCELED:
                // A task should only be in the canceled state if the STAGE is cancelled
                fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the CANCELED state but stage is " + stageState));
                break;
            case ABORTED:
                // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED)
                fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState));
                break;
            case FLUSHING:
                addFlushingTask(taskStatus.getTaskId());
                break;
            case FINISHED:
                addFinishedTask(taskStatus.getTaskId());
                break;
            default:
        }

        if (stageState == SCHEDULED || stageState == RUNNING || stageState == FLUSHING) {
            if (taskState == TaskState.RUNNING) {
                stateMachine.transitionToRunning();
            }
            if (isFlushing()) {
                stateMachine.transitionToFlushing();
            }
            if (isAllTaskFinished()) {
                stateMachine.transitionToFinished();
            }
        }
}

private synchronized void addFlushingTask(TaskId taskId) {
        flushingTasks.add(taskStatus.getTaskId());
}

private synchronized void addFinishedTask(TaskId taskId) {
        finishedTasks.add(taskStatus.getTaskId());
        flushingTasks.remove(taskStatus.getTaskId());
}

private synchronized boolean isAllTaskFinished() {
        return finishedTasks.containsAll(allTasks);
}

Then in subsequent commit I would probably make flushingTasks, finishedTasks lock free, e.g: use ConcurrentHashMap.newKeySet(); (still @Guarded(this) if method touches both finishedTasks and flushingTasks at same time) and:

private void addFlushingTask(TaskId taskId) {
        flushingTasks.add(taskStatus.getTaskId());
}

private void addFinishedTask(TaskId taskId) {
        if (!finishedTasks.contains(taskId)) {
          synchronized(this) {
              // atomically move task to finished set.
              // nit: MAYBE it's not needed
              finishedTasks.add(taskStatus.getTaskId());
              flushingTasks.remove(taskStatus.getTaskId());
          }
        }
}

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is easier but it does not resolve the source problem. You are still locking this in every call of updateTaskStatus. I tried that approach and there was still a high contention on this's monitor.

Copy link
Copy Markdown
Member

@sopel39 sopel39 Sep 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make addFlushingTask and addFinishedTask return boolean (true if element was added), e.g:

        boolean taskStateChanged = addFlushingTask(taskStatus.getTaskId());
        ...
        if (!stateChanged) {
            return;
        }
        if (stageState == SCHEDULED || stageState == RUNNING || stageState == FLUSHING) {
            if (taskState == TaskState.RUNNING) {
                stateMachine.transitionToRunning();
            }
            if (isFlushing()) {
                stateMachine.transitionToFlushing();
            }
            if (isAllTaskFinished()) {
                stateMachine.transitionToFinished();
            }
        }

Alternatively we could probably make isFlushing() and isAllTaskFinished() lock-free somehow

Copy link
Copy Markdown
Contributor Author

@radek-kondziolka radek-kondziolka Sep 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented the version where the updateTaskStatus was not called when taskState was not changed.
It helped a bit, but the lock contention was still too high to be accepted. (like totally 1 day).

Alternatively we could probably make isFlushing() and isAllTaskFinished() lock-free somehow

This what I did in that PR.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented the version where the updateTaskStatus was not called when taskState was not changed.
It helped a bit, but the lock contention was still too high to be accepted. (like totally 1 day).

It could be because the whole updateTaskStatus was synchronized. That seems like a waste. There is no reason why transotionToXX should be synchronized and they do some non-trivial stuff like firing executor task

Copy link
Copy Markdown
Contributor Author

@radek-kondziolka radek-kondziolka Sep 22, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it could be and could not be. It is hard to say. I did not check the version:
(1) decrease the number of calls updateTaskStatus & (2) change the scope of synchronized section
when I've checked that (1) does not help and (2) does not help (separately) I decided to make this method completly lock-free (lock-less)

TaskState taskState = taskStatus.getState();
if (stageState.isDone()) {
return;
}

if (taskState == TaskState.FAILED || taskState == TaskState.CANCELED || taskState == TaskState.ABORTED) {
updateNotSuccessfulTaskStatus(taskStatus);
return;
}

long flushingOrFinishedTasksCounter = 0;
long finishedTasksCounter = 0;

/*
* We are counting the unique number of flushing/finished tasks using the atomic counter flushingOrFinishedTasksCounter.
* When some thread detects that number of finished or flushing tasks is equal
* to allTasks.size() then it decides to move stateMachine's state to FLUSHING/FINISHED. To distinguish between
* FLUSHING or FINISHED we use `flushingTaskWasObserved` to mark that there was at least one flushing task.
*/
if (taskState == TaskState.FLUSHING) {
boolean addedNow = flushingOrFinishedTasks.add(taskStatus.getTaskId());
if (addedNow) {
flushingTaskWasObserved = true;
flushingOrFinishedTasksCounter = this.flushingOrFinishedTasksCounter.incrementAndGet();
}
}

/*
* We are counting the unique number of finished tasks using the atomic counter finishedTasksCounter.
* When some thread detects that number of finished tasks is equal to allTasks.size() then it decides
* to move stateMachine's state to FINISHED.
*/
else if (taskState == TaskState.FINISHED) {
boolean addedNow = flushingOrFinishedTasks.add(taskStatus.getTaskId());
Comment thread
radek-kondziolka marked this conversation as resolved.
if (addedNow) {
Comment thread
radek-kondziolka marked this conversation as resolved.
Outdated
flushingOrFinishedTasksCounter = this.flushingOrFinishedTasksCounter.incrementAndGet();
Comment thread
radek-kondziolka marked this conversation as resolved.
}
addedNow = this.finishedTasks.add(taskStatus.getTaskId());
if (addedNow) {
finishedTasksCounter = this.finishedTasksCounter.incrementAndGet();
}
}

if (stageState == SCHEDULED || stageState == RUNNING || stageState == FLUSHING) {
if (taskState == TaskState.RUNNING) {
stateMachine.transitionToRunning();
}
if (isFlushing()) {
stateMachine.transitionToFlushing();
try {
allTasksLock.readLock().lock();
if (isFlushing(flushingOrFinishedTasksCounter)) {
stateMachine.transitionToFlushing();
}
if (isFinished(finishedTasksCounter)) {
stateMachine.transitionToFinished();
}
}
if (finishedTasks.containsAll(allTasks)) {
stateMachine.transitionToFinished();
finally {
allTasksLock.readLock().unlock();
}
}
}
Expand Down