diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 67f64d7962035..0f6d019968ff8 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -45,7 +45,7 @@ public class OneForOneStreamManager extends StreamManager { private final ConcurrentHashMap streams; /** State of a single stream. */ - private static class StreamState { + public static class StreamState { final String appId; final Iterator buffers; @@ -59,11 +59,19 @@ private static class StreamState { // Used to keep track of the number of chunks being transferred and not finished yet. volatile long chunksBeingTransferred = 0L; - StreamState(String appId, Iterator buffers, Channel channel) { + public StreamState(String appId, Iterator buffers, Channel channel) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); this.associatedChannel = channel; } + + public String getAppId() { + return appId; + } + + public Iterator getBuffers() { + return buffers; + } } public OneForOneStreamManager() { @@ -208,4 +216,8 @@ public long registerStream(String appId, Iterator buffers, Channe public int numStreamStates() { return streams.size(); } + + public ConcurrentHashMap getStreams() { + return streams; + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 8c05288fb4111..edfd3412aaf70 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -23,6 +23,7 @@ import java.util.HashMap; import java.util.Iterator; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.function.Function; import com.codahale.metrics.Gauge; @@ -39,6 +40,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.OneForOneStreamManager.StreamState; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; @@ -109,15 +111,29 @@ protected void handleMessage( numBlockIds += ids.length; } } - streamId = streamManager.registerStream(client.getClientId(), - new ShuffleManagedBufferIterator(msg), client.getChannel()); + if (shouldRegisterStream(msg.appId, msg.execId, client)) { + streamId = streamManager.registerStream(msg.appId, + new ShuffleManagedBufferIterator(msg), client.getChannel()); + } else { + Exception e = new RuntimeException("can not register stream since the app: " + + msg.appId + " has already terminated"); + callback.onFailure(e); + return; + } } else { // For the compatibility with the old version, still keep the support for OpenBlocks. OpenBlocks msg = (OpenBlocks) msgObj; numBlockIds = msg.blockIds.length; checkAuth(client, msg.appId); - streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg), client.getChannel()); + if (shouldRegisterStream(msg.appId, msg.execId, client)) { + streamId = streamManager.registerStream(msg.appId, + new ManagedBufferIterator(msg), client.getChannel()); + } else { + Exception e = new RuntimeException("can not register stream since the app: " + + msg.appId + " has already terminated"); + callback.onFailure(e); + return; + } } if (logger.isTraceEnabled()) { logger.trace( @@ -180,6 +196,22 @@ public StreamManager getStreamManager() { * local directories associated with the executors of that application in a separate thread. */ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Current numStreamStates size: {}", streamManager.numStreamStates()); + ConcurrentHashMap streams = streamManager.getStreams(); + for (Map.Entry entry: streams.entrySet()) { + StreamState state = entry.getValue(); + if (state.getAppId().equals(appId)) { + logger.warn("Found finished app: {} , " + + "but streamState is still in memory, clean it now", appId); + streams.remove(entry.getKey()); + + // Release all remaining buffers. + while (state.getBuffers().hasNext()) { + state.getBuffers().next().release(); + } + } + } + blockManager.applicationRemoved(appId, cleanupLocalDirs); } @@ -388,4 +420,13 @@ public void channelInactive(TransportClient client) { super.channelInactive(client); } + private boolean shouldRegisterStream(String appId, String execId, TransportClient client) { + if (!blockManager.getExecutors().containsKey(new AppExecId(appId, execId))) { + logger.warn("the App {} with execId {} is not exist, so do not register stream", + appId, execId); + client.getChannel().close(); + return false; + } + return true; + } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 657774c1b468f..1f0d59ac7a209 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -76,6 +76,10 @@ public class ExternalShuffleBlockResolver { @VisibleForTesting final ConcurrentMap executors; + public ConcurrentMap getExecutors() { + return executors; + } + /** * Caches index file information so that we can avoid open/close the index files * for each block fetch. diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 455351fcf767c..4d3a531ff36a9 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -18,10 +18,15 @@ package org.apache.spark.network.shuffle; import java.nio.ByteBuffer; +import java.util.ArrayList; import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ConcurrentMap; import com.codahale.metrics.Meter; import com.codahale.metrics.Timer; +import com.google.common.collect.Maps; +import io.netty.channel.Channel; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -35,6 +40,7 @@ import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.server.OneForOneStreamManager; +import org.apache.spark.network.server.OneForOneStreamManager.StreamState; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; @@ -43,6 +49,7 @@ import org.apache.spark.network.shuffle.protocol.RegisterExecutor; import org.apache.spark.network.shuffle.protocol.StreamHandle; import org.apache.spark.network.shuffle.protocol.UploadBlock; +import org.apache.spark.network.shuffle.ExternalShuffleBlockResolver.AppExecId; public class ExternalBlockHandlerSuite { TransportClient client = mock(TransportClient.class); @@ -161,6 +168,13 @@ public void testOpenDiskPersistedRDDBlocksWithMissingBlock() { private void checkOpenBlocksReceive(BlockTransferMessage msg, ManagedBuffer[] blockMarkers) { when(client.getClientId()).thenReturn("app0"); + // add app info to executors + ExecutorShuffleInfo shuffleInfo = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, + 16, "sort"); + ConcurrentMap executors = Maps.newConcurrentMap(); + executors.put(new AppExecId("app0", "exec1"), shuffleInfo); + when(blockResolver.getExecutors()).thenReturn(executors); + RpcResponseCallback callback = mock(RpcResponseCallback.class); handler.receive(client, msg.toByteBuffer(), callback); @@ -222,4 +236,53 @@ public void testBadMessages() { verify(callback, never()).onSuccess(any(ByteBuffer.class)); verify(callback, never()).onFailure(any(Throwable.class)); } + + @Test + public void testDoNotRegisterStreamWhenAppHasFinished() { + // add app info to executors + ExecutorShuffleInfo shuffleInfo = new ExecutorShuffleInfo(new String[] {"/a", "/b"}, + 16, "sort"); + ConcurrentMap executors = Maps.newConcurrentMap(); + executors.put(new AppExecId("app0", "exec1"), shuffleInfo); + when(blockResolver.getExecutors()).thenReturn(executors); + + Channel dummyChannel = mock(Channel.class, RETURNS_SMART_NULLS); + when(client.getChannel()).thenReturn(dummyChannel); + when(client.getClientId()).thenReturn("app1"); + + // suppose app1's info has been cleaned + // for OpenBlocks msg case + OpenBlocks openBlocks = new OpenBlocks( + "app1", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }); + RpcResponseCallback callback = mock(RpcResponseCallback.class); + handler.receive(client, openBlocks.toByteBuffer(), callback); + verify(callback, times(1)).onFailure(any(Throwable.class)); + assertEquals(0, streamManager.numStreamStates()); + + // for FetchShuffleBlocks msg case + FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks( + "app1", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, false); + handler.receive(client, fetchShuffleBlocks.toByteBuffer(), callback); + verify(callback, never()).onSuccess(any(ByteBuffer.class)); + assertEquals(0, streamManager.numStreamStates()); + } + + @Test + public void testWhenApplicationRemovedCleanRelatedStreamState() { + OneForOneStreamManager oneForOneStreamManager = new OneForOneStreamManager(); + ExternalBlockHandler externalBlockHandler = new ExternalBlockHandler( + oneForOneStreamManager, + blockResolver); + Channel dummyChannel = mock(Channel.class, RETURNS_SMART_NULLS); + + List buffers = new ArrayList<>(); + buffers.add(new NioManagedBuffer(ByteBuffer.wrap(new byte[3]))); + buffers.add(new NioManagedBuffer(ByteBuffer.wrap(new byte[7]))); + oneForOneStreamManager.getStreams().put(1L, + new StreamState("app0", buffers.iterator(), dummyChannel)); + assertEquals(1, oneForOneStreamManager.numStreamStates()); + + externalBlockHandler.applicationRemoved("app0", false); + assertEquals(0, oneForOneStreamManager.numStreamStates()); + } }