diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index df8c21fb837e..867bf93e75fc 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -56,7 +56,8 @@ class NettyBlockRpcServer( case openBlocks: OpenBlocks => val blocks: Seq[ManagedBuffer] = openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) - val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, + client.getChannel) logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) diff --git a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index e671854da1ca..12919ea1bcbd 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -48,15 +48,16 @@ private static class StreamState { final Iterator buffers; // The channel associated to the stream - Channel associatedChannel = null; + final Channel associatedChannel; // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. int curChunk = 0; - StreamState(String appId, Iterator buffers) { + StreamState(String appId, Iterator buffers, Channel associatedChannel) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); + this.associatedChannel = associatedChannel; } } @@ -67,13 +68,6 @@ public OneForOneStreamManager() { streams = new ConcurrentHashMap(); } - @Override - public void registerChannel(Channel channel, long streamId) { - if (streams.containsKey(streamId)) { - streams.get(streamId).associatedChannel = channel; - } - } - @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); @@ -135,10 +129,9 @@ public void checkAuthorization(TransportClient client, long streamId) { * If an app ID is provided, only callers who've authenticated with the given app ID will be * allowed to fetch from this stream. */ - public long registerStream(String appId, Iterator buffers) { + public long registerStream(String appId, Iterator buffers, Channel channel) { long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(appId, buffers)); + streams.put(myStreamId, new StreamState(appId, buffers, channel)); return myStreamId; } - } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 3f0155957a14..fed112da016c 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -59,16 +59,6 @@ public ManagedBuffer openStream(String streamId) { throw new UnsupportedOperationException(); } - /** - * Associates a stream with a single client connection, which is guaranteed to be the only reader - * of the stream. The getChunk() method will be called serially on this connection and once the - * connection is closed, the stream will never be used again, enabling cleanup. - * - * This must be called before the first getChunk() on the stream, but it may be invoked multiple - * times with the same channel and stream id. - */ - public void registerChannel(Channel channel, long streamId) { } - /** * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not * to read from the associated streams again, so any state can be cleaned up. diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index c864d7ce16bd..bcb0fda399e3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -117,7 +117,6 @@ private void processFetchRequest(final ChunkFetchRequest req) { ManagedBuffer buf; try { streamManager.checkAuthorization(reverseClient, req.streamChunkId.streamId); - streamManager.registerChannel(channel, req.streamChunkId.streamId); buf = streamManager.getChunk(req.streamChunkId.streamId, req.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format( diff --git a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index f22187a01db0..0173d90c0a01 100644 --- a/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -84,7 +84,7 @@ protected void handleMessage( for (String blockId : msg.blockIds) { blocks.add(blockManager.getBlockData(msg.appId, msg.execId, blockId)); } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); + long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator(), client.getChannel()); logger.trace("Registered streamId {} with {} buffers", streamId, msg.blockIds.length); callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); diff --git a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 9379412155e8..bf3e06f4c9be 100644 --- a/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/network/shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -20,6 +20,7 @@ import java.nio.ByteBuffer; import java.util.Iterator; +import io.netty.channel.Channel; import org.junit.Before; import org.junit.Test; import org.mockito.ArgumentCaptor; @@ -94,7 +95,8 @@ public void testOpenShuffleBlocks() { @SuppressWarnings("unchecked") ArgumentCaptor> stream = (ArgumentCaptor>) (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); + verify(streamManager, times(1)).registerStream(anyString(), stream.capture(), + (Channel)any()); Iterator buffers = stream.getValue(); assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next());