Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,10 @@ public TransportChannelHandler initializePipeline(
TransportChannelHandler channelHandler = createChannelHandler(channel, channelRpcHandler);
ChunkFetchRequestHandler chunkFetchHandler =
createChunkFetchHandler(channelHandler, channelRpcHandler);

ChannelPipeline pipeline = channel.pipeline()
.addLast("encoder", ENCODER)
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder())
.addLast(TransportFrameDecoder.HANDLER_NAME, NettyUtils.createFrameDecoder(conf.consolidateBufsThreshold()))
.addLast("decoder", DECODER)
.addLast("idleStateHandler",
new IdleStateHandler(0, 0, conf.connectionTimeoutMs() / 1000))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ public static Class<? extends ServerChannel> getServerChannelClass(IOMode mode)
* This is used before all decoders.
*/
public static TransportFrameDecoder createFrameDecoder() {
return new TransportFrameDecoder();
return new TransportFrameDecoder(-1L);
}

public static TransportFrameDecoder createFrameDecoder(long consolidateBufsThreshold) {
return new TransportFrameDecoder(consolidateBufsThreshold);
}

/** Returns the remote address on the channel or "&lt;unknown remote&gt;" if none exists. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class TransportConf {
private final String SPARK_NETWORK_IO_LAZYFD_KEY;
private final String SPARK_NETWORK_VERBOSE_METRICS;
private final String SPARK_NETWORK_IO_ENABLETCPKEEPALIVE_KEY;
private final String SPARK_NETWORK_IO_CONSOLIDATEBUFS_THRESHOLD_KEY;

private final ConfigProvider conf;

Expand All @@ -66,6 +67,7 @@ public TransportConf(String module, ConfigProvider conf) {
SPARK_NETWORK_IO_LAZYFD_KEY = getConfKey("io.lazyFD");
SPARK_NETWORK_VERBOSE_METRICS = getConfKey("io.enableVerboseMetrics");
SPARK_NETWORK_IO_ENABLETCPKEEPALIVE_KEY = getConfKey("io.enableTcpKeepAlive");
SPARK_NETWORK_IO_CONSOLIDATEBUFS_THRESHOLD_KEY = getConfKey("io.consolidateBufsThreshold");
}

public int getInt(String name, int defaultValue) {
Expand Down Expand Up @@ -94,6 +96,23 @@ public boolean preferDirectBufs() {
return conf.getBoolean(SPARK_NETWORK_IO_PREFERDIRECTBUFS_KEY, true);
}

/** The threshold for consolidation, it is derived upon the memoryOverhead in yarn mode. */
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This replicates a lot of logic from elsewhere with hard-coded constants. Is it really important vary it so finely and add a whole new conf? It seems like this ought to be pretty independent of the environment, whether consolidating a buffer of size X is worthwhile.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srowen Thanks, yes, I think you are right, we can just make it some fixed factor of the frame size.

public long consolidateBufsThreshold() {
boolean isDriver = conf.get("spark.executor.id").equals("driver");
final long MEMORY_OVERHEAD_MIN = 384L;
final double MEMORY_OVERHEAD_FACTOR = 0.1;
final double SHUFFLE_MEMORY_OVERHEAD_FACTOR = MEMORY_OVERHEAD_FACTOR * 0.6;
final double SHUFFLE_MEMORY_OVERHEAD_SAFE_FACTOR = SHUFFLE_MEMORY_OVERHEAD_FACTOR * 0.5;
long memory;
if (isDriver) {
memory = Math.max(JavaUtils.byteStringAsBytes(conf.get("spark.driver.memory")), MEMORY_OVERHEAD_MIN);
} else {
memory = Math.max(JavaUtils.byteStringAsBytes(conf.get("spark.executor.memory")), MEMORY_OVERHEAD_MIN);
}
long defaultConsolidateBufsThreshold = (long)(memory * SHUFFLE_MEMORY_OVERHEAD_SAFE_FACTOR);
return conf.getLong(SPARK_NETWORK_IO_CONSOLIDATEBUFS_THRESHOLD_KEY, defaultConsolidateBufsThreshold);
}

