Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,32 +56,43 @@ public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception
buffer = in.alloc().compositeBuffer();
}

buffer.writeBytes(in);
buffer.addComponent(in).writerIndex(buffer.writerIndex() + in.readableBytes());

while (buffer.isReadable()) {
feedInterceptor();
if (interceptor != null) {
continue;
}
discardReadBytes();
if (!feedInterceptor()) {
ByteBuf frame = decodeNext();
if (frame == null) {
break;
}

ByteBuf frame = decodeNext();
if (frame != null) {
ctx.fireChannelRead(frame);
} else {
break;
}
}

// We can't discard read sub-buffers if there are other references to the buffer (e.g.
// through slices used for framing). This assumes that code that retains references
// will call retain() from the thread that called "fireChannelRead()" above, otherwise
// ref counting will go awry.
if (buffer != null && buffer.refCnt() == 1) {
discardReadBytes();
}

private void discardReadBytes() {
// If the buffer's been retained by downstream code, then make a copy of the remaining
// bytes into a new buffer. Otherwise, just discard stale components.
if (buffer.refCnt() > 1) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanzin could you give a real case? Or this is just for correctness, even if downstream in Spark doesn't use retain?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually Spark does use retain() when fetching shuffle blocks, and for some reason that causes problems. I think the real problem is somewhere in netty code, but this is the workaround the netty code itself uses (see ByteToMessageDecoder).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Just saw retain() in ChunkFetchSuccess.

CompositeByteBuf newBuffer = buffer.alloc().compositeBuffer();

if (buffer.readableBytes() > 0) {
ByteBuf spillBuf = buffer.alloc().buffer(buffer.readableBytes());
spillBuf.writeBytes(buffer);
newBuffer.addComponent(spillBuf).writerIndex(spillBuf.readableBytes());
}

buffer.release();
buffer = newBuffer;
} else {
buffer.discardReadComponents();
}
}

protected ByteBuf decodeNext() throws Exception {
private ByteBuf decodeNext() throws Exception {
if (buffer.readableBytes() < LENGTH_SIZE) {
return null;
}
Expand Down Expand Up @@ -127,10 +138,14 @@ public void setInterceptor(Interceptor interceptor) {
this.interceptor = interceptor;
}

private void feedInterceptor() throws Exception {
/**
* @return Whether the interceptor is still active after processing the data.
*/
private boolean feedInterceptor() throws Exception {
if (interceptor != null && !interceptor.handle(buffer)) {
interceptor = null;
}
return interceptor != null;
}

public static interface Interceptor {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,49 +18,44 @@
package org.apache.spark.network.util;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelHandlerContext;
import org.junit.AfterClass;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import static org.junit.Assert.*;
import static org.mockito.Mockito.*;

public class TransportFrameDecoderSuite {

private static Random RND = new Random();

@AfterClass
public static void cleanup() {
RND = null;
}

@Test
public void testFrameDecoding() throws Exception {
Random rnd = new Random();
TransportFrameDecoder decoder = new TransportFrameDecoder();
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);

final int frameCount = 100;
ByteBuf data = Unpooled.buffer();
try {
for (int i = 0; i < frameCount; i++) {
byte[] frame = new byte[1024 * (rnd.nextInt(31) + 1)];
data.writeLong(frame.length + 8);
data.writeBytes(frame);
}

while (data.isReadable()) {
int size = rnd.nextInt(16 * 1024) + 256;
decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)));
}

verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
} finally {
data.release();
}
ChannelHandlerContext ctx = mockChannelHandlerContext();
ByteBuf data = createAndFeedFrames(100, decoder, ctx);
verifyAndCloseDecoder(decoder, ctx, data);
}

@Test
public void testInterception() throws Exception {
final int interceptedReads = 3;
TransportFrameDecoder decoder = new TransportFrameDecoder();
TransportFrameDecoder.Interceptor interceptor = spy(new MockInterceptor(interceptedReads));
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
ChannelHandlerContext ctx = mockChannelHandlerContext();

byte[] data = new byte[8];
ByteBuf len = Unpooled.copyLong(8 + data.length);
Expand All @@ -70,16 +65,56 @@ public void testInterception() throws Exception {
decoder.setInterceptor(interceptor);
for (int i = 0; i < interceptedReads; i++) {
decoder.channelRead(ctx, dataBuf);
dataBuf.release();
assertEquals(0, dataBuf.refCnt());
dataBuf = Unpooled.wrappedBuffer(data);
}
decoder.channelRead(ctx, len);
decoder.channelRead(ctx, dataBuf);
verify(interceptor, times(interceptedReads)).handle(any(ByteBuf.class));
verify(ctx).fireChannelRead(any(ByteBuffer.class));
assertEquals(0, len.refCnt());
assertEquals(0, dataBuf.refCnt());
} finally {
len.release();
dataBuf.release();
release(len);
release(dataBuf);
}
}

@Test
public void testRetainedFrames() throws Exception {
TransportFrameDecoder decoder = new TransportFrameDecoder();

final AtomicInteger count = new AtomicInteger();
final List<ByteBuf> retained = new ArrayList<>();

ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) {
// Retain a few frames but not others.
ByteBuf buf = (ByteBuf) in.getArguments()[0];
if (count.incrementAndGet() % 2 == 0) {
retained.add(buf);
} else {
buf.release();
}
return null;
}
});

