diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 272ea84e6180..5889562dd970 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -56,32 +56,43 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception buffer = in.alloc().compositeBuffer(); } - buffer.writeBytes(in); + buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); while (buffer.isReadable()) { - feedInterceptor(); - if (interceptor != null) { - continue; - } + discardReadBytes(); + if (!feedInterceptor()) { + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } - ByteBuf frame = decodeNext(); - if (frame != null) { ctx.fireChannelRead(frame); - } else { - break; } } - // We can't discard read sub-buffers if there are other references to the buffer (e.g. - // through slices used for framing). This assumes that code that retains references - // will call retain() from the thread that called "fireChannelRead()" above, otherwise - // ref counting will go awry. - if (buffer != null && buffer.refCnt() == 1) { + discardReadBytes(); + } + + private void discardReadBytes() { + // If the buffer's been retained by downstream code, then make a copy of the remaining + // bytes into a new buffer. Otherwise, just discard stale components. + if (buffer.refCnt() > 1) { + CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + + if (buffer.readableBytes() > 0) { + ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); + spillBuf.writeBytes(buffer); + newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + } + + buffer.release(); + buffer = newBuffer; + } else { buffer.discardReadComponents(); } } - protected ByteBuf decodeNext() throws Exception { + private ByteBuf decodeNext() throws Exception { if (buffer.readableBytes() < LENGTH_SIZE) { return null; } @@ -127,10 +138,14 @@ public void setInterceptor(Interceptor interceptor) { this.interceptor = interceptor; } - private void feedInterceptor() throws Exception { + /** + * @return Whether the interceptor is still active after processing the data. + */ + private boolean feedInterceptor() throws Exception { if (interceptor != null && !interceptor.handle(buffer)) { interceptor = null; } + return interceptor != null; } public static interface Interceptor { diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index ca74f0a00cf9..19475c21ffce 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -18,41 +18,36 @@ package org.apache.spark.network.util; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import org.junit.AfterClass; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import static org.junit.Assert.*; import static org.mockito.Mockito.*; public class TransportFrameDecoderSuite { + private static Random RND = new Random(); + + @AfterClass + public static void cleanup() { + RND = null; + } + @Test public void testFrameDecoding() throws Exception { - Random rnd = new Random(); TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); - - final int frameCount = 100; - ByteBuf data = Unpooled.buffer(); - try { - for (int i = 0; i < frameCount; i++) { - byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)]; - data.writeLong(frame.length + 8); - data.writeBytes(frame); - } - - while (data.isReadable()) { - int size = rnd.nextInt(16 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size))); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - } finally { - data.release(); - } + ChannelHandlerContext ctx = mockChannelHandlerContext(); + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + verifyAndCloseDecoder(decoder, ctx, data); } @Test @@ -60,7 +55,7 @@ public void testInterception() throws Exception { final int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelHandlerContext ctx = mockChannelHandlerContext(); byte[] data = new byte[8]; ByteBuf len = Unpooled.copyLong(8 + data.length); @@ -70,16 +65,56 @@ public void testInterception() throws Exception { decoder.setInterceptor(interceptor); for (int i = 0; i < interceptedReads; i++) { decoder.channelRead(ctx, dataBuf); - dataBuf.release(); + assertEquals(0, dataBuf.refCnt()); dataBuf = Unpooled.wrappedBuffer(data); } decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); verify(ctx).fireChannelRead(any(ByteBuffer.class)); + assertEquals(0, len.refCnt()); + assertEquals(0, dataBuf.refCnt()); } finally { - len.release(); - dataBuf.release(); + release(len); + release(dataBuf); + } + } + + @Test + public void testRetainedFrames() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + + final AtomicInteger count = new AtomicInteger(); + final List retained = new ArrayList<>(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); + } + return null; + } + }); + + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + try { + // Verify all retained buffers are readable. + for (ByteBuf b : retained) { + byte[] tmp = new byte[b.readableBytes()]; + b.readBytes(tmp); + b.release(); + } + verifyAndCloseDecoder(decoder, ctx, data); + } finally { + for (ByteBuf b : retained) { + release(b); + } } } @@ -100,6 +135,47 @@ public void testLargeFrame() throws Exception { testInvalidFrame(Integer.MAX_VALUE + 9); } + /** + * Creates a number of randomly sized frames and feed them to the given decoder, verifying + * that the frames were read. + */ + private ByteBuf createAndFeedFrames( + int frameCount, + TransportFrameDecoder decoder, + ChannelHandlerContext ctx) throws Exception { + ByteBuf data = Unpooled.buffer(); + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + data.writeLong(frame.length + 8); + data.writeBytes(frame); + } + + try { + while (data.isReadable()) { + int size = RND.nextInt(4 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } catch (Exception e) { + release(data); + throw e; + } + return data; + } + + private void verifyAndCloseDecoder( + TransportFrameDecoder decoder, + ChannelHandlerContext ctx, + ByteBuf data) throws Exception { + try { + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); + } finally { + release(data); + } + } + private void testInvalidFrame(long size) throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); @@ -111,6 +187,25 @@ private void testInvalidFrame(long size) throws Exception { } } + private ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + return null; + } + }); + return ctx; + } + + private void release(ByteBuf buf) { + if (buf.refCnt() > 0) { + buf.release(buf.refCnt()); + } + } + private static class MockInterceptor implements TransportFrameDecoder.Interceptor { private int remainingReads;