diff --git a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java index 42ffb7e6037ba..dbd918da7bfdc 100644 --- a/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java +++ b/flink-libraries/flink-state-processing-api/src/main/java/org/apache/flink/state/api/runtime/SavepointEnvironment.java @@ -36,6 +36,7 @@ import org.apache.flink.runtime.checkpoint.PrioritizedOperatorSubtaskState; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.executiongraph.ExecutionGraphID; @@ -70,6 +71,8 @@ import org.apache.flink.util.UserCodeClassLoader; import org.apache.flink.util.concurrent.Executors; +import javax.annotation.Nullable; + import java.util.Collections; import java.util.Map; import java.util.concurrent.CompletableFuture; @@ -119,6 +122,8 @@ public class SavepointEnvironment implements Environment { private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory; + @Nullable private ChannelStateWriter channelStateWriter; + private SavepointEnvironment( RuntimeContext ctx, ExecutionConfig executionConfig, @@ -440,4 +445,15 @@ public CompletableFuture sendRequestToCoordinator( return CompletableFuture.completedFuture(null); } } + + @Override + public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { + this.channelStateWriter = channelStateWriter; + } + + @Override + @Nullable + public ChannelStateWriter getChannelStateWriter() { + return this.channelStateWriter; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java index a4b203b5b9d4c..4b449a4d12950 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/execution/Environment.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider; import org.apache.flink.runtime.io.disk.iomanager.IOManager; @@ -280,4 +281,8 @@ default CheckpointStorageAccess getCheckpointStorageAccess() { } ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory(); + + void setChannelStateWriter(ChannelStateWriter channelStateWriter); + + ChannelStateWriter getChannelStateWriter(); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java index d5e297c54b0ee..69df0116802cd 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/RuntimeEnvironment.java @@ -31,6 +31,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider; @@ -118,6 +119,8 @@ public class RuntimeEnvironment implements Environment { ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory; + @Nullable private ChannelStateWriter channelStateWriter; + // ------------------------------------------------------------------------ public RuntimeEnvironment( @@ -408,4 +411,15 @@ public CheckpointStorageAccess getCheckpointStorageAccess() { public ChannelStateWriteRequestExecutorFactory getChannelStateExecutorFactory() { return channelStateExecutorFactory; } + + public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { + checkState(this.channelStateWriter == null, "Cannot set channelStateWriter twice!"); + this.channelStateWriter = channelStateWriter; + } + + @Override + @Nullable + public ChannelStateWriter getChannelStateWriter() { + return this.channelStateWriter; + } } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java index 1be391cd9856b..844820041b22c 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/Task.java @@ -42,6 +42,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointStoreUtil; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.clusterframework.types.AllocationID; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; @@ -311,6 +312,9 @@ public class Task */ private UserCodeClassLoader userCodeClassLoader; + /** The channelStateWriter of the env. We obtain it after the invokable is initialized. */ + @Nullable private volatile ChannelStateWriter channelStateWriter; + /** * IMPORTANT: This constructor may not start any work that would need to be undone in the * case of a failing task deployment. @@ -508,6 +512,12 @@ TaskInvokable getInvokable() { return invokable; } + @Nullable + @VisibleForTesting + ChannelStateWriter getChannelStateWriter() { + return channelStateWriter; + } + public boolean isBackPressured() { if (invokable == null || partitionWriters.length == 0 @@ -749,6 +759,10 @@ private void doRun() { FlinkSecurityManager.unmonitorUserSystemExitForCurrentThread(); } + // We register a reference to the channelStateWriter + // so we can close it after the inputGates close + this.channelStateWriter = env.getChannelStateWriter(); + // ---------------------------------------------------------------- // actual task core work // ---------------------------------------------------------------- @@ -1011,6 +1025,16 @@ private void releaseResources() { } closeAllResultPartitions(); closeAllInputGates(); + if (this.channelStateWriter != null) { + LOG.debug("Closing channelStateWriter for task {}", taskNameWithSubtask); + try { + this.channelStateWriter.close(); + } catch (Throwable t) { + ExceptionUtils.rethrowIfFatalError(t); + LOG.error( + "Failed to close channelStateWriter for task {}.", taskNameWithSubtask, t); + } + } try { taskStateManager.close(); diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java index a7e02d458464f..505e67a30d78d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java @@ -522,6 +522,8 @@ protected StreamTask( CheckpointingOptions .UNALIGNED_MAX_SUBTASKS_PER_CHANNEL_STATE_FILE)) : ChannelStateWriter.NO_OP; + environment.setChannelStateWriter(channelStateWriter); + this.subtaskCheckpointCoordinator = new SubtaskCheckpointCoordinatorImpl( checkpointStorageAccess, diff --git a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java index 24c63a7e5ab5b..811354c47f380 100644 --- a/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java +++ b/flink-runtime/src/main/java/org/apache/flink/streaming/runtime/tasks/SubtaskCheckpointCoordinatorImpl.java @@ -579,7 +579,6 @@ public void cancel() throws IOException { } } IOUtils.closeAllQuietly(asyncCheckpointRunnables); - channelStateWriter.close(); } @VisibleForTesting diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java index 0a6a425c4ac1d..fe291c1b4e5f3 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java @@ -32,6 +32,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider; @@ -61,6 +62,8 @@ import org.apache.flink.runtime.util.TestingUserCodeClassLoader; import org.apache.flink.util.UserCodeClassLoader; +import javax.annotation.Nullable; + import java.util.Collections; import java.util.Map; import java.util.concurrent.Future; @@ -88,6 +91,8 @@ public class DummyEnvironment implements Environment { private CheckpointStorageAccess checkpointStorageAccess; + @Nullable private ChannelStateWriter channelStateWriter; + public DummyEnvironment() { this("Test Job", 1, 0, 1); } @@ -312,4 +317,15 @@ public void setCheckpointStorageAccess(CheckpointStorageAccess checkpointStorage public CheckpointStorageAccess getCheckpointStorageAccess() { return checkNotNull(checkpointStorageAccess); } + + @Override + public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { + this.channelStateWriter = channelStateWriter; + } + + @Override + @Nullable + public ChannelStateWriter getChannelStateWriter() { + return this.channelStateWriter; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java index 1b4a7bb817252..c50eface312b0 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java @@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider; @@ -65,6 +66,8 @@ import org.apache.flink.util.UserCodeClassLoader; import org.apache.flink.util.concurrent.Executors; +import javax.annotation.Nullable; + import java.util.Collections; import java.util.LinkedList; import java.util.List; @@ -149,6 +152,8 @@ public class MockEnvironment implements Environment, AutoCloseable { private final ChannelStateWriteRequestExecutorFactory channelStateExecutorFactory; + @Nullable private ChannelStateWriter channelStateWriter; + public static MockEnvironmentBuilder builder() { return new MockEnvironmentBuilder(); } @@ -495,4 +500,15 @@ public Optional getActualExternalFailureCause() { public void setExternalFailureCauseConsumer(Consumer externalFailureCauseConsumer) { this.externalFailureCauseConsumer = Optional.of(externalFailureCauseConsumer); } + + @Override + public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { + this.channelStateWriter = channelStateWriter; + } + + @Override + @Nullable + public ChannelStateWriter getChannelStateWriter() { + return this.channelStateWriter; + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java index 65bc95d4b0c97..36a248934e14b 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskTest.java @@ -26,6 +26,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetaData; import org.apache.flink.runtime.checkpoint.CheckpointOptions; import org.apache.flink.runtime.checkpoint.CheckpointType; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor; import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor; import org.apache.flink.runtime.deployment.TaskDeploymentDescriptorFactory.ShuffleDescriptorAndIndex; @@ -77,6 +78,7 @@ import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.apache.flink.runtime.testutils.CommonTestUtils.waitUntilCondition; @@ -1249,6 +1251,34 @@ public void testDeclineCheckpoint() throws Exception { assertEquals(ExecutionState.FINISHED, task.getTerminationFuture().getNow(null)); } + private void testChannelStateWriterCloses(Class invokable) + throws Exception { + final Task task = + createTaskBuilder() + .setInvokable(invokable) + .setTaskManagerActions(new NoOpTaskManagerActions()) + .build(Executors.directExecutor()); + + task.startTaskThread(); + awaitInvokableLatch(task); + ChannelStateWriterWithCloseTracker channelStateWriter = + (ChannelStateWriterWithCloseTracker) task.getChannelStateWriter(); + assertFalse(channelStateWriter.isClosed()); + triggerInvokableLatch(task); + task.getExecutingThread().join(); + assertTrue(channelStateWriter.isClosed()); + } + + @Test + public void testChannelStateWriterClosesOnSuccess() throws Exception { + testChannelStateWriterCloses(ChannelStateWriterSetterInvokable.class); + } + + @Test + public void testChannelStateWriterClosesOnFailure() throws Exception { + testChannelStateWriterCloses(FailingChannelStateWriterSetterInvokable.class); + } + private void assertCheckpointDeclined( Task task, TestCheckpointResponder testCheckpointResponder, @@ -1576,7 +1606,7 @@ public void cleanUp(Throwable throwable) throws Exception { } /** {@link AbstractInvokable} which throws {@link RuntimeException} on invoke. */ - public static final class InvokableWithExceptionOnTrigger extends TriggerLatchInvokable { + public static class InvokableWithExceptionOnTrigger extends TriggerLatchInvokable { public InvokableWithExceptionOnTrigger(Environment environment) { super(environment); } @@ -1762,4 +1792,34 @@ void awaitTriggerLatch() { } } } + + private static class ChannelStateWriterWithCloseTracker + extends ChannelStateWriter.NoOpChannelStateWriter { + private final AtomicBoolean closeCalled = new AtomicBoolean(false); + + @Override + public void close() { + closeCalled.set(true); + } + + public boolean isClosed() { + return closeCalled.get(); + } + } + + private static class ChannelStateWriterSetterInvokable extends InvokableBlockingWithTrigger { + + public ChannelStateWriterSetterInvokable(Environment environment) { + super(environment); + environment.setChannelStateWriter(new ChannelStateWriterWithCloseTracker()); + } + } + + private static class FailingChannelStateWriterSetterInvokable + extends InvokableWithExceptionOnTrigger { + public FailingChannelStateWriterSetterInvokable(Environment environment) { + super(environment); + environment.setChannelStateWriter(new ChannelStateWriterWithCloseTracker()); + } + } } diff --git a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java index 9fb6b73376275..dd4ef8994c2e5 100644 --- a/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java +++ b/flink-runtime/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamMockEnvironment.java @@ -34,6 +34,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics; import org.apache.flink.runtime.checkpoint.TaskStateSnapshot; import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriteRequestExecutorFactory; +import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter; import org.apache.flink.runtime.execution.Environment; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; import org.apache.flink.runtime.externalresource.ExternalResourceInfoProvider; @@ -143,6 +144,8 @@ public class StreamMockEnvironment implements Environment { private CheckpointStorageAccess checkpointStorageAccess; + @Nullable private ChannelStateWriter channelStateWriter; + public StreamMockEnvironment( Configuration jobConfig, Configuration taskConfig, @@ -455,4 +458,15 @@ public void setCheckpointStorageAccess(CheckpointStorageAccess checkpointStorage public CheckpointStorageAccess getCheckpointStorageAccess() { return checkpointStorageAccess; } + + @Override + public void setChannelStateWriter(ChannelStateWriter channelStateWriter) { + this.channelStateWriter = channelStateWriter; + } + + @Override + @Nullable + public ChannelStateWriter getChannelStateWriter() { + return this.channelStateWriter; + } }