ByteBuf data = createAndFeedFrames(100, decoder, ctx);
try {
// Verify all retained buffers are readable.
for (ByteBuf b : retained) {
byte[] tmp = new byte[b.readableBytes()];
b.readBytes(tmp);
b.release();
}
verifyAndCloseDecoder(decoder, ctx, data);
} finally {
for (ByteBuf b : retained) {
release(b);
}
}
}

Expand All @@ -100,6 +135,47 @@ public void testLargeFrame() throws Exception {
testInvalidFrame(Integer.MAX_VALUE + 9);
}

/**
* Creates a number of randomly sized frames and feed them to the given decoder, verifying
* that the frames were read.
*/
private ByteBuf createAndFeedFrames(
int frameCount,
TransportFrameDecoder decoder,
ChannelHandlerContext ctx) throws Exception {
ByteBuf data = Unpooled.buffer();
for (int i = 0; i < frameCount; i++) {
byte[] frame = new byte[1024 * (RND.nextInt(31) + 1)];
data.writeLong(frame.length + 8);
data.writeBytes(frame);
}

try {
while (data.isReadable()) {
int size = RND.nextInt(4 * 1024) + 256;
decoder.channelRead(ctx, data.readSlice(Math.min(data.readableBytes(), size)).retain());
}

verify(ctx, times(frameCount)).fireChannelRead(any(ByteBuf.class));
} catch (Exception e) {
release(data);
throw e;
}
return data;
}

private void verifyAndCloseDecoder(
TransportFrameDecoder decoder,
ChannelHandlerContext ctx,
ByteBuf data) throws Exception {
try {
decoder.channelInactive(ctx);
assertTrue("There shouldn't be dangling references to the data.", data.release());
} finally {
release(data);
}
}

private void testInvalidFrame(long size) throws Exception {
TransportFrameDecoder decoder = new TransportFrameDecoder();
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
Expand All @@ -111,6 +187,25 @@ private void testInvalidFrame(long size) throws Exception {
}
}

private ChannelHandlerContext mockChannelHandlerContext() {
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
when(ctx.fireChannelRead(any())).thenAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock in) {
ByteBuf buf = (ByteBuf) in.getArguments()[0];
buf.release();
return null;
}
});
return ctx;
}

private void release(ByteBuf buf) {
if (buf.refCnt() > 0) {
buf.release(buf.refCnt());
}
}

private static class MockInterceptor implements TransportFrameDecoder.Interceptor {

private int remainingReads;
Expand Down