diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java index 1c5176626063..64f3526c2a17 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTask.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTask.java @@ -182,7 +182,7 @@ private void initialize(Consumer onDone, CounterStat failedTasks) if (newState == FAILED || newState == ABORTED) { // don't close buffers for a failed query // closed buffers signal to upstream tasks that everything finished cleanly - outputBuffer.fail(); + outputBuffer.abort(); } else { outputBuffer.destroy(); @@ -488,12 +488,12 @@ public void acknowledgeTaskResults(OutputBufferId bufferId, long sequenceId) outputBuffer.acknowledge(bufferId, sequenceId); } - public TaskInfo abortTaskResults(OutputBufferId bufferId) + public TaskInfo destroyTaskResults(OutputBufferId bufferId) { requireNonNull(bufferId, "bufferId is null"); log.debug("Aborting task %s output %s", taskId, bufferId); - outputBuffer.abort(bufferId); + outputBuffer.destroy(bufferId); return getTaskInfo(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java index b852d98f6338..c636b20aa1fb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskExecution.java @@ -38,6 +38,7 @@ import io.trino.operator.StageExecutionDescriptor; import io.trino.operator.TaskContext; import io.trino.spi.SplitWeight; +import io.trino.spi.TrinoException; import io.trino.sql.planner.LocalExecutionPlanner.LocalExecutionPlan; import io.trino.sql.planner.plan.PlanNodeId; @@ -76,6 +77,7 @@ import static io.trino.execution.SqlTaskExecution.SplitsState.FINISHED; import static io.trino.execution.SqlTaskExecution.SplitsState.NO_MORE_SPLITS; import static io.trino.operator.PipelineExecutionStrategy.UNGROUPED_EXECUTION; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; @@ -639,14 +641,28 @@ private synchronized void checkTaskCompletion() // no more output will be created outputBuffer.setNoMorePages(); - // are there still pages in the output buffer - if (!outputBuffer.isFinished()) { + BufferState bufferState = outputBuffer.getState(); + if (!bufferState.isTerminal()) { taskStateMachine.transitionToFlushing(); return; } - // Cool! All done! - taskStateMachine.finished(); + if (bufferState == BufferState.FINISHED) { + // Cool! All done! + taskStateMachine.finished(); + return; + } + + if (bufferState == BufferState.FAILED) { + Throwable failureCause = outputBuffer.getFailureCause() + .orElseGet(() -> new TrinoException(GENERIC_INTERNAL_ERROR, "Output buffer is failed but the failure cause is missing")); + taskStateMachine.failed(failureCause); + return; + } + + // The only terminal state that remains is ABORTED. + // Buffer is expected to be aborted only if the task itself is aborted. In this scenario the following statement is expected to be noop. + taskStateMachine.failed(new TrinoException(GENERIC_INTERNAL_ERROR, "Unexpected buffer state: " + bufferState)); } @Override @@ -1111,7 +1127,7 @@ public CheckTaskCompletionOnBufferFinish(SqlTaskExecution sqlTaskExecution) @Override public void stateChanged(BufferState newState) { - if (newState == BufferState.FINISHED) { + if (newState.isTerminal()) { SqlTaskExecution sqlTaskExecution = sqlTaskExecutionReference.get(); if (sqlTaskExecution != null) { sqlTaskExecution.checkTaskCompletion(); diff --git a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java index 38352a26e149..685148e1ba8f 100644 --- a/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/SqlTaskManager.java @@ -448,12 +448,12 @@ public void acknowledgeTaskResults(TaskId taskId, OutputBufferId bufferId, long } @Override - public TaskInfo abortTaskResults(TaskId taskId, OutputBufferId bufferId) + public TaskInfo destroyTaskResults(TaskId taskId, OutputBufferId bufferId) { requireNonNull(taskId, "taskId is null"); requireNonNull(bufferId, "bufferId is null"); - return tasks.getUnchecked(taskId).abortTaskResults(bufferId); + return tasks.getUnchecked(taskId).destroyTaskResults(bufferId); } @Override diff --git a/core/trino-main/src/main/java/io/trino/execution/TaskManager.java b/core/trino-main/src/main/java/io/trino/execution/TaskManager.java index cc3abf39333f..d45928526179 100644 --- a/core/trino-main/src/main/java/io/trino/execution/TaskManager.java +++ b/core/trino-main/src/main/java/io/trino/execution/TaskManager.java @@ -137,7 +137,7 @@ TaskInfo updateTask( * NOTE: this design assumes that only tasks and buffers that will * eventually exist are queried. */ - TaskInfo abortTaskResults(TaskId taskId, OutputBufferId bufferId); + TaskInfo destroyTaskResults(TaskId taskId, OutputBufferId bufferId); /** * Adds a state change listener to the specified task. diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java index b9619dd5c52c..1d2c3232348b 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/ArbitraryOutputBuffer.java @@ -20,7 +20,6 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.airlift.units.DataSize; -import io.trino.execution.StateMachine; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.ClientBuffer.PagesSupplier; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; @@ -45,12 +44,9 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static io.trino.execution.buffer.BufferState.FAILED; import static io.trino.execution.buffer.BufferState.FINISHED; import static io.trino.execution.buffer.BufferState.FLUSHING; -import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; import static io.trino.execution.buffer.BufferState.NO_MORE_PAGES; -import static io.trino.execution.buffer.BufferState.OPEN; import static io.trino.execution.buffer.OutputBuffers.BufferType.ARBITRARY; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.trino.execution.buffer.PagesSerde.getSerializedPagePositionCount; @@ -77,7 +73,7 @@ public class ArbitraryOutputBuffer // The index of the first client buffer that should be polled private final AtomicInteger nextClientBufferIndex = new AtomicInteger(0); - private final StateMachine state; + private final OutputBufferStateMachine stateMachine; private final String taskInstanceId; private final AtomicLong totalPagesAdded = new AtomicLong(); @@ -85,13 +81,13 @@ public class ArbitraryOutputBuffer public ArbitraryOutputBuffer( String taskInstanceId, - StateMachine state, + OutputBufferStateMachine stateMachine, DataSize maxBufferSize, Supplier memoryContextSupplier, Executor notificationExecutor) { this.taskInstanceId = requireNonNull(taskInstanceId, "taskInstanceId is null"); - this.state = requireNonNull(state, "state is null"); + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); requireNonNull(maxBufferSize, "maxBufferSize is null"); checkArgument(maxBufferSize.toBytes() > 0, "maxBufferSize must be at least 1"); this.memoryManager = new OutputBufferMemoryManager( @@ -105,13 +101,7 @@ public ArbitraryOutputBuffer( @Override public void addStateChangeListener(StateChangeListener stateChangeListener) { - state.addStateChangeListener(stateChangeListener); - } - - @Override - public boolean isFinished() - { - return state.get() == FINISHED; + stateMachine.addStateChangeListener(stateChangeListener); } @Override @@ -123,7 +113,7 @@ public double getUtilization() @Override public boolean isOverutilized() { - return (memoryManager.getUtilization() >= 0.5) || !state.get().canAddPages(); + return (memoryManager.getUtilization() >= 0.5) || !stateMachine.getState().canAddPages(); } @Override @@ -134,7 +124,7 @@ public OutputBufferInfo getInfo() // // always get the state first before any other stats - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); // buffers it a concurrent collection so it is safe to access out side of guard // in this case we only want a snapshot of the current buffers @@ -163,6 +153,12 @@ public OutputBufferInfo getInfo() infos.build()); } + @Override + public BufferState getState() + { + return stateMachine.getState(); + } + @Override public void setOutputBuffers(OutputBuffers newOutputBuffers) { @@ -172,7 +168,7 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) synchronized (this) { // ignore buffers added after query finishes, which can happen when a query is canceled // also ignore old versions, which is normal - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); if (state.isTerminal() || outputBuffers.getVersion() >= newOutputBuffers.getVersion()) { return; } @@ -190,12 +186,11 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) // update state if no more buffers is set if (outputBuffers.isNoMoreBufferIds()) { - this.state.compareAndSet(OPEN, NO_MORE_BUFFERS); - this.state.compareAndSet(NO_MORE_PAGES, FLUSHING); + stateMachine.noMoreBuffers(); } } - if (!state.get().canAddBuffers()) { + if (!stateMachine.getState().canAddBuffers()) { noMoreBuffers(); } @@ -216,7 +211,7 @@ public void enqueue(List pages) // ignore pages after "no more pages" is set // this can happen with a limit query - if (!state.get().canAddPages()) { + if (!stateMachine.getState().canAddPages()) { return; } @@ -287,9 +282,9 @@ public void acknowledge(OutputBufferId bufferId, long sequenceId) } @Override - public void abort(OutputBufferId bufferId) + public void destroy(OutputBufferId bufferId) { - checkState(!Thread.holdsLock(this), "Cannot abort while holding a lock on this"); + checkState(!Thread.holdsLock(this), "Cannot destroy while holding a lock on this"); requireNonNull(bufferId, "bufferId is null"); getBuffer(bufferId).destroy(); @@ -301,8 +296,7 @@ public void abort(OutputBufferId bufferId) public void setNoMorePages() { checkState(!Thread.holdsLock(this), "Cannot set no more pages while holding a lock on this"); - state.compareAndSet(OPEN, NO_MORE_PAGES); - state.compareAndSet(NO_MORE_BUFFERS, FLUSHING); + stateMachine.noMorePages(); memoryManager.setNoBlockOnFull(); masterBuffer.setNoMorePages(); @@ -321,7 +315,7 @@ public void destroy() checkState(!Thread.holdsLock(this), "Cannot destroy while holding a lock on this"); // ignore destroy if the buffer already in a terminal state. - if (state.setIf(FINISHED, oldState -> !oldState.isTerminal())) { + if (stateMachine.finish()) { noMoreBuffers(); masterBuffer.destroy(); @@ -334,10 +328,10 @@ public void destroy() } @Override - public void fail() + public void abort() { - // ignore fail if the buffer already in a terminal state. - if (state.setIf(FAILED, oldState -> !oldState.isTerminal())) { + // ignore abort if the buffer already in a terminal state. + if (stateMachine.abort()) { memoryManager.setNoBlockOnFull(); forceFreeMemory(); // DO NOT destroy buffers or set no more pages. The coordinator manages the teardown of failed queries. @@ -350,6 +344,12 @@ public long getPeakMemoryUsage() return memoryManager.getPeakMemoryUsage(); } + @Override + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + @VisibleForTesting void forceFreeMemory() { @@ -366,14 +366,14 @@ private synchronized ClientBuffer getBuffer(OutputBufferId id) // NOTE: buffers are allowed to be created in the FINISHED state because destroy() can move to the finished state // without a clean "no-more-buffers" message from the scheduler. This happens with limit queries and is ok because // the buffer will be immediately destroyed. - checkState(state.get().canAddBuffers() || !outputBuffers.isNoMoreBufferIds(), "No more buffers already set"); + checkState(stateMachine.getState().canAddBuffers() || !outputBuffers.isNoMoreBufferIds(), "No more buffers already set"); // NOTE: buffers are allowed to be created before they are explicitly declared by setOutputBuffers // When no-more-buffers is set, we verify that all created buffers have been declared buffer = new ClientBuffer(taskInstanceId, id, onPagesReleased); // buffer may have finished immediately before calling this method - if (state.get() == FINISHED) { + if (stateMachine.getState() == FINISHED) { buffer.destroy(); } @@ -400,7 +400,7 @@ private void checkFlushComplete() // This buffer type assigns each page to a single, arbitrary reader, // so we don't need to wait for no-more-buffers to finish the buffer. // Any readers added after finish will simply receive no data. - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); if ((state == FLUSHING) || ((state == NO_MORE_PAGES) && masterBuffer.isEmpty())) { if (safeGetBuffersSnapshot().stream().allMatch(ClientBuffer::isDestroyed)) { destroy(); diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java index 7dd749199dff..95dbff4a4d03 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/BroadcastOutputBuffer.java @@ -20,7 +20,6 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.airlift.units.DataSize; -import io.trino.execution.StateMachine; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.buffer.SerializedPageReference.PagesReleasedListener; @@ -33,6 +32,7 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; +import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; @@ -41,13 +41,13 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.execution.buffer.BufferState.ABORTED; import static io.trino.execution.buffer.BufferState.FAILED; import static io.trino.execution.buffer.BufferState.FINISHED; import static io.trino.execution.buffer.BufferState.FLUSHING; import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; -import static io.trino.execution.buffer.BufferState.NO_MORE_PAGES; -import static io.trino.execution.buffer.BufferState.OPEN; import static io.trino.execution.buffer.OutputBuffers.BufferType.BROADCAST; import static io.trino.execution.buffer.PagesSerde.getSerializedPagePositionCount; import static io.trino.execution.buffer.SerializedPageReference.dereferencePages; @@ -57,7 +57,7 @@ public class BroadcastOutputBuffer implements OutputBuffer { private final String taskInstanceId; - private final StateMachine state; + private final OutputBufferStateMachine stateMachine; private final OutputBufferMemoryManager memoryManager; private final PagesReleasedListener onPagesReleased; @@ -79,14 +79,14 @@ public class BroadcastOutputBuffer public BroadcastOutputBuffer( String taskInstanceId, - StateMachine state, + OutputBufferStateMachine stateMachine, DataSize maxBufferSize, Supplier memoryContextSupplier, Executor notificationExecutor, Runnable notifyStatusChanged) { this.taskInstanceId = requireNonNull(taskInstanceId, "taskInstanceId is null"); - this.state = requireNonNull(state, "state is null"); + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); this.memoryManager = new OutputBufferMemoryManager( requireNonNull(maxBufferSize, "maxBufferSize is null").toBytes(), requireNonNull(memoryContextSupplier, "memoryContextSupplier is null"), @@ -101,13 +101,7 @@ public BroadcastOutputBuffer( @Override public void addStateChangeListener(StateChangeListener stateChangeListener) { - state.addStateChangeListener(stateChangeListener); - } - - @Override - public boolean isFinished() - { - return state.get() == FINISHED; + stateMachine.addStateChangeListener(stateChangeListener); } @Override @@ -119,7 +113,7 @@ public double getUtilization() @Override public boolean isOverutilized() { - return (getUtilization() > 0.5) && state.get().canAddPages(); + return (getUtilization() > 0.5) && stateMachine.getState().canAddPages(); } @Override @@ -130,7 +124,7 @@ public OutputBufferInfo getInfo() // // always get the state first before any other stats - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); // buffer it a concurrent collection so it is safe to access out side of guard // in this case we only want a snapshot of the current buffers @@ -151,6 +145,12 @@ public OutputBufferInfo getInfo() .collect(toImmutableList())); } + @Override + public BufferState getState() + { + return stateMachine.getState(); + } + @Override public void setOutputBuffers(OutputBuffers newOutputBuffers) { @@ -160,7 +160,7 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) synchronized (this) { // ignore buffers added after query finishes, which can happen when a query is canceled // also ignore old versions, which is normal - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); if (state.isTerminal() || outputBuffers.getVersion() >= newOutputBuffers.getVersion()) { return; } @@ -181,12 +181,11 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) // update state if no more buffers is set if (outputBuffers.isNoMoreBufferIds()) { - this.state.compareAndSet(OPEN, NO_MORE_BUFFERS); - this.state.compareAndSet(NO_MORE_PAGES, FLUSHING); + stateMachine.noMoreBuffers(); } } - if (!state.get().canAddBuffers()) { + if (!stateMachine.getState().canAddBuffers()) { noMoreBuffers(); } @@ -207,7 +206,7 @@ public void enqueue(List pages) // ignore pages after "no more pages" is set // this can happen with a limit query - if (!state.get().canAddPages()) { + if (!stateMachine.getState().canAddPages()) { return; } @@ -234,7 +233,7 @@ public void enqueue(List pages) // if we can still add buffers, remember the pages for the future buffers Collection buffers; synchronized (this) { - if (state.get().canAddBuffers()) { + if (stateMachine.getState().canAddBuffers()) { serializedPageReferences.forEach(SerializedPageReference::addReference); initialPagesForNewBuffers.addAll(serializedPageReferences); } @@ -252,7 +251,7 @@ public void enqueue(List pages) // if the buffer is full for first time and more clients are expected, update the task status // notifying a status change will lead to the SourcePartitionedScheduler sending 'no-more-buffers' to unblock if (!hasBlockedBefore.get() - && state.get().canAddBuffers() + && stateMachine.getState().canAddBuffers() && !isFull().isDone() && hasBlockedBefore.compareAndSet(false, true)) { notifyStatusChanged.run(); @@ -286,9 +285,9 @@ public void acknowledge(OutputBufferId bufferId, long sequenceId) } @Override - public void abort(OutputBufferId bufferId) + public void destroy(OutputBufferId bufferId) { - checkState(!Thread.holdsLock(this), "Cannot abort while holding a lock on this"); + checkState(!Thread.holdsLock(this), "Cannot destroy while holding a lock on this"); requireNonNull(bufferId, "bufferId is null"); getBuffer(bufferId).destroy(); @@ -300,8 +299,7 @@ public void abort(OutputBufferId bufferId) public void setNoMorePages() { checkState(!Thread.holdsLock(this), "Cannot set no more pages while holding a lock on this"); - state.compareAndSet(OPEN, NO_MORE_PAGES); - state.compareAndSet(NO_MORE_BUFFERS, FLUSHING); + stateMachine.noMorePages(); memoryManager.setNoBlockOnFull(); safeGetBuffersSnapshot().forEach(ClientBuffer::setNoMorePages); @@ -315,7 +313,7 @@ public void destroy() checkState(!Thread.holdsLock(this), "Cannot destroy while holding a lock on this"); // ignore destroy if the buffer already in a terminal state. - if (state.setIf(FINISHED, oldState -> !oldState.isTerminal())) { + if (stateMachine.finish()) { noMoreBuffers(); safeGetBuffersSnapshot().forEach(ClientBuffer::destroy); @@ -326,10 +324,10 @@ public void destroy() } @Override - public void fail() + public void abort() { - // ignore fail if the buffer already in a terminal state. - if (state.setIf(FAILED, oldState -> !oldState.isTerminal())) { + // ignore abort if the buffer already in a terminal state. + if (stateMachine.abort()) { memoryManager.setNoBlockOnFull(); forceFreeMemory(); // DO NOT destroy buffers or set no more pages. The coordinator manages the teardown of failed queries. @@ -342,6 +340,12 @@ public long getPeakMemoryUsage() return memoryManager.getPeakMemoryUsage(); } + @Override + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + @VisibleForTesting void forceFreeMemory() { @@ -358,15 +362,17 @@ private synchronized ClientBuffer getBuffer(OutputBufferId id) // NOTE: buffers are allowed to be created in the FINISHED state because destroy() can move to the finished state // without a clean "no-more-buffers" message from the scheduler. This happens with limit queries and is ok because // the buffer will be immediately destroyed. - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); checkState(state.canAddBuffers() || !outputBuffers.isNoMoreBufferIds(), "No more buffers already set"); // NOTE: buffers are allowed to be created before they are explicitly declared by setOutputBuffers // When no-more-buffers is set, we verify that all created buffers have been declared buffer = new ClientBuffer(taskInstanceId, id, onPagesReleased); - // do not setup the new buffer if we are already failed - if (state != FAILED) { + // do not setup the new buffer if we are already aborted + if (state != ABORTED) { + verify(state != FAILED, "broadcast output buffer is not expected to fail internally"); + // add initial pages buffer.enqueuePages(initialPagesForNewBuffers); @@ -412,7 +418,8 @@ private void noMoreBuffers() private void checkFlushComplete() { - if (state.get() != FLUSHING && state.get() != NO_MORE_BUFFERS) { + BufferState state = stateMachine.getState(); + if (state != FLUSHING && state != NO_MORE_BUFFERS) { return; } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/BufferState.java b/core/trino-main/src/main/java/io/trino/execution/buffer/BufferState.java index bd5b8610f350..0994336b1cd1 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/BufferState.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/BufferState.java @@ -47,11 +47,17 @@ public enum BufferState */ FINISHED(false, false, true), /** - * Buffer has failed. No more buffers or pages can be added. Readers + * Buffer has been aborted. No more buffers or pages can be added. Readers * will be blocked, as to not communicate a finished state. It is * assumed that the reader will be cleaned up elsewhere. * This is the terminal state. */ + ABORTED(false, false, true), + + /** + * Buffer is failed. No more buffers or pages can be added. The task will be failed. + * This is the terminal state. + */ FAILED(false, false, true); public static final Set TERMINAL_BUFFER_STATES = Stream.of(BufferState.values()).filter(BufferState::isTerminal).collect(toImmutableSet()); diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java index 3116c3656ae9..0d96e19de567 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/LazyOutputBuffer.java @@ -20,7 +20,6 @@ import io.airlift.slice.Slice; import io.airlift.units.DataSize; import io.trino.exchange.ExchangeManagerRegistry; -import io.trino.execution.StateMachine; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.TaskId; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; @@ -35,6 +34,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.concurrent.Executor; import java.util.function.Supplier; @@ -43,16 +43,13 @@ import static com.google.common.base.Preconditions.checkState; import static com.google.common.util.concurrent.Futures.immediateFuture; import static io.trino.execution.buffer.BufferResult.emptyResults; -import static io.trino.execution.buffer.BufferState.FAILED; import static io.trino.execution.buffer.BufferState.FINISHED; -import static io.trino.execution.buffer.BufferState.OPEN; -import static io.trino.execution.buffer.BufferState.TERMINAL_BUFFER_STATES; import static java.util.Objects.requireNonNull; public class LazyOutputBuffer implements OutputBuffer { - private final StateMachine state; + private final OutputBufferStateMachine stateMachine; private final String taskInstanceId; private final DataSize maxBufferSize; private final DataSize maxBroadcastBufferSize; @@ -67,7 +64,7 @@ public class LazyOutputBuffer private volatile OutputBuffer delegate; @GuardedBy("this") - private final Set abortedBuffers = new HashSet<>(); + private final Set destroyedBuffers = new HashSet<>(); @GuardedBy("this") private final List pendingReads = new ArrayList<>(); @@ -84,7 +81,7 @@ public LazyOutputBuffer( { this.taskInstanceId = requireNonNull(taskInstanceId, "taskInstanceId is null"); this.executor = requireNonNull(executor, "executor is null"); - state = new StateMachine<>(taskId + "-buffer", executor, OPEN, TERMINAL_BUFFER_STATES); + stateMachine = new OutputBufferStateMachine(taskId, executor); this.maxBufferSize = requireNonNull(maxBufferSize, "maxBufferSize is null"); this.maxBroadcastBufferSize = requireNonNull(maxBroadcastBufferSize, "maxBroadcastBufferSize is null"); checkArgument(maxBufferSize.toBytes() > 0, "maxBufferSize must be at least 1"); @@ -96,13 +93,7 @@ public LazyOutputBuffer( @Override public void addStateChangeListener(StateChangeListener stateChangeListener) { - state.addStateChangeListener(stateChangeListener); - } - - @Override - public boolean isFinished() - { - return state.get() == FINISHED; + stateMachine.addStateChangeListener(stateChangeListener); } @Override @@ -135,7 +126,7 @@ public OutputBufferInfo getInfo() // // NOTE: this code must be lock free to not hanging state machine updates // - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); return new OutputBufferInfo( "UNINITIALIZED", @@ -151,10 +142,16 @@ public OutputBufferInfo getInfo() return outputBuffer.getInfo(); } + @Override + public BufferState getState() + { + return stateMachine.getState(); + } + @Override public void setOutputBuffers(OutputBuffers newOutputBuffers) { - Set abortedBuffers = ImmutableSet.of(); + Set destroyedBuffers = ImmutableSet.of(); List pendingReads = ImmutableList.of(); OutputBuffer outputBuffer = delegate; if (outputBuffer == null) { @@ -162,33 +159,33 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) outputBuffer = delegate; if (outputBuffer == null) { // ignore set output if buffer was already destroyed or failed - if (state.get().isTerminal()) { + if (stateMachine.getState().isTerminal()) { return; } switch (newOutputBuffers.getType()) { case PARTITIONED: - outputBuffer = new PartitionedOutputBuffer(taskInstanceId, state, newOutputBuffers, maxBufferSize, memoryContextSupplier, executor); + outputBuffer = new PartitionedOutputBuffer(taskInstanceId, stateMachine, newOutputBuffers, maxBufferSize, memoryContextSupplier, executor); break; case BROADCAST: - outputBuffer = new BroadcastOutputBuffer(taskInstanceId, state, maxBroadcastBufferSize, memoryContextSupplier, executor, notifyStatusChanged); + outputBuffer = new BroadcastOutputBuffer(taskInstanceId, stateMachine, maxBroadcastBufferSize, memoryContextSupplier, executor, notifyStatusChanged); break; case ARBITRARY: - outputBuffer = new ArbitraryOutputBuffer(taskInstanceId, state, maxBufferSize, memoryContextSupplier, executor); + outputBuffer = new ArbitraryOutputBuffer(taskInstanceId, stateMachine, maxBufferSize, memoryContextSupplier, executor); break; case SPOOL: ExchangeSinkInstanceHandle exchangeSinkInstanceHandle = newOutputBuffers.getExchangeSinkInstanceHandle() .orElseThrow(() -> new IllegalArgumentException("exchange sink handle is expected to be present for buffer type EXTERNAL")); ExchangeManager exchangeManager = exchangeManagerRegistry.getExchangeManager(); ExchangeSink exchangeSink = exchangeManager.createSink(exchangeSinkInstanceHandle, false); - outputBuffer = new SpoolingExchangeOutputBuffer(state, newOutputBuffers, exchangeSink, memoryContextSupplier); + outputBuffer = new SpoolingExchangeOutputBuffer(stateMachine, newOutputBuffers, exchangeSink, memoryContextSupplier); break; default: throw new IllegalArgumentException("Unexpected output buffer type: " + newOutputBuffers.getType()); } // process pending aborts and reads outside of synchronized lock - abortedBuffers = ImmutableSet.copyOf(this.abortedBuffers); - this.abortedBuffers.clear(); + destroyedBuffers = ImmutableSet.copyOf(this.destroyedBuffers); + this.destroyedBuffers.clear(); pendingReads = ImmutableList.copyOf(this.pendingReads); this.pendingReads.clear(); // Must be assigned last to avoid a race condition with unsynchronized readers @@ -200,7 +197,7 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) outputBuffer.setOutputBuffers(newOutputBuffers); // process pending aborts and reads outside of synchronized lock - abortedBuffers.forEach(outputBuffer::abort); + destroyedBuffers.forEach(outputBuffer::destroy); for (PendingRead pendingRead : pendingReads) { pendingRead.process(outputBuffer); } @@ -213,7 +210,7 @@ public ListenableFuture get(OutputBufferId bufferId, long token, D if (outputBuffer == null) { synchronized (this) { if (delegate == null) { - if (state.get() == FINISHED) { + if (stateMachine.getState() == FINISHED) { return immediateFuture(emptyResults(taskInstanceId, 0, true)); } @@ -235,13 +232,13 @@ public void acknowledge(OutputBufferId bufferId, long token) } @Override - public void abort(OutputBufferId bufferId) + public void destroy(OutputBufferId bufferId) { OutputBuffer outputBuffer = delegate; if (outputBuffer == null) { synchronized (this) { if (delegate == null) { - abortedBuffers.add(bufferId); + destroyedBuffers.add(bufferId); // Normally, we should free any pending readers for this buffer, // but we assume that the real buffer will be created quickly. return; @@ -249,7 +246,7 @@ public void abort(OutputBufferId bufferId) outputBuffer = delegate; } } - outputBuffer.abort(bufferId); + outputBuffer.destroy(bufferId); } @Override @@ -289,7 +286,7 @@ public void destroy() synchronized (this) { if (delegate == null) { // ignore destroy if the buffer already in a terminal state. - if (!state.setIf(FINISHED, state -> !state.isTerminal())) { + if (!stateMachine.finish()) { return; } @@ -312,14 +309,14 @@ public void destroy() } @Override - public void fail() + public void abort() { OutputBuffer outputBuffer = delegate; if (outputBuffer == null) { synchronized (this) { if (delegate == null) { - // ignore fail if the buffer already in a terminal state. - state.setIf(FAILED, state -> !state.isTerminal()); + // ignore abort if the buffer already in a terminal state. + stateMachine.abort(); // Do not free readers on fail return; @@ -327,7 +324,7 @@ public void fail() outputBuffer = delegate; } } - outputBuffer.fail(); + outputBuffer.abort(); } @Override @@ -341,6 +338,12 @@ public long getPeakMemoryUsage() return 0; } + @Override + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + @Nullable private OutputBuffer getDelegateOutputBuffer() { diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffer.java index c49a13cdbfbf..37ce169516cb 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBuffer.java @@ -20,6 +20,7 @@ import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import java.util.List; +import java.util.Optional; public interface OutputBuffer { @@ -30,10 +31,9 @@ public interface OutputBuffer OutputBufferInfo getInfo(); /** - * A buffer is finished once no-more-pages has been set and all buffers have been closed - * with an abort call. + * Get buffer state */ - boolean isFinished(); + BufferState getState(); /** * Get the memory utilization percentage. @@ -73,9 +73,9 @@ public interface OutputBuffer void acknowledge(OutputBufferId bufferId, long token); /** - * Closes the specified output buffer. + * Destroys the specified output buffer, discarding all pages. */ - void abort(OutputBufferId bufferId); + void destroy(OutputBufferId bufferId); /** * Get a future that will be completed when the buffer is not full. @@ -106,13 +106,18 @@ public interface OutputBuffer void destroy(); /** - * Fail the buffer, discarding all pages, but blocking readers. It is expected that + * Abort the buffer, discarding all pages, but blocking readers. It is expected that * readers will be unblocked when the failed query is cleaned up. */ - void fail(); + void abort(); /** * @return the peak memory usage of this output buffer. */ long getPeakMemoryUsage(); + + /** + * Returns non empty failure cause if the buffer is in state {@link BufferState#FAILED} + */ + Optional getFailureCause(); } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferStateMachine.java b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferStateMachine.java new file mode 100644 index 000000000000..33b071ed1c21 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/OutputBufferStateMachine.java @@ -0,0 +1,91 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.buffer; + +import io.trino.execution.StateMachine; +import io.trino.execution.TaskId; + +import java.util.Optional; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; + +import static io.trino.execution.buffer.BufferState.ABORTED; +import static io.trino.execution.buffer.BufferState.FAILED; +import static io.trino.execution.buffer.BufferState.FINISHED; +import static io.trino.execution.buffer.BufferState.FLUSHING; +import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; +import static io.trino.execution.buffer.BufferState.NO_MORE_PAGES; +import static io.trino.execution.buffer.BufferState.OPEN; +import static io.trino.execution.buffer.BufferState.TERMINAL_BUFFER_STATES; +import static java.util.Objects.requireNonNull; + +public class OutputBufferStateMachine +{ + private final StateMachine state; + private final AtomicReference failureCause = new AtomicReference<>(); + + public OutputBufferStateMachine(TaskId taskId, Executor executor) + { + state = new StateMachine<>(taskId + "-buffer", executor, OPEN, TERMINAL_BUFFER_STATES); + } + + public BufferState getState() + { + return state.get(); + } + + public void addStateChangeListener(StateMachine.StateChangeListener stateChangeListener) + { + state.addStateChangeListener(stateChangeListener); + } + + public boolean noMoreBuffers() + { + if (state.compareAndSet(OPEN, NO_MORE_BUFFERS)) { + return true; + } + return state.compareAndSet(NO_MORE_PAGES, FLUSHING); + } + + public boolean noMorePages() + { + if (state.compareAndSet(OPEN, NO_MORE_PAGES)) { + return true; + } + return state.compareAndSet(NO_MORE_BUFFERS, FLUSHING); + } + + public boolean finish() + { + return state.setIf(FINISHED, oldState -> !oldState.isTerminal()); + } + + public boolean abort() + { + return state.setIf(ABORTED, oldState -> !oldState.isTerminal()); + } + + public boolean fail(Throwable throwable) + { + requireNonNull(throwable, "throwable is null"); + + failureCause.compareAndSet(null, throwable); + return state.setIf(FAILED, oldState -> !oldState.isTerminal()); + } + + public Optional getFailureCause() + { + return Optional.ofNullable(failureCause.get()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java index b0771437c296..4cf5b3aae851 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/PartitionedOutputBuffer.java @@ -18,25 +18,21 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.slice.Slice; import io.airlift.units.DataSize; -import io.trino.execution.StateMachine; import io.trino.execution.StateMachine.StateChangeListener; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.buffer.SerializedPageReference.PagesReleasedListener; import io.trino.memory.context.LocalMemoryContext; import java.util.List; +import java.util.Optional; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; -import static io.trino.execution.buffer.BufferState.FAILED; -import static io.trino.execution.buffer.BufferState.FINISHED; import static io.trino.execution.buffer.BufferState.FLUSHING; import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; -import static io.trino.execution.buffer.BufferState.NO_MORE_PAGES; -import static io.trino.execution.buffer.BufferState.OPEN; import static io.trino.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static io.trino.execution.buffer.PagesSerde.getSerializedPagePositionCount; import static io.trino.execution.buffer.SerializedPageReference.dereferencePages; @@ -45,7 +41,7 @@ public class PartitionedOutputBuffer implements OutputBuffer { - private final StateMachine state; + private final OutputBufferStateMachine stateMachine; private final OutputBuffers outputBuffers; private final OutputBufferMemoryManager memoryManager; private final PagesReleasedListener onPagesReleased; @@ -57,13 +53,13 @@ public class PartitionedOutputBuffer public PartitionedOutputBuffer( String taskInstanceId, - StateMachine state, + OutputBufferStateMachine stateMachine, OutputBuffers outputBuffers, DataSize maxBufferSize, Supplier memoryContextSupplier, Executor notificationExecutor) { - this.state = requireNonNull(state, "state is null"); + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); requireNonNull(outputBuffers, "outputBuffers is null"); checkArgument(outputBuffers.getType() == PARTITIONED, "Expected a PARTITIONED output buffer descriptor"); @@ -82,21 +78,14 @@ public PartitionedOutputBuffer( } this.partitions = partitions.build(); - state.compareAndSet(OPEN, NO_MORE_BUFFERS); - state.compareAndSet(NO_MORE_PAGES, FLUSHING); + stateMachine.noMoreBuffers(); checkFlushComplete(); } @Override public void addStateChangeListener(StateChangeListener stateChangeListener) { - state.addStateChangeListener(stateChangeListener); - } - - @Override - public boolean isFinished() - { - return state.get() == FINISHED; + stateMachine.addStateChangeListener(stateChangeListener); } @Override @@ -119,7 +108,7 @@ public OutputBufferInfo getInfo() // // always get the state first before any other stats - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); int totalBufferedPages = 0; ImmutableList.Builder infos = ImmutableList.builderWithExpectedSize(partitions.size()); @@ -141,6 +130,12 @@ public OutputBufferInfo getInfo() infos.build()); } + @Override + public BufferState getState() + { + return stateMachine.getState(); + } + @Override public void setOutputBuffers(OutputBuffers newOutputBuffers) { @@ -148,7 +143,7 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) // ignore buffers added after query finishes, which can happen when a query is canceled // also ignore old versions, which is normal - if (state.get().isTerminal() || outputBuffers.getVersion() >= newOutputBuffers.getVersion()) { + if (stateMachine.getState().isTerminal() || outputBuffers.getVersion() >= newOutputBuffers.getVersion()) { return; } @@ -176,7 +171,7 @@ public void enqueue(int partitionNumber, List pages) // ignore pages after "no more pages" is set // this can happen with a limit query - if (!state.get().canAddPages()) { + if (!stateMachine.getState().canAddPages()) { return; } @@ -224,7 +219,7 @@ public void acknowledge(OutputBufferId outputBufferId, long sequenceId) } @Override - public void abort(OutputBufferId bufferId) + public void destroy(OutputBufferId bufferId) { requireNonNull(bufferId, "bufferId is null"); @@ -236,8 +231,7 @@ public void abort(OutputBufferId bufferId) @Override public void setNoMorePages() { - state.compareAndSet(OPEN, NO_MORE_PAGES); - state.compareAndSet(NO_MORE_BUFFERS, FLUSHING); + stateMachine.noMorePages(); memoryManager.setNoBlockOnFull(); partitions.forEach(ClientBuffer::setNoMorePages); @@ -249,7 +243,7 @@ public void setNoMorePages() public void destroy() { // ignore destroy if the buffer already in a terminal state. - if (state.setIf(FINISHED, oldState -> !oldState.isTerminal())) { + if (stateMachine.finish()) { partitions.forEach(ClientBuffer::destroy); memoryManager.setNoBlockOnFull(); forceFreeMemory(); @@ -257,10 +251,10 @@ public void destroy() } @Override - public void fail() + public void abort() { - // ignore fail if the buffer already in a terminal state. - if (state.setIf(FAILED, oldState -> !oldState.isTerminal())) { + // ignore abort if the buffer already in a terminal state. + if (stateMachine.abort()) { memoryManager.setNoBlockOnFull(); forceFreeMemory(); // DO NOT destroy buffers or set no more pages. The coordinator manages the teardown of failed queries. @@ -273,6 +267,12 @@ public long getPeakMemoryUsage() return memoryManager.getPeakMemoryUsage(); } + @Override + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + @VisibleForTesting void forceFreeMemory() { @@ -281,7 +281,8 @@ void forceFreeMemory() private void checkFlushComplete() { - if (state.get() != FLUSHING && state.get() != NO_MORE_BUFFERS) { + BufferState state = stateMachine.getState(); + if (state != FLUSHING && state != NO_MORE_BUFFERS) { return; } diff --git a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java index 0a6772b8e286..0657135ecd5e 100644 --- a/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java +++ b/core/trino-main/src/main/java/io/trino/execution/buffer/SpoolingExchangeOutputBuffer.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.units.DataSize; import io.trino.execution.StateMachine; @@ -22,17 +23,13 @@ import io.trino.spi.exchange.ExchangeSink; import java.util.List; +import java.util.Optional; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Supplier; import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.concurrent.MoreFutures.asVoid; import static io.airlift.concurrent.MoreFutures.toListenableFuture; -import static io.trino.execution.buffer.BufferState.FAILED; -import static io.trino.execution.buffer.BufferState.FINISHED; -import static io.trino.execution.buffer.BufferState.FLUSHING; -import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; -import static io.trino.execution.buffer.BufferState.OPEN; import static io.trino.execution.buffer.OutputBuffers.BufferType.SPOOL; import static io.trino.execution.buffer.PagesSerde.getSerializedPagePositionCount; import static java.util.Objects.requireNonNull; @@ -40,7 +37,9 @@ public class SpoolingExchangeOutputBuffer implements OutputBuffer { - private final StateMachine state; + private static final Logger log = Logger.get(SpoolingExchangeOutputBuffer.class); + + private final OutputBufferStateMachine stateMachine; private final OutputBuffers outputBuffers; private final ExchangeSink exchangeSink; private final Supplier memoryContextSupplier; @@ -50,24 +49,24 @@ public class SpoolingExchangeOutputBuffer private final AtomicLong totalRowsAdded = new AtomicLong(); public SpoolingExchangeOutputBuffer( - StateMachine state, + OutputBufferStateMachine stateMachine, OutputBuffers outputBuffers, ExchangeSink exchangeSink, Supplier memoryContextSupplier) { - this.state = requireNonNull(state, "state is null"); + this.stateMachine = requireNonNull(stateMachine, "stateMachine is null"); this.outputBuffers = requireNonNull(outputBuffers, "outputBuffers is null"); checkArgument(outputBuffers.getType() == SPOOL, "Expected a SPOOL output buffer"); this.exchangeSink = requireNonNull(exchangeSink, "exchangeSink is null"); this.memoryContextSupplier = requireNonNull(memoryContextSupplier, "memoryContextSupplier is null"); - state.compareAndSet(OPEN, NO_MORE_BUFFERS); + stateMachine.noMoreBuffers(); } @Override public OutputBufferInfo getInfo() { - BufferState state = this.state.get(); + BufferState state = stateMachine.getState(); return new OutputBufferInfo( "EXTERNAL", state, @@ -81,9 +80,9 @@ public OutputBufferInfo getInfo() } @Override - public boolean isFinished() + public BufferState getState() { - return state.get() == FINISHED; + return stateMachine.getState(); } @Override @@ -101,7 +100,7 @@ public boolean isOverutilized() @Override public void addStateChangeListener(StateMachine.StateChangeListener stateChangeListener) { - state.addStateChangeListener(stateChangeListener); + stateMachine.addStateChangeListener(stateChangeListener); } @Override @@ -111,7 +110,7 @@ public void setOutputBuffers(OutputBuffers newOutputBuffers) // ignore buffers added after query finishes, which can happen when a query is canceled // also ignore old versions, which is normal - if (state.get().isTerminal() || outputBuffers.getVersion() >= newOutputBuffers.getVersion()) { + if (stateMachine.getState().isTerminal() || outputBuffers.getVersion() >= newOutputBuffers.getVersion()) { return; } @@ -132,7 +131,7 @@ public void acknowledge(OutputBuffers.OutputBufferId bufferId, long token) } @Override - public void abort(OutputBuffers.OutputBufferId bufferId) + public void destroy(OutputBuffers.OutputBufferId bufferId) { throw new UnsupportedOperationException(); } @@ -156,7 +155,7 @@ public void enqueue(int partition, List pages) // ignore pages after "no more pages" is set // this can happen with a limit query - if (!state.get().canAddPages()) { + if (!stateMachine.getState().canAddPages()) { return; } @@ -171,35 +170,45 @@ public void enqueue(int partition, List pages) @Override public void setNoMorePages() { - if (state.compareAndSet(NO_MORE_BUFFERS, FLUSHING)) { - destroy(); + if (!stateMachine.noMorePages()) { + return; } + + exchangeSink.finish().whenComplete((value, failure) -> { + if (failure != null) { + stateMachine.fail(failure); + } + else { + stateMachine.finish(); + } + updateMemoryUsage(0); + }); } @Override public void destroy() { - if (state.setIf(FINISHED, oldState -> !oldState.isTerminal())) { - try { - exchangeSink.finish(); - } - finally { - updateMemoryUsage(exchangeSink.getMemoryUsage()); - } - } + // Abort the buffer if it hasn't been finished. This is possible when a task is cancelled early by the coordinator. + // Task cancellation is not supported (and not expected to be requested by the coordinator when the spooling exchange + // is in use) as the task output is expected to be deterministic. + // In a scenario when due to a bug in coordinator logic a cancellation is requested it is better to invalidate the sink + // to avoid publishing incomplete data to the downstream stage that could potentially cause a correctness problem + abort(); } @Override - public void fail() + public void abort() { - if (state.setIf(FAILED, oldState -> !oldState.isTerminal())) { - try { - exchangeSink.abort(); - } - finally { - updateMemoryUsage(0); - } + if (!stateMachine.abort()) { + return; } + + exchangeSink.abort().whenComplete((value, failure) -> { + if (failure != null) { + log.warn(failure, "Error aborting exchange sink"); + } + updateMemoryUsage(0); + }); } @Override @@ -208,6 +217,12 @@ public long getPeakMemoryUsage() return peakMemoryUsage.get(); } + @Override + public Optional getFailureCause() + { + return stateMachine.getFailureCause(); + } + private void updateMemoryUsage(long bytes) { LocalMemoryContext context = getSystemMemoryContextOrNull(); diff --git a/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java b/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java index 9599e2e3eef4..d42553127890 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java +++ b/core/trino-main/src/main/java/io/trino/operator/DeduplicatingDirectExchangeBuffer.java @@ -17,6 +17,8 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.io.Closer; +import com.google.common.util.concurrent.FluentFuture; +import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.airlift.log.Logger; @@ -37,7 +39,6 @@ import io.trino.spi.exchange.ExchangeSinkHandle; import io.trino.spi.exchange.ExchangeSinkInstanceHandle; import io.trino.spi.exchange.ExchangeSource; -import io.trino.spi.exchange.ExchangeSourceHandle; import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.NotThreadSafe; @@ -53,6 +54,7 @@ import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicBoolean; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -61,10 +63,13 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.collect.Multimaps.asMap; +import static com.google.common.util.concurrent.Futures.addCallback; import static com.google.common.util.concurrent.Futures.getUnchecked; import static com.google.common.util.concurrent.Futures.immediateVoidFuture; import static com.google.common.util.concurrent.Futures.nonCancellationPropagating; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.MoreFutures.asVoid; +import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.airlift.concurrent.MoreFutures.toListenableFuture; import static io.trino.operator.RetryPolicy.NONE; import static io.trino.operator.RetryPolicy.QUERY; @@ -124,11 +129,12 @@ public DeduplicatingDirectExchangeBuffer( exchangeManagerRegistry, queryId, exchangeId, + executor, bufferCapacity); } @Override - public ListenableFuture isBlocked() + public synchronized ListenableFuture isBlocked() { if (failure != null || closed) { return immediateVoidFuture(); @@ -452,6 +458,7 @@ private static class PageBuffer private final ExchangeManagerRegistry exchangeManagerRegistry; private final QueryId queryId; private final ExchangeId exchangeId; + private final Executor executor; private final long pageBufferCapacityInBytes; private final ListMultimap pageBuffer = ArrayListMultimap.create(); @@ -470,15 +477,19 @@ private static class PageBuffer private boolean inputFinished; private boolean closed; + private final AtomicBoolean exchangeSinkFinished = new AtomicBoolean(); + private PageBuffer( ExchangeManagerRegistry exchangeManagerRegistry, QueryId queryId, ExchangeId exchangeId, + Executor executor, DataSize pageBufferCapacity) { this.exchangeManagerRegistry = requireNonNull(exchangeManagerRegistry, "exchangeManagerRegistry is null"); this.queryId = requireNonNull(queryId, "queryId is null"); this.exchangeId = requireNonNull(exchangeId, "exchangeId is null"); + this.executor = requireNonNull(executor, "executor is null"); this.pageBufferCapacityInBytes = requireNonNull(pageBufferCapacity, "pageBufferCapacity is null").toBytes(); } @@ -600,10 +611,15 @@ public OutputSource createOutputSource(Set selectedTasks) verify(exchange != null, "exchange is expected to be initialized"); verify(sinkInstanceHandle != null, "sinkInstanceHandle is expected to be initialized"); - exchangeSink.finish(); - exchangeSink = null; - exchange.sinkFinished(sinkInstanceHandle); - return new ExchangeOutputSource(selectedTasks, exchangeManager, exchange, queryId); + // Finish ExchangeSink and create ExchangeSource asynchronously to avoid blocking an ExchangeClient thread for potentially substantial amount of time + ListenableFuture exchangeSourceFuture = FluentFuture.from(toListenableFuture(exchangeSink.finish())) + .transformAsync((ignored) -> { + exchangeSinkFinished.set(true); + exchange.sinkFinished(sinkInstanceHandle); + return toListenableFuture(exchange.getSourceHandles()); + }, executor) + .transform(exchangeManager::createSource, executor); + return new ExchangeOutputSource(selectedTasks, queryId, exchangeSourceFuture); } public long getRetainedSizeInBytes() @@ -645,14 +661,21 @@ public void close() pageBufferRetainedSizeInBytes = 0; bufferedPageCount = 0; writeBuffer = null; - try (Closer closer = Closer.create()) { - closer.register(exchange); - if (exchangeSink != null) { - closer.register(exchangeSink::abort); + + if (exchangeSink != null && !exchangeSinkFinished.get()) { + try { + exchangeSink.abort().whenComplete((result, failure) -> { + if (failure != null) { + log.warn(failure, "Error aborting exchange sink"); + } + }); + } + catch (RuntimeException e) { + log.warn(e, "Error aborting exchange sink"); } } - catch (IOException e) { - throw new UncheckedIOException(e); + if (exchange != null) { + exchange.close(); } } } @@ -719,23 +742,20 @@ private static class ExchangeOutputSource implements OutputSource { private final Set selectedTasks; - private final ExchangeManager exchangeManager; - private final Exchange exchange; private final QueryId queryId; + private final ListenableFuture exchangeSourceFuture; private ExchangeSource exchangeSource; private boolean finished; private ExchangeOutputSource( Set selectedTasks, - ExchangeManager exchangeManager, - Exchange exchange, - QueryId queryId) + QueryId queryId, + ListenableFuture exchangeSourceFuture) { this.selectedTasks = ImmutableSet.copyOf(requireNonNull(selectedTasks, "selectedTasks is null")); - this.exchangeManager = requireNonNull(exchangeManager, "exchangeManager is null"); - this.exchange = requireNonNull(exchange, "exchange is null"); this.queryId = requireNonNull(queryId, "queryId is null"); + this.exchangeSourceFuture = requireNonNull(exchangeSourceFuture, "exchangeSourceFuture is null"); } @Override @@ -745,12 +765,10 @@ public Slice getNext() return null; } if (exchangeSource == null) { - CompletableFuture> sourceHandlesFuture = exchange.getSourceHandles(); - if (!sourceHandlesFuture.isDone()) { + if (!exchangeSourceFuture.isDone()) { return null; } - List handles = getUnchecked(sourceHandlesFuture); - exchangeSource = exchangeManager.createSource(handles); + exchangeSource = getFutureValue(exchangeSourceFuture); } while (!exchangeSource.isFinished()) { if (!exchangeSource.isBlocked().isDone()) { @@ -785,16 +803,15 @@ public ListenableFuture isBlocked() if (finished) { return immediateVoidFuture(); } + if (!exchangeSourceFuture.isDone()) { + return nonCancellationPropagating(asVoid(exchangeSourceFuture)); + } if (exchangeSource != null) { CompletableFuture blocked = exchangeSource.isBlocked(); if (!blocked.isDone()) { return nonCancellationPropagating(asVoid(toListenableFuture(blocked))); } } - CompletableFuture> sourceHandles = exchange.getSourceHandles(); - if (!sourceHandles.isDone()) { - return nonCancellationPropagating(asVoid(toListenableFuture(sourceHandles))); - } return immediateVoidFuture(); } @@ -814,9 +831,26 @@ public void close() return; } finished = true; - if (exchangeSource != null) { - exchangeSource.close(); - } + addCallback(exchangeSourceFuture, new FutureCallback<>() + { + @Override + public void onSuccess(ExchangeSource exchangeSource) + { + try { + exchangeSource.close(); + } + catch (RuntimeException e) { + log.warn(e, "error closing exchange source"); + } + } + + @Override + public void onFailure(Throwable ignored) + { + // The callback is needed to safely close the exchange source + // It a failure occurred it is expected to be propagated by the getNext method + } + }, directExecutor()); } } } diff --git a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java index 869586dd9723..d8aa54c3f15f 100644 --- a/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/DirectExchangeClient.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.http.client.HttpClient; +import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.airlift.units.DataSize; import io.airlift.units.Duration; @@ -50,6 +51,8 @@ public class DirectExchangeClient implements Closeable { + private static final Logger log = Logger.get(DirectExchangeClient.class); + private final String selfAddress; private final DataIntegrityVerification dataIntegrityVerification; private final DataSize maxResponseSize; @@ -240,8 +243,15 @@ public synchronized void close() for (HttpPageBufferClient client : allClients.values()) { closeQuietly(client); } - buffer.close(); - memoryContext.setBytes(0); + try { + buffer.close(); + } + catch (RuntimeException e) { + log.warn(e, "error closing buffer"); + } + finally { + memoryContext.setBytes(0); + } } private synchronized void scheduleRequestIfNecessary() diff --git a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java index 1429c5649a20..9ab0e43b4cfd 100644 --- a/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java +++ b/core/trino-main/src/main/java/io/trino/operator/HttpPageBufferClient.java @@ -268,10 +268,10 @@ public synchronized boolean isRunning() @Override public void close() { - boolean shouldSendDelete; + boolean shouldDestroyTaskResults; Future future; synchronized (this) { - shouldSendDelete = !closed; + shouldDestroyTaskResults = !closed; closed = true; @@ -286,9 +286,9 @@ public void close() future.cancel(true); } - // abort the output buffer on the remote node; response of delete is ignored - if (shouldSendDelete) { - sendDelete(); + // destroy task results on the remote node; response is ignored + if (shouldDestroyTaskResults) { + destroyTaskResults(); } } @@ -325,7 +325,7 @@ private synchronized void initiateRequest() } if (completed) { - sendDelete(); + destroyTaskResults(); } else { sendGetResults(); @@ -478,7 +478,7 @@ public void onFailure(Throwable t) }, pageBufferClientCallbackExecutor); } - private synchronized void sendDelete() + private synchronized void destroyTaskResults() { HttpResponseFuture resultFuture = httpClient.executeAsync(prepareDelete().setUri(location).build(), createStatusResponseHandler()); future = resultFuture; diff --git a/core/trino-main/src/main/java/io/trino/server/TaskResource.java b/core/trino-main/src/main/java/io/trino/server/TaskResource.java index 5e5b1f5d48a2..2f9e00872e36 100644 --- a/core/trino-main/src/main/java/io/trino/server/TaskResource.java +++ b/core/trino-main/src/main/java/io/trino/server/TaskResource.java @@ -375,7 +375,7 @@ public void acknowledgeResults( @ResourceSecurity(INTERNAL_ONLY) @DELETE @Path("{taskId}/results/{bufferId}") - public void abortResults( + public void destroyTaskResults( @PathParam("taskId") TaskId taskId, @PathParam("bufferId") OutputBufferId bufferId, @Context UriInfo uriInfo, @@ -384,11 +384,11 @@ public void abortResults( requireNonNull(taskId, "taskId is null"); requireNonNull(bufferId, "bufferId is null"); - if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.ABORT_RESULTS, asyncResponse)) { + if (injectFailure(taskManager.getTraceToken(taskId), taskId, RequestType.DESTROY_RESULTS, asyncResponse)) { return; } - taskManager.abortTaskResults(taskId, bufferId); + taskManager.destroyTaskResults(taskId, bufferId); asyncResponse.resume(Response.noContent().build()); } @@ -461,7 +461,7 @@ private enum RequestType GET_TASK_STATUS(true), ACKNOWLEDGE_AND_GET_NEW_DYNAMIC_FILTER_DOMAINS(true), GET_RESULTS(false), - ABORT_RESULTS(false); + DESTROY_RESULTS(false); private final boolean taskManagement; diff --git a/core/trino-main/src/main/java/io/trino/server/testing/exchange/LocalFileSystemExchangeSink.java b/core/trino-main/src/main/java/io/trino/server/testing/exchange/LocalFileSystemExchangeSink.java index 114030b38c6e..943298f58375 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/exchange/LocalFileSystemExchangeSink.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/exchange/LocalFileSystemExchangeSink.java @@ -41,6 +41,8 @@ import static java.lang.Math.toIntExact; import static java.nio.file.Files.createFile; import static java.util.Objects.requireNonNull; +import static java.util.concurrent.CompletableFuture.completedFuture; +import static java.util.concurrent.CompletableFuture.failedFuture; public class LocalFileSystemExchangeSink implements ExchangeSink @@ -109,10 +111,10 @@ public synchronized long getMemoryUsage() } @Override - public synchronized void finish() + public synchronized CompletableFuture finish() { if (closed) { - return; + return completedFuture(null); } try { for (SliceOutput output : outputs.values()) { @@ -133,17 +135,18 @@ public synchronized void finish() } catch (Throwable t) { abort(); - throw t; + return failedFuture(t); } committed = true; closed = true; + return completedFuture(null); } @Override - public synchronized void abort() + public synchronized CompletableFuture abort() { if (closed) { - return; + return completedFuture(null); } closed = true; for (SliceOutput output : outputs.values()) { @@ -161,5 +164,6 @@ public synchronized void abort() catch (IOException e) { log.warn(e, "Error cleaning output directory"); } + return completedFuture(null); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java index 41979d19ea79..dcf1a9cf7c76 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTask.java @@ -173,8 +173,8 @@ public void testSimpleQuery() } assertEquals(results.getSerializedPages().size(), 0); - // complete the task by calling abort on it - TaskInfo info = sqlTask.abortTaskResults(OUT); + // complete the task by calling destroy on it + TaskInfo info = sqlTask.destroyTaskResults(OUT); assertEquals(info.getOutputBuffers().getState(), BufferState.FINISHED); taskInfo = sqlTask.getTaskInfo(info.getTaskStatus().getVersion()).get(); @@ -233,7 +233,7 @@ public void testAbort() assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING); assertEquals(taskInfo.getTaskStatus().getVersion(), STARTING_VERSION + 1); - sqlTask.abortTaskResults(OUT); + sqlTask.destroyTaskResults(OUT); taskInfo = sqlTask.getTaskInfo(taskInfo.getTaskStatus().getVersion()).get(); assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FINISHED); @@ -258,7 +258,7 @@ public void testBufferCloseOnFinish() updateTask(sqlTask, ImmutableList.of(new SplitAssignment(TABLE_SCAN_NODE_ID, ImmutableSet.of(), true)), outputBuffers); // finish the task by calling abort on it - sqlTask.abortTaskResults(OUT); + sqlTask.destroyTaskResults(OUT); // buffer will be closed by cancel event (wait for event to fire) bufferResult.get(1, SECONDS); diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java index bc86ed1ac435..9041863a97a9 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskExecution.java @@ -29,6 +29,7 @@ import io.trino.execution.buffer.BufferResult; import io.trino.execution.buffer.BufferState; import io.trino.execution.buffer.OutputBuffer; +import io.trino.execution.buffer.OutputBufferStateMachine; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.PartitionedOutputBuffer; @@ -95,8 +96,6 @@ import static io.trino.execution.TaskState.RUNNING; import static io.trino.execution.TaskTestUtils.TABLE_SCAN_NODE_ID; import static io.trino.execution.TaskTestUtils.createTestSplitMonitor; -import static io.trino.execution.buffer.BufferState.OPEN; -import static io.trino.execution.buffer.BufferState.TERMINAL_BUFFER_STATES; import static io.trino.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.trino.execution.buffer.PagesSerde.getSerializedPagePositionCount; @@ -612,7 +611,7 @@ private PartitionedOutputBuffer newTestingOutputBuffer(ScheduledExecutorService { return new PartitionedOutputBuffer( TASK_ID.toString(), - new StateMachine<>("bufferState", taskNotificationExecutor, OPEN, TERMINAL_BUFFER_STATES), + new OutputBufferStateMachine(TASK_ID, taskNotificationExecutor), createInitialEmptyOutputBuffers(PARTITIONED) .withBuffer(OUTPUT_BUFFER_ID, 0) .withNoMoreBufferIds(), @@ -685,7 +684,7 @@ public void assertBufferComplete(Duration timeout) public void abort() { - outputBuffer.abort(outputBufferId); + outputBuffer.destroy(outputBufferId); assertEquals(outputBuffer.getInfo().getState(), BufferState.FINISHED); } } diff --git a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java index 41042f058db9..c9c97539b626 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestSqlTaskManager.java @@ -135,8 +135,8 @@ public void testSimpleQuery() assertTrue(results.isBufferComplete()); assertEquals(results.getSerializedPages().size(), 0); - // complete the task by calling abort on it - TaskInfo info = sqlTaskManager.abortTaskResults(taskId, OUT); + // complete the task by calling destroy on it + TaskInfo info = sqlTaskManager.destroyTaskResults(taskId, OUT); assertEquals(info.getOutputBuffers().getState(), BufferState.FINISHED); taskInfo = sqlTaskManager.getTaskInfo(taskId, taskInfo.getTaskStatus().getVersion()).get(); @@ -203,7 +203,7 @@ public void testAbortResults() TaskInfo taskInfo = sqlTaskManager.getTaskInfo(taskId, TaskStatus.STARTING_VERSION).get(); assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FLUSHING); - sqlTaskManager.abortTaskResults(taskId, OUT); + sqlTaskManager.destroyTaskResults(taskId, OUT); taskInfo = sqlTaskManager.getTaskInfo(taskId, taskInfo.getTaskStatus().getVersion()).get(); assertEquals(taskInfo.getTaskStatus().getState(), TaskState.FINISHED); diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/BufferTestUtils.java b/core/trino-main/src/test/java/io/trino/execution/buffer/BufferTestUtils.java index e9b06924874e..a43fe201561a 100644 --- a/core/trino-main/src/test/java/io/trino/execution/buffer/BufferTestUtils.java +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/BufferTestUtils.java @@ -30,6 +30,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static io.airlift.concurrent.MoreFutures.tryGetFutureValue; +import static io.trino.execution.buffer.BufferState.FINISHED; import static io.trino.execution.buffer.TestingPagesSerdeFactory.testingPagesSerde; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; @@ -229,7 +230,7 @@ static void assertQueueClosed(OutputBuffer buffer, int unassignedPages, OutputBu static void assertFinished(OutputBuffer buffer) { - assertTrue(buffer.isFinished()); + assertEquals(buffer.getState(), FINISHED); for (BufferInfo bufferInfo : buffer.getInfo().getBuffers()) { assertTrue(bufferInfo.isFinished()); assertEquals(bufferInfo.getBufferedPages(), 0); diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/TestArbitraryOutputBuffer.java b/core/trino-main/src/test/java/io/trino/execution/buffer/TestArbitraryOutputBuffer.java index 95d179d413b9..65c1de483c46 100644 --- a/core/trino-main/src/test/java/io/trino/execution/buffer/TestArbitraryOutputBuffer.java +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/TestArbitraryOutputBuffer.java @@ -17,10 +17,12 @@ import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; import io.airlift.units.Duration; -import io.trino.execution.StateMachine; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.memory.context.SimpleLocalMemoryContext; import io.trino.spi.Page; +import io.trino.spi.QueryId; import io.trino.spi.type.BigintType; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -35,8 +37,12 @@ import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.execution.buffer.BufferResult.emptyResults; +import static io.trino.execution.buffer.BufferState.ABORTED; +import static io.trino.execution.buffer.BufferState.FINISHED; +import static io.trino.execution.buffer.BufferState.FLUSHING; +import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; +import static io.trino.execution.buffer.BufferState.NO_MORE_PAGES; import static io.trino.execution.buffer.BufferState.OPEN; -import static io.trino.execution.buffer.BufferState.TERMINAL_BUFFER_STATES; import static io.trino.execution.buffer.BufferTestUtils.MAX_WAIT; import static io.trino.execution.buffer.BufferTestUtils.NO_WAIT; import static io.trino.execution.buffer.BufferTestUtils.acknowledgeBufferResult; @@ -196,33 +202,33 @@ public void testSimple() // // finish the buffer - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); buffer.setNoMorePages(); assertQueueState(buffer, 0, FIRST, 2, 4); assertQueueState(buffer, 0, SECOND, 1, 10); // not fully finished until all pages are consumed - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // acknowledge the pages from the first buffer; buffer should not close automatically assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 6, sizeOfPages(10), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 6, true)); assertQueueState(buffer, 0, FIRST, 0, 6); assertQueueState(buffer, 0, SECOND, 1, 10); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // finish first queue - buffer.abort(FIRST); + buffer.destroy(FIRST); assertQueueClosed(buffer, 0, FIRST, 6); assertQueueState(buffer, 0, SECOND, 1, 10); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // acknowledge a page from the second queue; queue should not close automatically assertBufferResultEquals(TYPES, getBufferResult(buffer, SECOND, 11, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 11, true)); assertQueueState(buffer, 0, SECOND, 0, 11); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // finish second queue - buffer.abort(SECOND); + buffer.destroy(SECOND); assertQueueClosed(buffer, 0, FIRST, 6); assertQueueClosed(buffer, 0, SECOND, 11); assertFinished(buffer); @@ -336,7 +342,7 @@ public void testAddQueueAfterCreation() .withNoMoreBufferIds(), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); assertThatThrownBy(() -> buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY) .withBuffer(FIRST, BROADCAST_PARTITION_ID) @@ -364,19 +370,19 @@ public void testAddAfterFinish() public void testAddQueueAfterNoMoreQueues() { ArbitraryOutputBuffer buffer = createArbitraryBuffer(createInitialEmptyOutputBuffers(ARBITRARY), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // tell buffer no more queues will be added buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY).withNoMoreBufferIds()); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // set no more queues a second time to assure that we don't get an exception or such buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY).withNoMoreBufferIds()); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // set no more queues a third time to assure that we don't get an exception or such buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY).withNoMoreBufferIds()); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(ARBITRARY) .withBuffer(FIRST, BROADCAST_PARTITION_ID) @@ -404,7 +410,7 @@ public void testAddAfterDestroy() public void testGetBeforeCreate() { ArbitraryOutputBuffer buffer = createArbitraryBuffer(createInitialEmptyOutputBuffers(ARBITRARY), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // get a page from a buffer that doesn't exist yet ListenableFuture future = buffer.get(FIRST, 0L, sizeOfPages(1)); @@ -427,7 +433,7 @@ public void testResumeFromPreviousPosition() } ArbitraryOutputBuffer buffer = createArbitraryBuffer(outputBuffers, sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); Map> firstReads = new HashMap<>(); for (OutputBufferId id : ids) { @@ -472,7 +478,7 @@ public void testUseUndeclaredBufferAfterFinalBuffersSet() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // get a page from a buffer that was not declared, which will fail assertThatThrownBy(() -> buffer.get(SECOND, 0L, sizeOfPages(1))) @@ -484,14 +490,14 @@ public void testUseUndeclaredBufferAfterFinalBuffersSet() public void testAbortBeforeCreate() { ArbitraryOutputBuffer buffer = createArbitraryBuffer(createInitialEmptyOutputBuffers(ARBITRARY), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // get a page from a buffer that doesn't exist yet ListenableFuture future = buffer.get(FIRST, 0L, sizeOfPages(1)); assertFalse(future.isDone()); - // abort that buffer, and verify the future is finishd - buffer.abort(FIRST); + // destroy that buffer, and verify the future is finished + buffer.destroy(FIRST); assertBufferResultEquals(TYPES, getFuture(future, NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, false)); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(10), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); @@ -539,17 +545,17 @@ public void testAbort() // read a page from the first buffer assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(1), NO_WAIT), bufferResult(0, createPage(0))); - // abort buffer, and verify page cannot be acknowledged - buffer.abort(FIRST); + // destroy buffer, and verify page cannot be acknowledged + buffer.destroy(FIRST); assertQueueClosed(buffer, 9, FIRST, 0); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 1, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); outputBuffers = outputBuffers.withBuffer(SECOND, 0).withNoMoreBufferIds(); buffer.setOutputBuffers(outputBuffers); - // first page is lost because the first buffer was aborted + // first page is lost because the first buffer was destroyed assertBufferResultEquals(TYPES, getBufferResult(buffer, SECOND, 0, sizeOfPages(1), NO_WAIT), bufferResult(0, createPage(1))); - buffer.abort(SECOND); + buffer.destroy(SECOND); assertQueueClosed(buffer, 0, SECOND, 0); assertFinished(buffer); assertBufferResultEquals(TYPES, getBufferResult(buffer, SECOND, 1, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); @@ -571,8 +577,8 @@ public void testFinishClosesEmptyQueues() assertQueueState(buffer, 0, FIRST, 0, 0); assertQueueState(buffer, 0, SECOND, 0, 0); - buffer.abort(FIRST); - buffer.abort(SECOND); + buffer.destroy(FIRST); + buffer.destroy(SECOND); assertQueueClosed(buffer, 0, FIRST, 0); assertQueueClosed(buffer, 0, SECOND, 0); @@ -583,7 +589,7 @@ public void testAbortFreesReader() { ArbitraryOutputBuffer buffer = createArbitraryBuffer(createInitialEmptyOutputBuffers(ARBITRARY), sizeOfPages(10)); buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY).withBuffer(FIRST, 0)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -601,8 +607,8 @@ public void testAbortFreesReader() future = buffer.get(FIRST, 1, sizeOfPages(10)); assertFalse(future.isDone()); - // abort the buffer - buffer.abort(FIRST); + // destroy the buffer + buffer.destroy(FIRST); assertQueueClosed(buffer, 0, FIRST, 1); // verify the future completed @@ -614,7 +620,7 @@ public void testFinishFreesReader() { ArbitraryOutputBuffer buffer = createArbitraryBuffer(createInitialEmptyOutputBuffers(ARBITRARY), sizeOfPages(10)); buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY).withBuffer(FIRST, 0)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -634,7 +640,7 @@ public void testFinishFreesReader() // finish the buffer assertQueueState(buffer, 0, FIRST, 0, 1); - buffer.abort(FIRST); + buffer.destroy(FIRST); assertQueueClosed(buffer, 0, FIRST, 1); // verify the future completed @@ -648,7 +654,7 @@ public void testFinishFreesWriter() buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY) .withBuffer(FIRST, 0) .withNoMoreBufferIds()); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -669,7 +675,7 @@ public void testFinishFreesWriter() // finish the query buffer.setNoMorePages(); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // verify futures are complete assertFutureIsDone(firstEnqueuePage); @@ -681,10 +687,10 @@ public void testFinishFreesWriter() assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 7, sizeOfPages(100), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 7, true)); // verify not finished - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // finish the queue - buffer.abort(FIRST); + buffer.destroy(FIRST); // verify finished assertFinished(buffer); @@ -697,7 +703,7 @@ public void testDestroyFreesReader() buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY) .withBuffer(FIRST, 0) .withNoMoreBufferIds()); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -730,7 +736,7 @@ public void testDestroyFreesWriter() buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY) .withBuffer(FIRST, 0) .withNoMoreBufferIds()); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -766,7 +772,7 @@ public void testFailDoesNotFreeReader() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -784,8 +790,8 @@ public void testFailDoesNotFreeReader() future = buffer.get(FIRST, 1, sizeOfPages(10)); assertFalse(future.isDone()); - // fail the buffer - buffer.fail(); + // abort the buffer + buffer.abort(); // future should have not finished assertFalse(future.isDone()); @@ -803,7 +809,7 @@ public void testFailFreesWriter() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -822,9 +828,9 @@ public void testFailFreesWriter() assertFalse(firstEnqueuePage.isDone()); assertFalse(secondEnqueuePage.isDone()); - // fail the buffer (i.e., cancel the query) - buffer.fail(); - assertFalse(buffer.isFinished()); + // abort the buffer (i.e., fail the query) + buffer.abort(); + assertEquals(buffer.getState(), ABORTED); // verify the futures are completed assertFutureIsDone(firstEnqueuePage); @@ -837,7 +843,7 @@ public void testAddBufferAfterFail() OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(ARBITRARY) .withBuffer(FIRST, BROADCAST_PARTITION_ID); ArbitraryOutputBuffer buffer = createArbitraryBuffer(outputBuffers, sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -851,8 +857,8 @@ public void testAddBufferAfterFail() // verify we got one page assertBufferResultEquals(TYPES, getFuture(future, NO_WAIT), bufferResult(0, createPage(0))); - // fail the buffer - buffer.fail(); + // abort the buffer + buffer.abort(); // add a buffer outputBuffers = outputBuffers.withBuffer(SECOND, BROADCAST_PARTITION_ID); @@ -883,7 +889,7 @@ public void testBufferCompletion() .withBuffer(FIRST, 0) .withNoMoreBufferIds()); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer List pages = new ArrayList<>(); @@ -898,17 +904,14 @@ public void testBufferCompletion() // get and acknowledge 5 pages assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(5), MAX_WAIT), createBufferResult(TASK_INSTANCE_ID, 0, pages)); - // buffer is not finished - assertFalse(buffer.isFinished()); - // there are no more pages and no more buffers, but buffer is not finished because it didn't receive an acknowledgement yet - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // ask the buffer to finish - buffer.abort(FIRST); + buffer.destroy(FIRST); // verify that the buffer is finished - assertTrue(buffer.isFinished()); + assertEquals(buffer.getState(), FINISHED); } @Test @@ -916,7 +919,7 @@ public void testNoMorePagesFreesReader() { ArbitraryOutputBuffer buffer = createArbitraryBuffer(createInitialEmptyOutputBuffers(ARBITRARY), sizeOfPages(10)); buffer.setOutputBuffers(createInitialEmptyOutputBuffers(ARBITRARY).withBuffer(FIRST, 0)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); assertFalse(future.isDone()); @@ -937,24 +940,24 @@ public void testFinishBeforeNoMoreBuffers() addPage(buffer, createPage(i)); } buffer.setNoMorePages(); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_PAGES); // add one output buffer OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(ARBITRARY).withBuffer(FIRST, 0); buffer.setOutputBuffers(outputBuffers); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_PAGES); // read a page from the first buffer assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(1), NO_WAIT), bufferResult(0, createPage(0))); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_PAGES); // read remaining pages from the first buffer and acknowledge assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 1, sizeOfPages(10), NO_WAIT), bufferResult(1, createPage(1), createPage(2))); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 3, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 3, true)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_PAGES); // finish first queue - buffer.abort(FIRST); + buffer.destroy(FIRST); assertQueueClosed(buffer, 0, FIRST, 3); assertFinished(buffer); @@ -1063,7 +1066,7 @@ private ArbitraryOutputBuffer createArbitraryBuffer(OutputBuffers buffers, DataS { ArbitraryOutputBuffer buffer = new ArbitraryOutputBuffer( TASK_INSTANCE_ID, - new StateMachine<>("bufferState", stateNotificationExecutor, OPEN, TERMINAL_BUFFER_STATES), + new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), stateNotificationExecutor), dataSize, () -> new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), stateNotificationExecutor); diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/TestBroadcastOutputBuffer.java b/core/trino-main/src/test/java/io/trino/execution/buffer/TestBroadcastOutputBuffer.java index bbe5bea964d5..31ab00ed6cca 100644 --- a/core/trino-main/src/test/java/io/trino/execution/buffer/TestBroadcastOutputBuffer.java +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/TestBroadcastOutputBuffer.java @@ -17,12 +17,14 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.airlift.units.DataSize; -import io.trino.execution.StateMachine; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.memory.context.AggregatedMemoryContext; import io.trino.memory.context.MemoryReservationHandler; import io.trino.memory.context.SimpleLocalMemoryContext; import io.trino.spi.Page; +import io.trino.spi.QueryId; import io.trino.spi.type.BigintType; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -37,8 +39,11 @@ import static com.google.common.util.concurrent.MoreExecutors.directExecutor; import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.execution.buffer.BufferResult.emptyResults; +import static io.trino.execution.buffer.BufferState.ABORTED; +import static io.trino.execution.buffer.BufferState.FINISHED; +import static io.trino.execution.buffer.BufferState.FLUSHING; +import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; import static io.trino.execution.buffer.BufferState.OPEN; -import static io.trino.execution.buffer.BufferState.TERMINAL_BUFFER_STATES; import static io.trino.execution.buffer.BufferTestUtils.MAX_WAIT; import static io.trino.execution.buffer.BufferTestUtils.NO_WAIT; import static io.trino.execution.buffer.BufferTestUtils.acknowledgeBufferResult; @@ -201,19 +206,19 @@ public void testSimple() // // finish the buffer - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); buffer.setNoMorePages(); assertQueueState(buffer, FIRST, 10, 4); assertQueueState(buffer, SECOND, 4, 10); // not fully finished until all pages are consumed - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // remove a page, not finished assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 5, sizeOfPages(1), NO_WAIT), bufferResult(5, createPage(5))); assertQueueState(buffer, FIRST, 9, 5); assertQueueState(buffer, SECOND, 4, 10); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // remove all remaining pages from first queue, should not be finished BufferResult x = getBufferResult(buffer, FIRST, 6, sizeOfPages(10), NO_WAIT); @@ -229,10 +234,10 @@ public void testSimple() assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 14, sizeOfPages(10), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 14, true)); // finish first queue - buffer.abort(FIRST); + buffer.destroy(FIRST); assertQueueClosed(buffer, FIRST, 14); assertQueueState(buffer, SECOND, 4, 10); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // remove all remaining pages from second queue, should be finished assertBufferResultEquals(TYPES, getBufferResult(buffer, SECOND, 10, sizeOfPages(10), NO_WAIT), bufferResult(10, createPage(10), @@ -241,7 +246,7 @@ public void testSimple() createPage(13))); assertQueueState(buffer, SECOND, 4, 10); assertBufferResultEquals(TYPES, getBufferResult(buffer, SECOND, 14, sizeOfPages(10), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 14, true)); - buffer.abort(SECOND); + buffer.destroy(SECOND); assertQueueClosed(buffer, FIRST, 14); assertQueueClosed(buffer, SECOND, 14); assertFinished(buffer); @@ -422,7 +427,7 @@ public void testAddQueueAfterCreation() .withNoMoreBufferIds(), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); assertThatThrownBy(() -> buffer.setOutputBuffers(createInitialEmptyOutputBuffers(BROADCAST) .withBuffer(FIRST, BROADCAST_PARTITION_ID) @@ -450,19 +455,19 @@ public void testAddAfterFinish() public void testAddQueueAfterNoMoreQueues() { BroadcastOutputBuffer buffer = createBroadcastBuffer(createInitialEmptyOutputBuffers(BROADCAST), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // tell buffer no more queues will be added buffer.setOutputBuffers(createInitialEmptyOutputBuffers(BROADCAST).withNoMoreBufferIds()); - assertTrue(buffer.isFinished()); + assertEquals(buffer.getState(), FINISHED); // set no more queues a second time to assure that we don't get an exception or such buffer.setOutputBuffers(createInitialEmptyOutputBuffers(BROADCAST).withNoMoreBufferIds()); - assertTrue(buffer.isFinished()); + assertEquals(buffer.getState(), FINISHED); // set no more queues a third time to assure that we don't get an exception or such buffer.setOutputBuffers(createInitialEmptyOutputBuffers(BROADCAST).withNoMoreBufferIds()); - assertTrue(buffer.isFinished()); + assertEquals(buffer.getState(), FINISHED); } @Test @@ -483,7 +488,7 @@ public void testAddAfterDestroy() public void testGetBeforeCreate() { BroadcastOutputBuffer buffer = createBroadcastBuffer(createInitialEmptyOutputBuffers(BROADCAST), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // get a page from a buffer that doesn't exist yet ListenableFuture future = buffer.get(FIRST, 0L, sizeOfPages(1)); @@ -499,7 +504,7 @@ public void testGetBeforeCreate() public void testSetFinalBuffersWihtoutDeclaringUsedBuffer() { BroadcastOutputBuffer buffer = createBroadcastBuffer(createInitialEmptyOutputBuffers(BROADCAST), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // get a page from a buffer that doesn't exist yet ListenableFuture future = buffer.get(FIRST, 0L, sizeOfPages(1)); @@ -515,7 +520,7 @@ public void testSetFinalBuffersWihtoutDeclaringUsedBuffer() // acknowledge the page and verify we are finished assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 1, sizeOfPages(10), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 1, true)); - buffer.abort(FIRST); + buffer.destroy(FIRST); // set final buffers to a set that does not contain the buffer, which will fail assertThatThrownBy(() -> buffer.setOutputBuffers(createInitialEmptyOutputBuffers(BROADCAST).withNoMoreBufferIds())) @@ -531,7 +536,7 @@ public void testUseUndeclaredBufferAfterFinalBuffersSet() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // get a page from a buffer that was not declared, which will fail assertThatThrownBy(() -> buffer.get(SECOND, 0L, sizeOfPages(1))) @@ -543,14 +548,14 @@ public void testUseUndeclaredBufferAfterFinalBuffersSet() public void testAbortBeforeCreate() { BroadcastOutputBuffer buffer = createBroadcastBuffer(createInitialEmptyOutputBuffers(BROADCAST), sizeOfPages(2)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // get a page from a buffer that doesn't exist yet ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(1)); assertFalse(future.isDone()); - // abort that buffer, and verify the future is complete and buffer is finished - buffer.abort(FIRST); + // destroy that buffer, and verify the future is complete and buffer is finished + buffer.destroy(FIRST); assertTrue(future.isDone()); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(10), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); } @@ -625,12 +630,12 @@ public void testAbort() bufferedBuffer.setNoMorePages(); assertBufferResultEquals(TYPES, getBufferResult(bufferedBuffer, FIRST, 0, sizeOfPages(1), NO_WAIT), bufferResult(0, createPage(0))); - bufferedBuffer.abort(FIRST); + bufferedBuffer.destroy(FIRST); assertQueueClosed(bufferedBuffer, FIRST, 0); assertBufferResultEquals(TYPES, getBufferResult(bufferedBuffer, FIRST, 1, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); assertBufferResultEquals(TYPES, getBufferResult(bufferedBuffer, SECOND, 0, sizeOfPages(1), NO_WAIT), bufferResult(0, createPage(0))); - bufferedBuffer.abort(SECOND); + bufferedBuffer.destroy(SECOND); assertQueueClosed(bufferedBuffer, SECOND, 0); assertFinished(bufferedBuffer); assertBufferResultEquals(TYPES, getBufferResult(bufferedBuffer, SECOND, 1, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); @@ -652,8 +657,8 @@ public void testFinishClosesEmptyQueues() assertQueueState(buffer, FIRST, 0, 0); assertQueueState(buffer, SECOND, 0, 0); - buffer.abort(FIRST); - buffer.abort(SECOND); + buffer.destroy(FIRST); + buffer.destroy(SECOND); assertQueueClosed(buffer, FIRST, 0); assertQueueClosed(buffer, SECOND, 0); @@ -668,7 +673,7 @@ public void testAbortFreesReader() .withBuffer(SECOND, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -687,8 +692,8 @@ public void testAbortFreesReader() future = buffer.get(FIRST, 1, sizeOfPages(10)); assertFalse(future.isDone()); - // abort the buffer - buffer.abort(FIRST); + // destroy the buffer + buffer.destroy(FIRST); // verify the future completed // broadcast buffer does not return a "complete" result in this case, but it doesn't mapper @@ -706,7 +711,7 @@ public void testFinishFreesReader() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -740,7 +745,7 @@ public void testFinishFreesWriter() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -761,7 +766,7 @@ public void testFinishFreesWriter() // finish the query buffer.setNoMorePages(); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // verify futures are complete assertFutureIsDone(firstEnqueuePage); @@ -772,7 +777,7 @@ public void testFinishFreesWriter() bufferResult(1, createPage(1), createPage(2), createPage(3), createPage(4), createPage(5), createPage(6))); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 7, sizeOfPages(100), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 7, true)); - buffer.abort(FIRST); + buffer.destroy(FIRST); // verify finished assertFinished(buffer); @@ -786,7 +791,7 @@ public void testDestroyFreesReader() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -820,7 +825,7 @@ public void testDestroyFreesWriter() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -856,7 +861,7 @@ public void testFailDoesNotFreeReader() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -874,8 +879,8 @@ public void testFailDoesNotFreeReader() future = buffer.get(FIRST, 1, sizeOfPages(10)); assertFalse(future.isDone()); - // fail the buffer - buffer.fail(); + // abort the buffer + buffer.abort(); // future should have not finished assertFalse(future.isDone()); @@ -893,7 +898,7 @@ public void testFailFreesWriter() .withBuffer(FIRST, BROADCAST_PARTITION_ID) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -912,9 +917,9 @@ public void testFailFreesWriter() assertFalse(firstEnqueuePage.isDone()); assertFalse(secondEnqueuePage.isDone()); - // fail the buffer (i.e., cancel the query) - buffer.fail(); - assertFalse(buffer.isFinished()); + // abort the buffer (i.e., fail the query) + buffer.abort(); + assertEquals(buffer.getState(), ABORTED); // verify the futures are completed assertFutureIsDone(firstEnqueuePage); @@ -927,7 +932,7 @@ public void testAddBufferAfterFail() OutputBuffers outputBuffers = createInitialEmptyOutputBuffers(BROADCAST) .withBuffer(FIRST, BROADCAST_PARTITION_ID); BroadcastOutputBuffer buffer = createBroadcastBuffer(outputBuffers, sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), OPEN); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -941,8 +946,8 @@ public void testAddBufferAfterFail() // verify we got one page assertBufferResultEquals(TYPES, getFuture(future, NO_WAIT), bufferResult(0, createPage(0))); - // fail the buffer - buffer.fail(); + // abort the buffer + buffer.abort(); // add a buffer outputBuffers = outputBuffers.withBuffer(SECOND, BROADCAST_PARTITION_ID); @@ -974,7 +979,7 @@ public void testBufferCompletion() .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer List pages = new ArrayList<>(); @@ -989,17 +994,14 @@ public void testBufferCompletion() // get and acknowledge 5 pages assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(5), MAX_WAIT), createBufferResult(TASK_INSTANCE_ID, 0, pages)); - // buffer is not finished - assertFalse(buffer.isFinished()); - // there are no more pages and no more buffers, but buffer is not finished because it didn't receive an acknowledgement yet - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // ask the buffer to finish - buffer.abort(FIRST); + buffer.destroy(FIRST); // verify that the buffer is finished - assertTrue(buffer.isFinished()); + assertEquals(buffer.getState(), FINISHED); } @Test @@ -1143,7 +1145,7 @@ private BroadcastOutputBuffer createBroadcastBuffer(OutputBuffers outputBuffers, { BroadcastOutputBuffer buffer = new BroadcastOutputBuffer( TASK_INSTANCE_ID, - new StateMachine<>("bufferState", stateNotificationExecutor, OPEN, TERMINAL_BUFFER_STATES), + new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), stateNotificationExecutor), dataSize, () -> memoryContext.newLocalMemoryContext("test"), notificationExecutor, @@ -1170,14 +1172,14 @@ public void testBufferFinishesWhenClientBuffersDestroyed() } // the buffer is in the NO_MORE_BUFFERS state now - // and if we abort all the buffers it should destroy itself + // and if we destroy all the buffers it should destroy itself // and move to the FINISHED state - buffer.abort(FIRST); - assertFalse(buffer.isFinished()); - buffer.abort(SECOND); - assertFalse(buffer.isFinished()); - buffer.abort(THIRD); - assertTrue(buffer.isFinished()); + buffer.destroy(FIRST); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); + buffer.destroy(SECOND); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); + buffer.destroy(THIRD); + assertEquals(buffer.getState(), FINISHED); } @Test @@ -1209,7 +1211,7 @@ private BroadcastOutputBuffer createBroadcastBuffer(OutputBuffers outputBuffers, { BroadcastOutputBuffer buffer = new BroadcastOutputBuffer( TASK_INSTANCE_ID, - new StateMachine<>("bufferState", stateNotificationExecutor, OPEN, TERMINAL_BUFFER_STATES), + new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), stateNotificationExecutor), dataSize, () -> new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), stateNotificationExecutor, diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/TestPartitionedOutputBuffer.java b/core/trino-main/src/test/java/io/trino/execution/buffer/TestPartitionedOutputBuffer.java index dc9f5c713bc6..eb1cabd71a1f 100644 --- a/core/trino-main/src/test/java/io/trino/execution/buffer/TestPartitionedOutputBuffer.java +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/TestPartitionedOutputBuffer.java @@ -16,10 +16,12 @@ import com.google.common.collect.ImmutableList; import com.google.common.util.concurrent.ListenableFuture; import io.airlift.units.DataSize; -import io.trino.execution.StateMachine; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; import io.trino.execution.buffer.OutputBuffers.OutputBufferId; import io.trino.memory.context.SimpleLocalMemoryContext; import io.trino.spi.Page; +import io.trino.spi.QueryId; import io.trino.spi.type.BigintType; import org.testng.annotations.AfterClass; import org.testng.annotations.BeforeClass; @@ -31,8 +33,10 @@ import static io.airlift.concurrent.Threads.daemonThreadsNamed; import static io.trino.execution.buffer.BufferResult.emptyResults; -import static io.trino.execution.buffer.BufferState.OPEN; -import static io.trino.execution.buffer.BufferState.TERMINAL_BUFFER_STATES; +import static io.trino.execution.buffer.BufferState.ABORTED; +import static io.trino.execution.buffer.BufferState.FINISHED; +import static io.trino.execution.buffer.BufferState.FLUSHING; +import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; import static io.trino.execution.buffer.BufferTestUtils.MAX_WAIT; import static io.trino.execution.buffer.BufferTestUtils.NO_WAIT; import static io.trino.execution.buffer.BufferTestUtils.acknowledgeBufferResult; @@ -192,20 +196,20 @@ public void testSimplePartitioned() // // finish the buffer - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); buffer.setNoMorePages(); assertQueueState(buffer, FIRST, 12, 4); assertQueueState(buffer, SECOND, 0, 10); - buffer.abort(SECOND); + buffer.destroy(SECOND); assertQueueClosed(buffer, SECOND, 10); // not fully finished until all pages are consumed - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // remove a page, not finished assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 5, sizeOfPages(1), NO_WAIT), bufferResult(5, createPage(5))); assertQueueState(buffer, FIRST, 11, 5); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // remove all remaining pages from first queue, should not be finished BufferResult x = getBufferResult(buffer, FIRST, 6, sizeOfPages(30), NO_WAIT); @@ -224,7 +228,7 @@ public void testSimplePartitioned() assertQueueState(buffer, FIRST, 10, 6); // acknowledge all pages from the first partition, should transition to finished state assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 16, sizeOfPages(10), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 16, true)); - buffer.abort(FIRST); + buffer.destroy(FIRST); assertQueueClosed(buffer, FIRST, 16); assertFinished(buffer); } @@ -321,7 +325,7 @@ public void testAddQueueAfterCreation() .withNoMoreBufferIds(), sizeOfPages(10)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); assertThatThrownBy(() -> buffer.setOutputBuffers(createInitialEmptyOutputBuffers(PARTITIONED) .withBuffer(FIRST, 0) @@ -430,12 +434,12 @@ public void testAbort() buffer.setNoMorePages(); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(1), NO_WAIT), bufferResult(0, createPage(0))); - buffer.abort(FIRST); + buffer.destroy(FIRST); assertQueueClosed(buffer, FIRST, 0); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 1, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); assertBufferResultEquals(TYPES, getBufferResult(buffer, SECOND, 0, sizeOfPages(1), NO_WAIT), bufferResult(0, createPage(0))); - buffer.abort(SECOND); + buffer.destroy(SECOND); assertQueueClosed(buffer, SECOND, 0); assertFinished(buffer); assertBufferResultEquals(TYPES, getBufferResult(buffer, SECOND, 1, sizeOfPages(1), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 0, true)); @@ -457,8 +461,8 @@ public void testFinishClosesEmptyQueues() assertQueueState(buffer, FIRST, 0, 0); assertQueueState(buffer, SECOND, 0, 0); - buffer.abort(FIRST); - buffer.abort(SECOND); + buffer.destroy(FIRST); + buffer.destroy(SECOND); assertQueueClosed(buffer, FIRST, 0); assertQueueClosed(buffer, SECOND, 0); @@ -472,7 +476,7 @@ public void testAbortFreesReader() .withBuffer(FIRST, 0) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -491,8 +495,8 @@ public void testAbortFreesReader() future = buffer.get(FIRST, 1, sizeOfPages(10)); assertFalse(future.isDone()); - // abort the buffer - buffer.abort(FIRST); + // destroy the buffer + buffer.destroy(FIRST); // verify the future completed // partitioned buffer does not return a "complete" result in this case, but it doesn't matter @@ -510,7 +514,7 @@ public void testFinishFreesReader() .withBuffer(FIRST, 0) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -544,7 +548,7 @@ public void testFinishFreesWriter() .withBuffer(FIRST, 0) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -565,7 +569,7 @@ public void testFinishFreesWriter() // finish the query buffer.setNoMorePages(); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // verify futures are complete assertFutureIsDone(firstEnqueuePage); @@ -576,7 +580,7 @@ public void testFinishFreesWriter() bufferResult(1, createPage(1), createPage(2), createPage(3), createPage(4), createPage(5), createPage(6))); assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 7, sizeOfPages(100), NO_WAIT), emptyResults(TASK_INSTANCE_ID, 7, true)); - buffer.abort(FIRST); + buffer.destroy(FIRST); // verify finished assertFinished(buffer); @@ -590,7 +594,7 @@ public void testDestroyFreesReader() .withBuffer(FIRST, 0) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -624,7 +628,7 @@ public void testDestroyFreesWriter() .withBuffer(FIRST, 0) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -660,7 +664,7 @@ public void testFailDoesNotFreeReader() .withBuffer(FIRST, 0) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // attempt to get a page ListenableFuture future = buffer.get(FIRST, 0, sizeOfPages(10)); @@ -678,8 +682,8 @@ public void testFailDoesNotFreeReader() future = buffer.get(FIRST, 1, sizeOfPages(10)); assertFalse(future.isDone()); - // fail the buffer - buffer.fail(); + // abort the buffer + buffer.abort(); // future should have not finished assertFalse(future.isDone()); @@ -697,7 +701,7 @@ public void testFailFreesWriter() .withBuffer(FIRST, 0) .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer for (int i = 0; i < 5; i++) { @@ -716,9 +720,9 @@ public void testFailFreesWriter() assertFalse(firstEnqueuePage.isDone()); assertFalse(secondEnqueuePage.isDone()); - // fail the buffer (i.e., cancel the query) - buffer.fail(); - assertFalse(buffer.isFinished()); + // abort the buffer (i.e., fail the query) + buffer.abort(); + assertEquals(buffer.getState(), ABORTED); // verify the futures are completed assertFutureIsDone(firstEnqueuePage); @@ -734,7 +738,7 @@ public void testBufferCompletion() .withNoMoreBufferIds(), sizeOfPages(5)); - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); // fill the buffer List pages = new ArrayList<>(); @@ -749,17 +753,14 @@ public void testBufferCompletion() // get and acknowledge 5 pages assertBufferResultEquals(TYPES, getBufferResult(buffer, FIRST, 0, sizeOfPages(5), MAX_WAIT), createBufferResult(TASK_INSTANCE_ID, 0, pages)); - // buffer is not finished - assertFalse(buffer.isFinished()); - // there are no more pages and no more buffers, but buffer is not finished because it didn't receive an acknowledgement yet - assertFalse(buffer.isFinished()); + assertEquals(buffer.getState(), FLUSHING); // ask the buffer to finish - buffer.abort(FIRST); + buffer.destroy(FIRST); // verify that the buffer is finished - assertTrue(buffer.isFinished()); + assertEquals(buffer.getState(), FINISHED); } @Test @@ -780,14 +781,14 @@ public void testBufferFinishesWhenClientBuffersDestroyed() } // the buffer is in the NO_MORE_BUFFERS state now - // and if we abort all the buffers it should destroy itself + // and if we destroy all the buffers it should destroy itself // and move to the FINISHED state - buffer.abort(FIRST); - assertFalse(buffer.isFinished()); - buffer.abort(SECOND); - assertFalse(buffer.isFinished()); - buffer.abort(THIRD); - assertTrue(buffer.isFinished()); + buffer.destroy(FIRST); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); + buffer.destroy(SECOND); + assertEquals(buffer.getState(), NO_MORE_BUFFERS); + buffer.destroy(THIRD); + assertEquals(buffer.getState(), FINISHED); } @Test @@ -830,7 +831,7 @@ private PartitionedOutputBuffer createPartitionedBuffer(OutputBuffers buffers, D { return new PartitionedOutputBuffer( TASK_INSTANCE_ID, - new StateMachine<>("bufferState", stateNotificationExecutor, OPEN, TERMINAL_BUFFER_STATES), + new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), stateNotificationExecutor), buffers, dataSize, () -> new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), diff --git a/core/trino-main/src/test/java/io/trino/execution/buffer/TestSpoolingExchangeOutputBuffer.java b/core/trino-main/src/test/java/io/trino/execution/buffer/TestSpoolingExchangeOutputBuffer.java new file mode 100644 index 000000000000..cf5373b90daa --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/execution/buffer/TestSpoolingExchangeOutputBuffer.java @@ -0,0 +1,403 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.buffer; + +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ListMultimap; +import com.google.common.util.concurrent.ListenableFuture; +import io.airlift.slice.Slice; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.memory.context.LocalMemoryContext; +import io.trino.spi.QueryId; +import io.trino.spi.exchange.ExchangeSink; +import io.trino.spi.exchange.ExchangeSinkInstanceHandle; +import org.testng.annotations.Test; + +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static com.google.common.util.concurrent.Futures.immediateVoidFuture; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.execution.buffer.BufferState.ABORTED; +import static io.trino.execution.buffer.BufferState.FAILED; +import static io.trino.execution.buffer.BufferState.FINISHED; +import static io.trino.execution.buffer.BufferState.FLUSHING; +import static io.trino.execution.buffer.BufferState.NO_MORE_BUFFERS; +import static io.trino.execution.buffer.OutputBuffers.createSpoolingExchangeOutputBuffers; +import static java.util.Objects.requireNonNull; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestSpoolingExchangeOutputBuffer +{ + @Test + public void testIsFull() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + assertNotBlocked(outputBuffer.isFull()); + + CompletableFuture blocked = new CompletableFuture<>(); + exchangeSink.setBlocked(blocked); + + ListenableFuture full = outputBuffer.isFull(); + assertBlocked(full); + + blocked.complete(null); + assertNotBlocked(full); + } + + @Test + public void testFinishSuccess() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture finish = new CompletableFuture<>(); + exchangeSink.setFinish(finish); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.setNoMorePages(); + // call it for the second time to verify that the buffer handles it correctly + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), FLUSHING); + + finish.complete(null); + assertEquals(outputBuffer.getState(), FINISHED); + } + + @Test + public void testFinishFailure() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture finish = new CompletableFuture<>(); + exchangeSink.setFinish(finish); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.setNoMorePages(); + // call it for the second time to verify that the buffer handles it correctly + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), FLUSHING); + + RuntimeException failure = new RuntimeException("failure"); + finish.completeExceptionally(failure); + assertEquals(outputBuffer.getState(), FAILED); + assertEquals(outputBuffer.getFailureCause(), Optional.of(failure)); + } + + @Test + public void testDestroyAfterFinishCompletion() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture finish = new CompletableFuture<>(); + exchangeSink.setFinish(finish); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.setNoMorePages(); + // call it for the second time to verify that the buffer handles it correctly + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), FLUSHING); + + finish.complete(null); + assertEquals(outputBuffer.getState(), FINISHED); + + outputBuffer.destroy(); + assertEquals(outputBuffer.getState(), FINISHED); + } + + @Test + public void testDestroyBeforeFinishCompletion() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture finish = new CompletableFuture<>(); + exchangeSink.setFinish(finish); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), FLUSHING); + + outputBuffer.destroy(); + assertEquals(outputBuffer.getState(), ABORTED); + + finish.complete(null); + assertEquals(outputBuffer.getState(), ABORTED); + } + + @Test + public void testAbortBeforeNoMorePages() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.abort(); + assertEquals(outputBuffer.getState(), ABORTED); + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), ABORTED); + } + + @Test + public void testAbortBeforeFinishCompletion() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture finish = new CompletableFuture<>(); + exchangeSink.setFinish(finish); + CompletableFuture abort = new CompletableFuture<>(); + exchangeSink.setAbort(abort); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.setNoMorePages(); + // call it for the second time to verify that the buffer handles it correctly + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), FLUSHING); + + // if abort is called before finish completes it should abort the buffer + outputBuffer.abort(); + assertEquals(outputBuffer.getState(), ABORTED); + + // abort failure shouldn't impact the buffer state + abort.completeExceptionally(new RuntimeException("failure")); + assertEquals(outputBuffer.getState(), ABORTED); + } + + @Test + public void testAbortAfterFinishCompletion() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture finish = new CompletableFuture<>(); + exchangeSink.setFinish(finish); + CompletableFuture abort = new CompletableFuture<>(); + exchangeSink.setAbort(abort); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.setNoMorePages(); + // call it for the second time to verify that the buffer handles it correctly + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), FLUSHING); + + finish.complete(null); + assertEquals(outputBuffer.getState(), FINISHED); + + // abort is no op + outputBuffer.abort(); + assertEquals(outputBuffer.getState(), FINISHED); + + // abort success doesn't change the buffer state + abort.complete(null); + assertEquals(outputBuffer.getState(), FINISHED); + } + + @Test + public void testEnqueueAfterFinish() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture finish = new CompletableFuture<>(); + exchangeSink.setFinish(finish); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page1"))); + outputBuffer.enqueue(1, ImmutableList.of(utf8Slice("page2"), utf8Slice("page3"))); + + ImmutableListMultimap expectedDataBufferState = ImmutableListMultimap.builder() + .put(0, utf8Slice("page1")) + .put(1, utf8Slice("page2")) + .put(1, utf8Slice("page3")) + .build(); + + assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); + + outputBuffer.setNoMorePages(); + assertEquals(outputBuffer.getState(), FLUSHING); + // the buffer is flushing, this page is expected to be rejected + outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page4"))); + assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); + + finish.complete(null); + assertEquals(outputBuffer.getState(), FINISHED); + outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page5"))); + assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); + } + + @Test + public void testEnqueueAfterAbort() + { + TestingExchangeSink exchangeSink = new TestingExchangeSink(); + CompletableFuture abort = new CompletableFuture<>(); + exchangeSink.setAbort(abort); + + OutputBuffer outputBuffer = createSpoolingExchangeOutputBuffer(exchangeSink); + assertEquals(outputBuffer.getState(), NO_MORE_BUFFERS); + + outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page1"))); + outputBuffer.enqueue(1, ImmutableList.of(utf8Slice("page2"), utf8Slice("page3"))); + + ImmutableListMultimap expectedDataBufferState = ImmutableListMultimap.builder() + .put(0, utf8Slice("page1")) + .put(1, utf8Slice("page2")) + .put(1, utf8Slice("page3")) + .build(); + + assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); + + outputBuffer.abort(); + assertEquals(outputBuffer.getState(), ABORTED); + // the buffer is flushing, this page is expected to be rejected + outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page4"))); + assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); + + abort.complete(null); + assertEquals(outputBuffer.getState(), ABORTED); + outputBuffer.enqueue(0, ImmutableList.of(utf8Slice("page5"))); + assertEquals(exchangeSink.getDataBuffer(), expectedDataBufferState); + } + + private static SpoolingExchangeOutputBuffer createSpoolingExchangeOutputBuffer(ExchangeSink exchangeSink) + { + return new SpoolingExchangeOutputBuffer( + new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), directExecutor()), + createSpoolingExchangeOutputBuffers(TestingExchangeSinkInstanceHandle.INSTANCE), + exchangeSink, + TestingLocalMemoryContext::new); + } + + private static void assertNotBlocked(ListenableFuture blocked) + { + assertTrue(blocked.isDone()); + } + + private static void assertBlocked(ListenableFuture blocked) + { + assertFalse(blocked.isDone()); + } + + private static class TestingExchangeSink + implements ExchangeSink + { + private final ListMultimap dataBuffer = ArrayListMultimap.create(); + private CompletableFuture blocked = CompletableFuture.completedFuture(null); + private CompletableFuture finish = CompletableFuture.completedFuture(null); + private CompletableFuture abort = CompletableFuture.completedFuture(null); + + private boolean finishCalled; + private boolean abortCalled; + + @Override + public CompletableFuture isBlocked() + { + return blocked; + } + + public void setBlocked(CompletableFuture blocked) + { + this.blocked = requireNonNull(blocked, "blocked is null"); + } + + @Override + public void add(int partitionId, Slice data) + { + this.dataBuffer.put(partitionId, data); + } + + public ListMultimap getDataBuffer() + { + return dataBuffer; + } + + @Override + + public long getMemoryUsage() + { + return 0; + } + + @Override + public CompletableFuture finish() + { + assertFalse(abortCalled); + assertFalse(finishCalled); + finishCalled = true; + return finish; + } + + public void setFinish(CompletableFuture finish) + { + this.finish = requireNonNull(finish, "finish is null"); + } + + @Override + public CompletableFuture abort() + { + assertFalse(abortCalled); + abortCalled = true; + return abort; + } + + public void setAbort(CompletableFuture abort) + { + this.abort = requireNonNull(abort, "abort is null"); + } + } + + private enum TestingExchangeSinkInstanceHandle + implements ExchangeSinkInstanceHandle + { + INSTANCE + } + + private static class TestingLocalMemoryContext + implements LocalMemoryContext + { + @Override + public long getBytes() + { + return 0; + } + + @Override + public ListenableFuture setBytes(long bytes) + { + return immediateVoidFuture(); + } + + @Override + public boolean trySetBytes(long bytes) + { + return true; + } + + @Override + public void close() + { + } + } +} diff --git a/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java b/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java index a8c78741b5ef..d8385509eae2 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/BenchmarkPartitionedOutputOperator.java @@ -17,8 +17,9 @@ import io.airlift.slice.Slice; import io.airlift.units.DataSize; import io.trino.Session; -import io.trino.execution.StateMachine; -import io.trino.execution.buffer.BufferState; +import io.trino.execution.StageId; +import io.trino.execution.TaskId; +import io.trino.execution.buffer.OutputBufferStateMachine; import io.trino.execution.buffer.OutputBuffers; import io.trino.execution.buffer.PagesSerdeFactory; import io.trino.execution.buffer.PartitionedOutputBuffer; @@ -34,6 +35,7 @@ import io.trino.operator.TaskContext; import io.trino.operator.TrinoOperatorFactories; import io.trino.spi.Page; +import io.trino.spi.QueryId; import io.trino.spi.block.Block; import io.trino.spi.block.RunLengthEncodedBlock; import io.trino.spi.block.TestingBlockEncodingSerde; @@ -85,8 +87,6 @@ import static io.trino.block.BlockAssertions.createRLEBlock; import static io.trino.block.BlockAssertions.createRandomBlockForType; import static io.trino.block.BlockAssertions.createRandomLongsBlock; -import static io.trino.execution.buffer.BufferState.OPEN; -import static io.trino.execution.buffer.BufferState.TERMINAL_BUFFER_STATES; import static io.trino.execution.buffer.OutputBuffers.BufferType.PARTITIONED; import static io.trino.execution.buffer.OutputBuffers.createInitialEmptyOutputBuffers; import static io.trino.memory.context.AggregatedMemoryContext.newSimpleAggregatedMemoryContext; @@ -461,7 +461,7 @@ private TestingPartitionedOutputBuffer createPartitionedBuffer(OutputBuffers buf { return new TestingPartitionedOutputBuffer( "task-instance-id", - new StateMachine<>("bufferState", SCHEDULER, OPEN, TERMINAL_BUFFER_STATES), + new OutputBufferStateMachine(new TaskId(new StageId(new QueryId("query"), 0), 0, 0), SCHEDULER), buffers, dataSize, () -> new SimpleLocalMemoryContext(newSimpleAggregatedMemoryContext(), "test"), @@ -476,14 +476,14 @@ private static class TestingPartitionedOutputBuffer public TestingPartitionedOutputBuffer( String taskInstanceId, - StateMachine state, + OutputBufferStateMachine stateMachine, OutputBuffers outputBuffers, DataSize maxBufferSize, Supplier memoryContextSupplier, Executor notificationExecutor, Blackhole blackhole) { - super(taskInstanceId, state, outputBuffers, maxBufferSize, memoryContextSupplier, notificationExecutor); + super(taskInstanceId, stateMachine, outputBuffers, maxBufferSize, memoryContextSupplier, notificationExecutor); this.blackhole = blackhole; } diff --git a/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java b/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java index 9a25fc63703c..0e66e10aebf7 100644 --- a/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java +++ b/core/trino-main/src/test/java/io/trino/operator/output/TestPartitionedOutputOperator.java @@ -596,9 +596,9 @@ public OutputBufferInfo getInfo() } @Override - public boolean isFinished() + public BufferState getState() { - return false; + return BufferState.NO_MORE_BUFFERS; } @Override @@ -635,7 +635,7 @@ public void acknowledge(OutputBuffers.OutputBufferId bufferId, long token) } @Override - public void abort(OutputBuffers.OutputBufferId bufferId) + public void destroy(OutputBuffers.OutputBufferId bufferId) { } @@ -661,7 +661,7 @@ public void destroy() } @Override - public void fail() + public void abort() { } @@ -670,6 +670,12 @@ public long getPeakMemoryUsage() { return 0; } + + @Override + public Optional getFailureCause() + { + return Optional.empty(); + } } private static class SumModuloPartitionFunction diff --git a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSink.java b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSink.java index 29a2a881f9f8..22c211385da0 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSink.java +++ b/core/trino-spi/src/main/java/io/trino/spi/exchange/ExchangeSink.java @@ -35,6 +35,8 @@ public interface ExchangeSink * Appends arbitrary {@code data} to a partition specified by {@code partitionId}. * The engine is free to reuse the {@code data} buffer. * The implementation is expected to copy the buffer as it may be invalidated and recycled. + * If this method is invoked after {@link #finish()} or {@link #abort()} is initiated the + * invocation should be ignored. */ void add(int partitionId, Slice data); @@ -45,12 +47,21 @@ public interface ExchangeSink long getMemoryUsage(); /** - * Notifies the exchange sink that no more data will be appended + * Notifies the exchange sink that no more data will be appended. + * This method is guaranteed not to be called after {@link #abort()}. + * This method is guaranteed not be called more than once. + * + * @return future that will be resolved when the finish operation either succeeds or fails */ - void finish(); + CompletableFuture finish(); /** - * Notifies the exchange that the write operation has been aborted + * Notifies the exchange that the write operation has been aborted. + * This method may be called when {@link #finish()} is still running. In this situation the implementation + * is free to either cancel the finish operation and abort or let the finish operation succeed. + * This method is guaranteed not be called more than once. + * + * @return future that will be resolved when the abort operation either succeeds or fails */ - void abort(); + CompletableFuture abort(); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestExchangeManager.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestExchangeManager.java index 084322cbb6f2..fd922af0688c 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestExchangeManager.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestExchangeManager.java @@ -37,6 +37,7 @@ import java.util.function.Function; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.airlift.concurrent.MoreFutures.getFutureValue; import static io.trino.spi.exchange.ExchangeId.createRandomExchangeId; import static org.assertj.core.api.Assertions.assertThat; import static org.testng.Assert.assertTrue; @@ -165,10 +166,10 @@ private void writeData(ExchangeSinkInstanceHandle handle, Multimap