From b09f551f193f46e6c70585ec25d97b4f37dd2872 Mon Sep 17 00:00:00 2001 From: rkondziolka Date: Wed, 7 Sep 2022 12:36:05 +0200 Subject: [PATCH] Decrease a lock contention in PipelinedStageExecution Taking a monitor of io.trino.execution.scheduler.PipelinedStageExection in the updateTaskStatus method causes a high lock contention. Make this method lock-less. --- .../scheduler/PipelinedStageExecution.java | 201 ++++++++++++------ 1 file changed, 136 insertions(+), 65 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java index 703d904880eb..099c0ed9e3f0 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java @@ -55,7 +55,9 @@ import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.Consumer; import java.util.stream.Stream; @@ -116,12 +118,15 @@ public class PipelinedStageExecution private final Map tasks = new ConcurrentHashMap<>(); // current stage task tracking - @GuardedBy("this") + private final ReentrantReadWriteLock allTasksLock = new ReentrantReadWriteLock(); + @GuardedBy("allTasksLock") private final Set allTasks = new HashSet<>(); - @GuardedBy("this") - private final Set finishedTasks = new HashSet<>(); - @GuardedBy("this") - private final Set flushingTasks = new HashSet<>(); + + private final AtomicInteger finishedTasksCounter = new AtomicInteger(); + private final AtomicInteger flushingOrFinishedTasksCounter = new AtomicInteger(); + private volatile boolean flushingTaskWasObserved; + private final Set finishedTasks = ConcurrentHashMap.newKeySet(); + private final Set flushingOrFinishedTasks = ConcurrentHashMap.newKeySet(); // source task tracking @GuardedBy("this") @@ -225,11 +230,17 @@ public synchronized void schedulingComplete() return; } - if (isFlushing()) { - stateMachine.transitionToFlushing(); + try { + allTasksLock.readLock().lock(); + if (isFlushing(flushingOrFinishedTasksCounter.get())) { + stateMachine.transitionToFlushing(); + } + if (isFinished(finishedTasksCounter.get())) { + stateMachine.transitionToFinished(); + } } - if (finishedTasks.containsAll(allTasks)) { - stateMachine.transitionToFinished(); + finally { + allTasksLock.readLock().unlock(); } for (PlanNodeId partitionedSource : stage.getFragment().getPartitionedSources()) { @@ -237,11 +248,19 @@ public synchronized void schedulingComplete() } } - private synchronized boolean isFlushing() + @GuardedBy("allTasksLock") + private boolean isFlushing(long flushingOrFinishedTasksCounter) { // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. - return !flushingTasks.isEmpty() - && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); + // allTasks is protected by allTasksLock so that number of total tasks was not changed + return flushingOrFinishedTasksCounter == allTasks.size() && flushingTaskWasObserved; + } + + @GuardedBy("allTasksLock") + private boolean isFinished(long finishedTasksCounter) + { + // allTasks is protected by allTasksLock so that number of total tasks was not changed + return finishedTasksCounter == allTasks.size(); } @Override @@ -287,69 +306,74 @@ public synchronized Optional scheduleTask( int partition, Multimap initialSplits) { - if (stateMachine.getState().isDone()) { - return Optional.empty(); - } + try { + allTasksLock.writeLock().lock(); - checkArgument(!tasks.containsKey(partition), "A task for partition %s already exists", partition); + if (stateMachine.getState().isDone()) { + return Optional.empty(); + } - OutputBuffers outputBuffers = outputBufferManagers.get(stage.getFragment().getId()).getOutputBuffers(); + checkArgument(!tasks.containsKey(partition), "A task for partition %s already exists", partition); - Optional optionalTask = stage.createTask( - node, - partition, - attempt, - bucketToPartition, - outputBuffers, - initialSplits, - ImmutableSet.of(), - Optional.empty()); + OutputBuffers outputBuffers = outputBufferManagers.get(stage.getFragment().getId()).getOutputBuffers(); - if (optionalTask.isEmpty()) { - return Optional.empty(); - } + Optional optionalTask = stage.createTask( + node, + partition, + attempt, + bucketToPartition, + outputBuffers, + initialSplits, + ImmutableSet.of(), + Optional.empty()); - RemoteTask task = optionalTask.get(); + if (optionalTask.isEmpty()) { + return Optional.empty(); + } - tasks.put(partition, task); + RemoteTask task = optionalTask.get(); + checkArgument(task.getTaskStatus().getState() != TaskState.FINISHED && task.getTaskStatus().getState() != TaskState.FLUSHING); - ImmutableMultimap.Builder exchangeSplits = ImmutableMultimap.builder(); - sourceTasks.forEach((fragmentId, sourceTask) -> { - TaskStatus status = sourceTask.getTaskStatus(); - if (status.getState() != TaskState.FINISHED) { - PlanNodeId planNodeId = exchangeSources.get(fragmentId).getId(); - exchangeSplits.put(planNodeId, createExchangeSplit(sourceTask, task)); - } - }); + tasks.put(partition, task); - allTasks.add(task.getTaskId()); + ImmutableMultimap.Builder exchangeSplits = ImmutableMultimap.builder(); + sourceTasks.forEach((fragmentId, sourceTask) -> { + TaskStatus status = sourceTask.getTaskStatus(); + if (status.getState() != TaskState.FINISHED) { + PlanNodeId planNodeId = exchangeSources.get(fragmentId).getId(); + exchangeSplits.put(planNodeId, createExchangeSplit(sourceTask, task)); + } + }); - task.addSplits(exchangeSplits.build()); - completeSources.forEach(task::noMoreSplits); + allTasks.add(task.getTaskId()); - task.addStateChangeListener(this::updateTaskStatus); + task.addSplits(exchangeSplits.build()); + completeSources.forEach(task::noMoreSplits); - task.start(); + task.addStateChangeListener(this::updateTaskStatus); - taskLifecycleListener.taskCreated(stage.getFragment().getId(), task); + task.start(); - // update output buffers - OutputBufferId outputBufferId = new OutputBufferId(task.getTaskId().getPartitionId()); - updateSourceTasksOutputBuffers(outputBufferManager -> outputBufferManager.addOutputBuffer(outputBufferId)); + taskLifecycleListener.taskCreated(stage.getFragment().getId(), task); - return Optional.of(task); + // update output buffers + OutputBufferId outputBufferId = new OutputBufferId(task.getTaskId().getPartitionId()); + updateSourceTasksOutputBuffers(outputBufferManager -> outputBufferManager.addOutputBuffer(outputBufferId)); + + return Optional.of(task); + } + finally { + allTasksLock.writeLock().unlock(); + } } - private synchronized void updateTaskStatus(TaskStatus taskStatus) + private synchronized void updateNotSuccessfulTaskStatus(TaskStatus taskStatus) { State stageState = stateMachine.getState(); if (stageState.isDone()) { return; } - - TaskState taskState = taskStatus.getState(); - - switch (taskState) { + switch (taskStatus.getState()) { case FAILED: RuntimeException failure = taskStatus.getFailures().stream() .findFirst() @@ -366,25 +390,72 @@ private synchronized void updateTaskStatus(TaskStatus taskStatus) // A task should only be in the aborted state if the STAGE is done (ABORTED or FAILED) fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); break; - case FLUSHING: - flushingTasks.add(taskStatus.getTaskId()); - break; - case FINISHED: - finishedTasks.add(taskStatus.getTaskId()); - flushingTasks.remove(taskStatus.getTaskId()); - break; default: + break; + } + } + + private void updateTaskStatus(TaskStatus taskStatus) + { + State stageState = stateMachine.getState(); + TaskState taskState = taskStatus.getState(); + if (stageState.isDone()) { + return; + } + + if (taskState == TaskState.FAILED || taskState == TaskState.CANCELED || taskState == TaskState.ABORTED) { + updateNotSuccessfulTaskStatus(taskStatus); + return; + } + + long flushingOrFinishedTasksCounter = 0; + long finishedTasksCounter = 0; + + /* + * We are counting the unique number of flushing/finished tasks using the atomic counter flushingOrFinishedTasksCounter. + * When some thread detects that number of finished or flushing tasks is equal + * to allTasks.size() then it decides to move stateMachine's state to FLUSHING/FINISHED. To distinguish between + * FLUSHING or FINISHED we use `flushingTaskWasObserved` to mark that there was at least one flushing task. + */ + if (taskState == TaskState.FLUSHING) { + boolean addedNow = flushingOrFinishedTasks.add(taskStatus.getTaskId()); + if (addedNow) { + flushingTaskWasObserved = true; + flushingOrFinishedTasksCounter = this.flushingOrFinishedTasksCounter.incrementAndGet(); + } + } + + /* + * We are counting the unique number of finished tasks using the atomic counter finishedTasksCounter. + * When some thread detects that number of finished tasks is equal to allTasks.size() then it decides + * to move stateMachine's state to FINISHED. + */ + else if (taskState == TaskState.FINISHED) { + boolean addedNow = flushingOrFinishedTasks.add(taskStatus.getTaskId()); + if (addedNow) { + flushingOrFinishedTasksCounter = this.flushingOrFinishedTasksCounter.incrementAndGet(); + } + addedNow = this.finishedTasks.add(taskStatus.getTaskId()); + if (addedNow) { + finishedTasksCounter = this.finishedTasksCounter.incrementAndGet(); + } } if (stageState == SCHEDULED || stageState == RUNNING || stageState == FLUSHING) { if (taskState == TaskState.RUNNING) { stateMachine.transitionToRunning(); } - if (isFlushing()) { - stateMachine.transitionToFlushing(); + try { + allTasksLock.readLock().lock(); + if (isFlushing(flushingOrFinishedTasksCounter)) { + stateMachine.transitionToFlushing(); + } + if (isFinished(finishedTasksCounter)) { + stateMachine.transitionToFinished(); + } } - if (finishedTasks.containsAll(allTasks)) { - stateMachine.transitionToFinished(); + finally { + allTasksLock.readLock().unlock(); } } }