diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java index 36c0f10c7c7a..892d0dc28785 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/PipelinedStageExecution.java @@ -118,10 +118,8 @@ public class PipelinedStageExecution // current stage task tracking @GuardedBy("this") private final Set allTasks = new HashSet<>(); - @GuardedBy("this") - private final Set finishedTasks = new HashSet<>(); - @GuardedBy("this") - private final Set flushingTasks = new HashSet<>(); + private final Set finishedTasks = ConcurrentHashMap.newKeySet(); + private final Set flushingTasks = ConcurrentHashMap.newKeySet(); // source task tracking @GuardedBy("this") @@ -219,16 +217,16 @@ public synchronized void transitionToSchedulingSplits() } @Override - public synchronized void schedulingComplete() + public void schedulingComplete() { if (!stateMachine.transitionToScheduled()) { return; } - if (isFlushing()) { + if (isStageFlushing()) { stateMachine.transitionToFlushing(); } - if (finishedTasks.containsAll(allTasks)) { + if (isStageFinished()) { stateMachine.transitionToFinished(); } @@ -237,13 +235,6 @@ public synchronized void schedulingComplete() } } - private synchronized boolean isFlushing() - { - // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. - return !flushingTasks.isEmpty() - && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); - } - @Override public synchronized void schedulingComplete(PlanNodeId partitionedSource) { @@ -340,13 +331,13 @@ public synchronized Optional scheduleTask( return Optional.of(task); } - private synchronized void updateTaskStatus(TaskStatus taskStatus) + private void updateTaskStatus(TaskStatus taskStatus) { State stageState = stateMachine.getState(); if (stageState.isDone()) { return; } - + boolean newFlushingOrFinishedTaskObserved = false; TaskState taskState = taskStatus.getState(); switch (taskState) { @@ -367,11 +358,10 @@ private synchronized void updateTaskStatus(TaskStatus taskStatus) fail(new TrinoException(GENERIC_INTERNAL_ERROR, "A task is in the ABORTED state but stage is " + stageState)); break; case FLUSHING: - flushingTasks.add(taskStatus.getTaskId()); + newFlushingOrFinishedTaskObserved = addFlushingTask(taskStatus.getTaskId()); break; case FINISHED: - finishedTasks.add(taskStatus.getTaskId()); - flushingTasks.remove(taskStatus.getTaskId()); + newFlushingOrFinishedTaskObserved = addFinishedTask(taskStatus.getTaskId()); break; default: } @@ -380,13 +370,54 @@ private synchronized void updateTaskStatus(TaskStatus taskStatus) if (taskState == TaskState.RUNNING) { stateMachine.transitionToRunning(); } - if (isFlushing()) { - stateMachine.transitionToFlushing(); + // avoid extra synchronization if no new flushing or finished task was observed + if (newFlushingOrFinishedTaskObserved) { + if (isStageFlushing()) { + stateMachine.transitionToFlushing(); + } + if (isStageFinished()) { + stateMachine.transitionToFinished(); + } } - if (finishedTasks.containsAll(allTasks)) { - stateMachine.transitionToFinished(); + } + } + + private synchronized boolean isStageFlushing() + { + // to transition to flushing, there must be at least one flushing task, and all others must be flushing or finished. + return !flushingTasks.isEmpty() + && allTasks.stream().allMatch(taskId -> finishedTasks.contains(taskId) || flushingTasks.contains(taskId)); + } + + private synchronized boolean isStageFinished() + { + return finishedTasks.containsAll(allTasks); + } + + private boolean addFlushingTask(TaskId taskId) + { + if (!flushingTasks.contains(taskId) && !finishedTasks.contains(taskId)) { + synchronized (this) { + // We need to check whether that task is not already finished. It could happen because of out of order of + // task status events + if (!finishedTasks.contains(taskId)) { + return flushingTasks.add(taskId); + } + } + } + return false; + } + + private boolean addFinishedTask(TaskId taskId) + { + if (!finishedTasks.contains(taskId)) { + synchronized (this) { + boolean added = finishedTasks.add(taskId); + flushingTasks.remove(taskId); + return added; } } + return false; } private ExecutionFailureInfo rewriteTransportFailure(ExecutionFailureInfo executionFailureInfo)