From ed0c1d7e6df3357344574b3ef2dccb526ed3f9b7 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 10 Nov 2015 22:00:27 -0800 Subject: [PATCH 1/5] [SPARK-11617] [network] Fix leak in TransportFrameDecoder. The code was using the wrong API to add data to the internal composite buffer, causing buffers to leak in certain situations. Use the right API and enhance the tests to catch memory leaks. --- .../network/util/TransportFrameDecoder.java | 2 +- .../util/TransportFrameDecoderSuite.java | 41 +++++++++++++++---- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 272ea84e6180..a1c0d56c5424 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -56,7 +56,7 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception buffer = in.alloc().compositeBuffer(); } - buffer.writeBytes(in); + buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); while (buffer.isReadable()) { feedInterceptor(); diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index ca74f0a00cf9..2317ae69dca1 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -24,6 +24,8 @@ import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import static org.junit.Assert.*; import static org.mockito.Mockito.*; @@ -33,7 +35,7 @@ public class TransportFrameDecoderSuite { public void testFrameDecoding() throws Exception { Random rnd = new Random(); TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelHandlerContext ctx = mockChannelHandlerContext(); final int frameCount = 100; ByteBuf data = Unpooled.buffer(); @@ -46,12 +48,15 @@ public void testFrameDecoding() throws Exception { while (data.isReadable()) { int size = rnd.nextInt(16 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size))); + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); } verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); } finally { - data.release(); + release(data); } } @@ -60,7 +65,7 @@ public void testInterception() throws Exception { final int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + ChannelHandlerContext ctx = mockChannelHandlerContext(); byte[] data = new byte[8]; ByteBuf len = Unpooled.copyLong(8 + data.length); @@ -70,16 +75,18 @@ public void testInterception() throws Exception { decoder.setInterceptor(interceptor); for (int i = 0; i < interceptedReads; i++) { decoder.channelRead(ctx, dataBuf); - dataBuf.release(); + assertEquals(0, dataBuf.refCnt()); dataBuf = Unpooled.wrappedBuffer(data); } decoder.channelRead(ctx, len); decoder.channelRead(ctx, dataBuf); verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class)); verify(ctx).fireChannelRead(any(ByteBuffer.class)); + assertEquals(0, len.refCnt()); + assertEquals(0, dataBuf.refCnt()); } finally { - len.release(); - dataBuf.release(); + release(len); + release(dataBuf); } } @@ -111,6 +118,26 @@ private void testInvalidFrame(long size) throws Exception { } } + private ChannelHandlerContext mockChannelHandlerContext() { + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.readerIndex(buf.readerIndex() + buf.readableBytes()); + buf.release(); + return null; + } + }); + return ctx; + } + + private void release(ByteBuf buf) { + if (buf.refCnt() > 0) { + buf.release(buf.refCnt()); + } + } + private static class MockInterceptor implements TransportFrameDecoder.Interceptor { private int remainingReads; From 8a7d19496c8c61f0e3e2c5180857f1773ba10372 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 12 Nov 2015 13:39:28 -0800 Subject: [PATCH 2/5] Copy bytes when handlers keep references to data. This makes the frame decoder behave more like netty's ByteToMessageDecoder, at the expense of copying some data in a few cases. --- .../network/util/TransportFrameDecoder.java | 45 ++++++++++++------- .../util/TransportFrameDecoderSuite.java | 35 ++++++++++++--- 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index a1c0d56c5424..5889562dd970 100644 --- a/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/network/common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -59,29 +59,40 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes()); while (buffer.isReadable()) { - feedInterceptor(); - if (interceptor != null) { - continue; - } + discardReadBytes(); + if (!feedInterceptor()) { + ByteBuf frame = decodeNext(); + if (frame == null) { + break; + } - ByteBuf frame = decodeNext(); - if (frame != null) { ctx.fireChannelRead(frame); - } else { - break; } } - // We can't discard read sub-buffers if there are other references to the buffer (e.g. - // through slices used for framing). This assumes that code that retains references - // will call retain() from the thread that called "fireChannelRead()" above, otherwise - // ref counting will go awry. - if (buffer != null && buffer.refCnt() == 1) { + discardReadBytes(); + } + + private void discardReadBytes() { + // If the buffer's been retained by downstream code, then make a copy of the remaining + // bytes into a new buffer. Otherwise, just discard stale components. + if (buffer.refCnt() > 1) { + CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer(); + + if (buffer.readableBytes() > 0) { + ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes()); + spillBuf.writeBytes(buffer); + newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes()); + } + + buffer.release(); + buffer = newBuffer; + } else { buffer.discardReadComponents(); } } - protected ByteBuf decodeNext() throws Exception { + private ByteBuf decodeNext() throws Exception { if (buffer.readableBytes() < LENGTH_SIZE) { return null; } @@ -127,10 +138,14 @@ public void setInterceptor(Interceptor interceptor) { this.interceptor = interceptor; } - private void feedInterceptor() throws Exception { + /** + * @return Whether the interceptor is still active after processing the data. + */ + private boolean feedInterceptor() throws Exception { if (interceptor != null && !interceptor.handle(buffer)) { interceptor = null; } + return interceptor != null; } public static interface Interceptor { diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index 2317ae69dca1..cc824871229e 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -19,6 +19,7 @@ import java.nio.ByteBuffer; import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; @@ -35,7 +36,7 @@ public class TransportFrameDecoderSuite { public void testFrameDecoding() throws Exception { Random rnd = new Random(); TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mockChannelHandlerContext(); + ChannelHandlerContext ctx = mockChannelHandlerContext(true); final int frameCount = 100; ByteBuf data = Unpooled.buffer(); @@ -65,7 +66,7 @@ public void testInterception() throws Exception { final int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mockChannelHandlerContext(); + ChannelHandlerContext ctx = mockChannelHandlerContext(true); byte[] data = new byte[8]; ByteBuf len = Unpooled.copyLong(8 + data.length); @@ -90,6 +91,27 @@ public void testInterception() throws Exception { } } + @Test + public void testRetainedByteBuf() throws Exception { + TransportFrameDecoder decoder = new TransportFrameDecoder(); + ChannelHandlerContext ctx = mockChannelHandlerContext(false); + + byte[] frame = new byte[1024]; + ByteBuf data = Unpooled.buffer(); + data.writeLong(frame.length + 8); + data.writeBytes(frame); + + try { + decoder.channelRead(ctx, data); + // Because the mock context is not releasing the buffer slice passed to it, the frame + // decoder should not clear the read data from its internal composite buffer, so there + // should still be a reference to the original buffer. + assertEquals(1, data.refCnt()); + } finally { + release(data); + } + } + @Test(expected = IllegalArgumentException.class) public void testNegativeFrameSize() throws Exception { testInvalidFrame(-1); @@ -118,14 +140,15 @@ private void testInvalidFrame(long size) throws Exception { } } - private ChannelHandlerContext mockChannelHandlerContext() { + private ChannelHandlerContext mockChannelHandlerContext(final boolean releaseBuffer) { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { @Override public Void answer(InvocationOnMock in) { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.readerIndex(buf.readerIndex() + buf.readableBytes()); - buf.release(); + if (releaseBuffer) { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); + } return null; } }); From 180456ce3d25a126a74fa0f57c22665ef002b6cd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Fri, 13 Nov 2015 16:41:19 -0800 Subject: [PATCH 3/5] Remove stray import. --- .../apache/spark/network/util/TransportFrameDecoderSuite.java | 1 - 1 file changed, 1 deletion(-) diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index cc824871229e..9cc7e6cc5dc0 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -19,7 +19,6 @@ import java.nio.ByteBuffer; import java.util.Random; -import java.util.concurrent.atomic.AtomicReference; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; From dcdfc31bfa7461c16a5a3cd17ae66cc999697664 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 16 Nov 2015 12:51:06 -0800 Subject: [PATCH 4/5] Better test for retained frames. This test actually fails with java.lang.IndexOutOfBoundsException if the fix in this patch set is disabled. --- .../util/TransportFrameDecoderSuite.java | 94 +++++++++++++------ 1 file changed, 67 insertions(+), 27 deletions(-) diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index 9cc7e6cc5dc0..cabcbd34db86 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -18,11 +18,15 @@ package org.apache.spark.network.util; import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; import io.netty.buffer.ByteBuf; import io.netty.buffer.Unpooled; import io.netty.channel.ChannelHandlerContext; +import org.junit.AfterClass; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; @@ -31,23 +35,23 @@ public class TransportFrameDecoderSuite { + private static Random RND = new Random(); + + @AfterClass + public static void cleanup() { + RND = null; + } + @Test public void testFrameDecoding() throws Exception { - Random rnd = new Random(); TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mockChannelHandlerContext(true); + ChannelHandlerContext ctx = mockChannelHandlerContext(); final int frameCount = 100; - ByteBuf data = Unpooled.buffer(); + ByteBuf data = createInputBuffer(frameCount); try { - for (int i = 0; i < frameCount; i++) { - byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)]; - data.writeLong(frame.length + 8); - data.writeBytes(frame); - } - while (data.isReadable()) { - int size = rnd.nextInt(16 * 1024) + 256; + int size = RND.nextInt(16 * 1024) + 256; decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); } @@ -65,7 +69,7 @@ public void testInterception() throws Exception { final int interceptedReads = 3; TransportFrameDecoder decoder = new TransportFrameDecoder(); TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads)); - ChannelHandlerContext ctx = mockChannelHandlerContext(true); + ChannelHandlerContext ctx = mockChannelHandlerContext(); byte[] data = new byte[8]; ByteBuf len = Unpooled.copyLong(8 + data.length); @@ -91,22 +95,50 @@ public void testInterception() throws Exception { } @Test - public void testRetainedByteBuf() throws Exception { + public void testRetainedFrames() throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); - ChannelHandlerContext ctx = mockChannelHandlerContext(false); - byte[] frame = new byte[1024]; - ByteBuf data = Unpooled.buffer(); - data.writeLong(frame.length + 8); - data.writeBytes(frame); + final AtomicInteger count = new AtomicInteger(); + final List retained = new ArrayList<>(); + + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) { + // Retain a few frames but not others. + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + if (count.incrementAndGet() % 2 == 0) { + retained.add(buf); + } else { + buf.release(); + } + return null; + } + }); + final int frameCount = 100; + ByteBuf data = createInputBuffer(frameCount); try { - decoder.channelRead(ctx, data); - // Because the mock context is not releasing the buffer slice passed to it, the frame - // decoder should not clear the read data from its internal composite buffer, so there - // should still be a reference to the original buffer. - assertEquals(1, data.refCnt()); + while (data.isReadable()) { + int size = RND.nextInt(16 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + + // Verify all retained buffers are readable. + for (ByteBuf b : retained) { + byte[] tmp = new byte[b.readableBytes()]; + b.readBytes(tmp); + b.release(); + } + + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); } finally { + for (ByteBuf b : retained) { + release(b); + } release(data); } } @@ -128,6 +160,16 @@ public void testLargeFrame() throws Exception { testInvalidFrame(Integer.MAX_VALUE + 9); } + private ByteBuf createInputBuffer(int frameCount) throws Exception { + ByteBuf data = Unpooled.buffer(); + for (int i = 0; i < frameCount; i++) { + byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; + data.writeLong(frame.length + 8); + data.writeBytes(frame); + } + return data; + } + private void testInvalidFrame(long size) throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); @@ -139,15 +181,13 @@ private void testInvalidFrame(long size) throws Exception { } } - private ChannelHandlerContext mockChannelHandlerContext(final boolean releaseBuffer) { + private ChannelHandlerContext mockChannelHandlerContext() { ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); when(ctx.fireChannelRead(any())).thenAnswer(new Answer() { @Override public Void answer(InvocationOnMock in) { - if (releaseBuffer) { - ByteBuf buf = (ByteBuf) in.getArguments()[0]; - buf.release(); - } + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + buf.release(); return null; } }); From 7fe96174edde015c8c3bc43ca734e809c5041192 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Mon, 16 Nov 2015 13:28:57 -0800 Subject: [PATCH 5/5] Refactor some common test code. --- .../util/TransportFrameDecoderSuite.java | 66 ++++++++++--------- 1 file changed, 36 insertions(+), 30 deletions(-) diff --git a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index cabcbd34db86..19475c21ffce 100644 --- a/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -46,22 +46,8 @@ public static void cleanup() { public void testFrameDecoding() throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); ChannelHandlerContext ctx = mockChannelHandlerContext(); - - final int frameCount = 100; - ByteBuf data = createInputBuffer(frameCount); - try { - while (data.isReadable()) { - int size = RND.nextInt(16 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - - decoder.channelInactive(ctx); - assertTrue("There shouldn't be dangling references to the data.", data.release()); - } finally { - release(data); - } + ByteBuf data = createAndFeedFrames(100, decoder, ctx); + verifyAndCloseDecoder(decoder, ctx, data); } @Test @@ -116,30 +102,19 @@ public Void answer(InvocationOnMock in) { } }); - final int frameCount = 100; - ByteBuf data = createInputBuffer(frameCount); + ByteBuf data = createAndFeedFrames(100, decoder, ctx); try { - while (data.isReadable()) { - int size = RND.nextInt(16 * 1024) + 256; - decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); - } - - verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); - // Verify all retained buffers are readable. for (ByteBuf b : retained) { byte[] tmp = new byte[b.readableBytes()]; b.readBytes(tmp); b.release(); } - - decoder.channelInactive(ctx); - assertTrue("There shouldn't be dangling references to the data.", data.release()); + verifyAndCloseDecoder(decoder, ctx, data); } finally { for (ByteBuf b : retained) { release(b); } - release(data); } } @@ -160,16 +135,47 @@ public void testLargeFrame() throws Exception { testInvalidFrame(Integer.MAX_VALUE + 9); } - private ByteBuf createInputBuffer(int frameCount) throws Exception { + /** + * Creates a number of randomly sized frames and feed them to the given decoder, verifying + * that the frames were read. + */ + private ByteBuf createAndFeedFrames( + int frameCount, + TransportFrameDecoder decoder, + ChannelHandlerContext ctx) throws Exception { ByteBuf data = Unpooled.buffer(); for (int i = 0; i < frameCount; i++) { byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)]; data.writeLong(frame.length + 8); data.writeBytes(frame); } + + try { + while (data.isReadable()) { + int size = RND.nextInt(4 * 1024) + 256; + decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain()); + } + + verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class)); + } catch (Exception e) { + release(data); + throw e; + } return data; } + private void verifyAndCloseDecoder( + TransportFrameDecoder decoder, + ChannelHandlerContext ctx, + ByteBuf data) throws Exception { + try { + decoder.channelInactive(ctx); + assertTrue("There shouldn't be dangling references to the data.", data.release()); + } finally { + release(data); + } + } + private void testInvalidFrame(long size) throws Exception { TransportFrameDecoder decoder = new TransportFrameDecoder(); ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);