diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java index 8e73ab077a5c..1980361a1552 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportFrameDecoder.java @@ -19,6 +19,7 @@ import java.util.LinkedList; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.netty.buffer.ByteBuf; import io.netty.buffer.CompositeByteBuf; @@ -48,14 +49,30 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter { private static final int LENGTH_SIZE = 8; private static final int MAX_FRAME_SIZE = Integer.MAX_VALUE; private static final int UNKNOWN_FRAME_SIZE = -1; + private static final long CONSOLIDATE_THRESHOLD = 20 * 1024 * 1024; private final LinkedList buffers = new LinkedList<>(); private final ByteBuf frameLenBuf = Unpooled.buffer(LENGTH_SIZE, LENGTH_SIZE); + private final long consolidateThreshold; + + private CompositeByteBuf frameBuf = null; + private long consolidatedFrameBufSize = 0; + private int consolidatedNumComponents = 0; private long totalSize = 0; private long nextFrameSize = UNKNOWN_FRAME_SIZE; + private int frameRemainingBytes = UNKNOWN_FRAME_SIZE; private volatile Interceptor interceptor; + public TransportFrameDecoder() { + this(CONSOLIDATE_THRESHOLD); + } + + @VisibleForTesting + TransportFrameDecoder(long consolidateThreshold) { + this.consolidateThreshold = consolidateThreshold; + } + @Override public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception { ByteBuf in = (ByteBuf) data; @@ -123,30 +140,56 @@ private long decodeFrameSize() { private ByteBuf decodeNext() { long frameSize = decodeFrameSize(); - if (frameSize == UNKNOWN_FRAME_SIZE || totalSize < frameSize) { + if (frameSize == UNKNOWN_FRAME_SIZE) { return null; } - // Reset size for next frame. - nextFrameSize = UNKNOWN_FRAME_SIZE; - - Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, "Too large frame: %s", frameSize); - Preconditions.checkArgument(frameSize > 0, "Frame length should be positive: %s", frameSize); + if (frameBuf == null) { + Preconditions.checkArgument(frameSize < MAX_FRAME_SIZE, + "Too large frame: %s", frameSize); + Preconditions.checkArgument(frameSize > 0, + "Frame length should be positive: %s", frameSize); + frameRemainingBytes = (int) frameSize; - // If the first buffer holds the entire frame, return it. - int remaining = (int) frameSize; - if (buffers.getFirst().readableBytes() >= remaining) { - return nextBufferForFrame(remaining); + // If buffers is empty, then return immediately for more input data. + if (buffers.isEmpty()) { + return null; + } + // Otherwise, if the first buffer holds the entire frame, we attempt to + // build frame with it and return. + if (buffers.getFirst().readableBytes() >= frameRemainingBytes) { + // Reset buf and size for next frame. + frameBuf = null; + nextFrameSize = UNKNOWN_FRAME_SIZE; + return nextBufferForFrame(frameRemainingBytes); + } + // Other cases, create a composite buffer to manage all the buffers. + frameBuf = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE); } - // Otherwise, create a composite buffer. - CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE); - while (remaining > 0) { - ByteBuf next = nextBufferForFrame(remaining); - remaining -= next.readableBytes(); - frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes()); + while (frameRemainingBytes > 0 && !buffers.isEmpty()) { + ByteBuf next = nextBufferForFrame(frameRemainingBytes); + frameRemainingBytes -= next.readableBytes(); + frameBuf.addComponent(true, next); } - assert remaining == 0; + // If the delta size of frameBuf exceeds the threshold, then we do consolidation + // to reduce memory consumption. + if (frameBuf.capacity() - consolidatedFrameBufSize > consolidateThreshold) { + int newNumComponents = frameBuf.numComponents() - consolidatedNumComponents; + frameBuf.consolidate(consolidatedNumComponents, newNumComponents); + consolidatedFrameBufSize = frameBuf.capacity(); + consolidatedNumComponents = frameBuf.numComponents(); + } + if (frameRemainingBytes > 0) { + return null; + } + + // Reset buf and size for next frame. + ByteBuf frame = frameBuf; + frameBuf = null; + consolidatedFrameBufSize = 0; + consolidatedNumComponents = 0; + nextFrameSize = UNKNOWN_FRAME_SIZE; return frame; } diff --git a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java index 7d40387c5f1a..4b67aa80351d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/util/TransportFrameDecoderSuite.java @@ -27,11 +27,15 @@ import io.netty.channel.ChannelHandlerContext; import org.junit.AfterClass; import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import static org.junit.Assert.*; import static org.mockito.Mockito.*; public class TransportFrameDecoderSuite { + private static final Logger logger = LoggerFactory.getLogger(TransportFrameDecoderSuite.class); private static Random RND = new Random(); @AfterClass @@ -47,6 +51,69 @@ public void testFrameDecoding() throws Exception { verifyAndCloseDecoder(decoder, ctx, data); } + @Test + public void testConsolidationPerf() throws Exception { + long[] testingConsolidateThresholds = new long[] { + ByteUnit.MiB.toBytes(1), + ByteUnit.MiB.toBytes(5), + ByteUnit.MiB.toBytes(10), + ByteUnit.MiB.toBytes(20), + ByteUnit.MiB.toBytes(30), + ByteUnit.MiB.toBytes(50), + ByteUnit.MiB.toBytes(80), + ByteUnit.MiB.toBytes(100), + ByteUnit.MiB.toBytes(300), + ByteUnit.MiB.toBytes(500), + Long.MAX_VALUE }; + for (long threshold : testingConsolidateThresholds) { + TransportFrameDecoder decoder = new TransportFrameDecoder(threshold); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + List retained = new ArrayList<>(); + when(ctx.fireChannelRead(any())).thenAnswer(in -> { + ByteBuf buf = (ByteBuf) in.getArguments()[0]; + retained.add(buf); + return null; + }); + + // Testing multiple messages + int numMessages = 3; + long targetBytes = ByteUnit.MiB.toBytes(300); + int pieceBytes = (int) ByteUnit.KiB.toBytes(32); + for (int i = 0; i < numMessages; i++) { + try { + long writtenBytes = 0; + long totalTime = 0; + ByteBuf buf = Unpooled.buffer(8); + buf.writeLong(8 + targetBytes); + decoder.channelRead(ctx, buf); + while (writtenBytes < targetBytes) { + buf = Unpooled.buffer(pieceBytes * 2); + ByteBuf writtenBuf = Unpooled.buffer(pieceBytes).writerIndex(pieceBytes); + buf.writeBytes(writtenBuf); + writtenBuf.release(); + long start = System.currentTimeMillis(); + decoder.channelRead(ctx, buf); + long elapsedTime = System.currentTimeMillis() - start; + totalTime += elapsedTime; + writtenBytes += pieceBytes; + } + logger.info("Writing 300MiB frame buf with consolidation of threshold " + threshold + + " took " + totalTime + " milis"); + } finally { + for (ByteBuf buf : retained) { + release(buf); + } + } + } + long totalBytesGot = 0; + for (ByteBuf buf : retained) { + totalBytesGot += buf.capacity(); + } + assertEquals(numMessages, retained.size()); + assertEquals(targetBytes * numMessages, totalBytesGot); + } + } + @Test public void testInterception() throws Exception { int interceptedReads = 3;