/** Connect timeout in milliseconds. Default 120 secs. */
public int connectionTimeoutMs() {
long defaultNetworkTimeoutS = JavaUtils.timeStringAsSec(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ public class TransportFrameDecoder extends ChannelInboundHandlerAdapter {
private long nextFrameSize = UNKNOWN_FRAME_SIZE;
private volatile Interceptor interceptor;

private long consolidateBufsThreshold = Long.MAX_VALUE;
long consolidatedCount = 0L;
long consolidatedTotalTime = 0L;

public TransportFrameDecoder() {}

public TransportFrameDecoder(long consolidateBufsThreshold) {
if (consolidateBufsThreshold > 0) {
this.consolidateBufsThreshold = consolidateBufsThreshold;
}
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object data) throws Exception {
ByteBuf in = (ByteBuf) data;
Expand Down Expand Up @@ -141,10 +153,20 @@ private ByteBuf decodeNext() {

// Otherwise, create a composite buffer.
CompositeByteBuf frame = buffers.getFirst().alloc().compositeBuffer(Integer.MAX_VALUE);
long lastConsolidatedCapacity = 0L;
while (remaining > 0) {
ByteBuf next = nextBufferForFrame(remaining);
remaining -= next.readableBytes();
frame.addComponent(next).writerIndex(frame.writerIndex() + next.readableBytes());
if (frame.capacity() - lastConsolidatedCapacity >= consolidateBufsThreshold) {
// Because the bytebuf created is far less than it's capacity in most cases,
// we can reduce memory consumption by consolidation
long start = System.currentTimeMillis();
frame.consolidate();
consolidatedCount += 1;
consolidatedTotalTime += System.currentTimeMillis() - start;
lastConsolidatedCapacity = frame.capacity();
}
}
assert remaining == 0;
return frame;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,67 @@ public void testFrameDecoding() throws Exception {
verifyAndCloseDecoder(decoder, ctx, data);
}

@Test
public void testConsolidationForDecodingNonFullyWrittenByteBuf() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this is testing that consolidation is reducing the amount of memory needed to hold a frame? But since you're writing just 1 MB to the decoder, that's not triggering consolidation, is it?

Playing with CompositeByteBuf, it adjusts the internal capacity based on the readable bytes of the components, but the component buffers remain unchanged, so still holding on to the original amount of memory:

scala> cb.numComponents()
res4: Int = 2

scala> cb.capacity()
res5: Int = 8

scala> cb.component(0).capacity()
res6: Int = 1048576

So I'm not sure this test is testing anything useful.

Also it would be nice not to use so many magic numbers.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanzin I think the test should be refined. but I was quesion about your test.
CompositeByteBuf.capacity returns the last component endOffset, I think use the capacity for testing is ok.
https://github.com/netty/netty/blob/8fecbab2c56d3f49d0353d58ee1681f3e6d3feca/buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java#L730

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe my question wasn't clear. I'm asking what part of Spark code is this test testing.

As far as I can see, it's testing netty code, and these are not netty unit tests.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanzin it think this test is a little duplicate of testConsolidationPerf, we can just remove it. I will update soon. Sorry for that.

TransportFrameDecoder decoder = new TransportFrameDecoder();
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
List<ByteBuf> retained = new ArrayList<>();
when(ctx.fireChannelRead(any())).thenAnswer(in -> {
ByteBuf buf = (ByteBuf) in.getArguments()[0];
retained.add(buf);
return null;
});
ByteBuf data1 = Unpooled.buffer(1024 * 1024);
data1.writeLong(1024 * 1024 + 8);
data1.writeByte(127);
ByteBuf data2 = Unpooled.buffer(1024 * 1024);
for (int i = 0; i < 1024 * 1024 - 1; i++) {
data2.writeByte(128);
}
int orignalCapacity = data1.capacity() + data2.capacity();
try {
decoder.channelRead(ctx, data1);
decoder.channelRead(ctx, data2);
assertEquals(1, retained.size());
assert(retained.get(0).capacity() < orignalCapacity);
} catch (Exception e) {
release(data1);
release(data2);
}
}

@Test
public void testConsolidationPerf() {
TransportFrameDecoder decoder = new TransportFrameDecoder(300 * 1024 * 1024);
ChannelHandlerContext ctx = mock(ChannelHandlerContext.class);
List<ByteBuf> retained = new ArrayList<>();
when(ctx.fireChannelRead(any())).thenAnswer(in -> {
ByteBuf buf = (ByteBuf) in.getArguments()[0];
retained.add(buf);
return null;
});

ByteBuf buf = Unpooled.buffer(8);
try {
buf.writeLong(8 + 1024 * 1024 * 1000);
decoder.channelRead(ctx, buf);
for (int i = 0; i < 1000; i++) {
buf = Unpooled.buffer(1024 * 1024 * 2);
ByteBuf writtenBuf = Unpooled.buffer(1024 * 1024).writerIndex(1024 * 1024);
buf.writeBytes(writtenBuf);
writtenBuf.release();
decoder.channelRead(ctx, buf);
}
assertEquals(1, retained.size());
assertEquals(1024 * 1024 * 1000, retained.get(0).capacity());
System.out.println("consolidated " + decoder.consolidatedCount + " times cost " + decoder.consolidatedTotalTime + " milis");
} catch (Exception e) {
if (buf != null) {
release(buf);
}
}
}

@Test
public void testInterception() throws Exception {
int interceptedReads = 3;
Expand Down