diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 2d7a72315cf2..b886fce9be21 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -102,8 +102,12 @@ protected void handleMessage( FetchShuffleBlocks msg = (FetchShuffleBlocks) msgObj; checkAuth(client, msg.appId); numBlockIds = 0; - for (int[] ids: msg.reduceIds) { - numBlockIds += ids.length; + if (msg.batchFetchEnabled) { + numBlockIds = msg.mapIds.length; + } else { + for (int[] ids: msg.reduceIds) { + numBlockIds += ids.length; + } } streamId = streamManager.registerStream(client.getClientId(), new ShuffleManagedBufferIterator(msg), client.getChannel()); @@ -323,6 +327,7 @@ private class ShuffleManagedBufferIterator implements Iterator { private final int shuffleId; private final long[] mapIds; private final int[][] reduceIds; + private final boolean batchFetchEnabled; ShuffleManagedBufferIterator(FetchShuffleBlocks msg) { appId = msg.appId; @@ -330,6 +335,7 @@ private class ShuffleManagedBufferIterator implements Iterator { shuffleId = msg.shuffleId; mapIds = msg.mapIds; reduceIds = msg.reduceIds; + batchFetchEnabled = msg.batchFetchEnabled; } @Override @@ -343,12 +349,20 @@ public boolean hasNext() { @Override public ManagedBuffer next() { - final ManagedBuffer block = blockManager.getBlockData( - appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]); - if (reduceIdx < reduceIds[mapIdx].length - 1) { - reduceIdx += 1; + ManagedBuffer block; + if (!batchFetchEnabled) { + block = blockManager.getBlockData( + appId, execId, shuffleId, mapIds[mapIdx], reduceIds[mapIdx][reduceIdx]); + if (reduceIdx < reduceIds[mapIdx].length - 1) { + reduceIdx += 1; + } else { + reduceIdx = 0; + mapIdx += 1; + } } else { - reduceIdx = 0; + assert(reduceIds[mapIdx].length == 2); + block = blockManager.getContinuousBlocksData(appId, execId, shuffleId, mapIds[mapIdx], + reduceIds[mapIdx][0], reduceIds[mapIdx][1]); mapIdx += 1; } metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 8b0d1e145a81..beca5d6e5a78 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -165,8 +165,7 @@ public void registerExecutor( } /** - * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions - * about how the hash and sort based shuffles store their data. + * Obtains a FileSegmentManagedBuffer from a single block (shuffleId, mapId, reduceId). */ public ManagedBuffer getBlockData( String appId, @@ -174,12 +173,26 @@ public ManagedBuffer getBlockData( int shuffleId, long mapId, int reduceId) { + return getContinuousBlocksData(appId, execId, shuffleId, mapId, reduceId, reduceId + 1); + } + + /** + * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, [startReduceId, endReduceId)). + * We make assumptions about how the hash and sort based shuffles store their data. + */ + public ManagedBuffer getContinuousBlocksData( + String appId, + String execId, + int shuffleId, + long mapId, + int startReduceId, + int endReduceId) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, startReduceId, endReduceId); } public ManagedBuffer getRddBlockData( @@ -296,13 +309,14 @@ private void deleteNonShuffleServiceServedFiles(String[] dirs) { * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, long mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, long mapId, int startReduceId, int endReduceId) { File indexFile = ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.index"); try { ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); - ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId); + ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex( + startReduceId, endReduceId); return new FileSegmentManagedBuffer( conf, ExecutorDiskUtils.getFile(executor.localDirs, executor.subDirsPerLocalDir, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 52854c86be3e..ab373a7f03d9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -25,7 +25,6 @@ import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; -import org.apache.commons.lang3.tuple.ImmutableTriple; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -113,39 +112,47 @@ private boolean isShuffleBlocks(String[] blockIds) { */ private FetchShuffleBlocks createFetchShuffleBlocksMsg( String appId, String execId, String[] blockIds) { - int shuffleId = splitBlockId(blockIds[0]).left; + String[] firstBlock = splitBlockId(blockIds[0]); + int shuffleId = Integer.parseInt(firstBlock[1]); + boolean batchFetchEnabled = firstBlock.length == 5; + HashMap> mapIdToReduceIds = new HashMap<>(); for (String blockId : blockIds) { - ImmutableTriple blockIdParts = splitBlockId(blockId); - if (blockIdParts.left != shuffleId) { + String[] blockIdParts = splitBlockId(blockId); + if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = blockIdParts.middle; + long mapId = Long.parseLong(blockIdParts[2]); if (!mapIdToReduceIds.containsKey(mapId)) { mapIdToReduceIds.put(mapId, new ArrayList<>()); } - mapIdToReduceIds.get(mapId).add(blockIdParts.right); + mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[3])); + if (batchFetchEnabled) { + // When we read continuous shuffle blocks in batch, we will reuse reduceIds in + // FetchShuffleBlocks to store the start and end reduce id for range + // [startReduceId, endReduceId). + assert(blockIdParts.length == 5); + mapIdToReduceIds.get(mapId).add(Integer.parseInt(blockIdParts[4])); + } } long[] mapIds = Longs.toArray(mapIdToReduceIds.keySet()); int[][] reduceIdArr = new int[mapIds.length][]; for (int i = 0; i < mapIds.length; i++) { reduceIdArr[i] = Ints.toArray(mapIdToReduceIds.get(mapIds[i])); } - return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIdArr); + return new FetchShuffleBlocks( + appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled); } - /** Split the shuffleBlockId and return shuffleId, mapId and reduceId. */ - private ImmutableTriple splitBlockId(String blockId) { + /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */ + private String[] splitBlockId(String blockId) { String[] blockIdParts = blockId.split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { + if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) { throw new IllegalArgumentException( "Unexpected shuffle block id format: " + blockId); } - return new ImmutableTriple<>( - Integer.parseInt(blockIdParts[1]), - Long.parseLong(blockIdParts[2]), - Integer.parseInt(blockIdParts[3])); + return blockIdParts; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 371149bef397..b65aacfcc4b9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -54,8 +54,15 @@ public int getSize() { * Get index offset for a particular reducer. */ public ShuffleIndexRecord getIndex(int reduceId) { - long offset = offsets.get(reduceId); - long nextOffset = offsets.get(reduceId + 1); + return getIndex(reduceId, reduceId + 1); + } + + /** + * Get index offset for the reducer range of [startReduceId, endReduceId). + */ + public ShuffleIndexRecord getIndex(int startReduceId, int endReduceId) { + long offset = offsets.get(startReduceId); + long nextOffset = offsets.get(endReduceId); return new ShuffleIndexRecord(offset, nextOffset - offset); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index faa960d414bc..c0f307af042e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -35,20 +35,32 @@ public class FetchShuffleBlocks extends BlockTransferMessage { // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. public final long[] mapIds; + // When batchFetchEnabled=true, reduceIds[i] contains 2 elements: startReduceId (inclusive) and + // endReduceId (exclusive) for the mapper mapIds[i]. + // When batchFetchEnabled=false, reduceIds[i] contains all the reduce IDs that mapper mapIds[i] + // needs to fetch. public final int[][] reduceIds; + public final boolean batchFetchEnabled; public FetchShuffleBlocks( String appId, String execId, int shuffleId, long[] mapIds, - int[][] reduceIds) { + int[][] reduceIds, + boolean batchFetchEnabled) { this.appId = appId; this.execId = execId; this.shuffleId = shuffleId; this.mapIds = mapIds; this.reduceIds = reduceIds; assert(mapIds.length == reduceIds.length); + this.batchFetchEnabled = batchFetchEnabled; + if (batchFetchEnabled) { + for (int[] ids: reduceIds) { + assert(ids.length == 2); + } + } } @Override @@ -62,6 +74,7 @@ public String toString() { .add("shuffleId", shuffleId) .add("mapIds", Arrays.toString(mapIds)) .add("reduceIds", Arrays.deepToString(reduceIds)) + .add("batchFetchEnabled", batchFetchEnabled) .toString(); } @@ -73,6 +86,7 @@ public boolean equals(Object o) { FetchShuffleBlocks that = (FetchShuffleBlocks) o; if (shuffleId != that.shuffleId) return false; + if (batchFetchEnabled != that.batchFetchEnabled) return false; if (!appId.equals(that.appId)) return false; if (!execId.equals(that.execId)) return false; if (!Arrays.equals(mapIds, that.mapIds)) return false; @@ -86,6 +100,7 @@ public int hashCode() { result = 31 * result + shuffleId; result = 31 * result + Arrays.hashCode(mapIds); result = 31 * result + Arrays.deepHashCode(reduceIds); + result = 31 * result + (batchFetchEnabled ? 1 : 0); return result; } @@ -100,7 +115,8 @@ public int encodedLength() { + 4 /* encoded length of shuffleId */ + Encoders.LongArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ - + encodedLengthOfReduceIds; + + encodedLengthOfReduceIds + + 1; /* encoded length of batchFetchEnabled */ } @Override @@ -113,6 +129,7 @@ public void encode(ByteBuf buf) { for (int[] ids: reduceIds) { Encoders.IntArrays.encode(buf, ids); } + buf.writeBoolean(batchFetchEnabled); } public static FetchShuffleBlocks decode(ByteBuf buf) { @@ -125,6 +142,7 @@ public static FetchShuffleBlocks decode(ByteBuf buf) { for (int i = 0; i < reduceIdsSize; i++) { reduceIds[i] = Encoders.IntArrays.decode(buf); } - return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIds); + boolean batchFetchEnabled = buf.readBoolean(); + return new FetchShuffleBlocks(appId, execId, shuffleId, mapIds, reduceIds, batchFetchEnabled); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index ba40f4a45ac8..fd2c67a3a270 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -30,7 +30,10 @@ public void serializeOpenShuffleBlocks() { checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); checkSerializeDeserialize(new FetchShuffleBlocks( "app-1", "exec-2", 0, new long[] {0, 1}, - new int[][] {{ 0, 1 }, { 0, 1, 2 }})); + new int[][] {{ 0, 1 }, { 0, 1, 2 }}, false)); + checkSerializeDeserialize(new FetchShuffleBlocks( + "app-1", "exec-2", 0, new long[] {0, 1}, + new int[][] {{ 0, 1 }, { 0, 2 }}, true)); checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 6a5d04b6f417..455351fcf767 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -101,7 +101,7 @@ public void testFetchShuffleBlocks() { when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(blockMarkers[1]); FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks( - "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}); + "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, false); checkOpenBlocksReceive(fetchShuffleBlocks, blockMarkers); verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); @@ -109,6 +109,22 @@ public void testFetchShuffleBlocks() { verifyOpenBlockLatencyMetrics(); } + @Test + public void testFetchShuffleBlocksInBatch() { + ManagedBuffer[] batchBlockMarkers = { + new NioManagedBuffer(ByteBuffer.wrap(new byte[10])) + }; + when(blockResolver.getContinuousBlocksData( + "app0", "exec1", 0, 0, 0, 1)).thenReturn(batchBlockMarkers[0]); + + FetchShuffleBlocks fetchShuffleBlocks = new FetchShuffleBlocks( + "app0", "exec1", 0, new long[] { 0 }, new int[][] {{ 0, 1 }}, true); + checkOpenBlocksReceive(fetchShuffleBlocks, batchBlockMarkers); + + verify(blockResolver, times(1)).getContinuousBlocksData("app0", "exec1", 0, 0, 0, 1); + verifyOpenBlockLatencyMetrics(); + } + @Test public void testOpenDiskPersistedRDDBlocks() { when(blockResolver.getRddBlockData("app0", "exec1", 0, 0)).thenReturn(blockMarkers[0]); @@ -154,7 +170,7 @@ private void checkOpenBlocksReceive(BlockTransferMessage msg, ManagedBuffer[] bl StreamHandle handle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); - assertEquals(2, handle.numChunks); + assertEquals(blockMarkers.length, handle.numChunks); @SuppressWarnings("unchecked") ArgumentCaptor> stream = (ArgumentCaptor>) @@ -162,8 +178,9 @@ private void checkOpenBlocksReceive(BlockTransferMessage msg, ManagedBuffer[] bl verify(streamManager, times(1)).registerStream(anyString(), stream.capture(), any()); Iterator buffers = stream.getValue(); - assertEquals(blockMarkers[0], buffers.next()); - assertEquals(blockMarkers[1], buffers.next()); + for (ManagedBuffer blockMarker : blockMarkers) { + assertEquals(blockMarker, buffers.next()); + } assertFalse(buffers.hasNext()); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 09eb699be305..09b31430b1eb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -111,6 +111,13 @@ public void testSortShuffleBlocks() throws IOException { CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); assertEquals(sortBlock1, block1); } + + try (InputStream blocksStream = resolver.getContinuousBlocksData( + "app0", "exec0", 0, 0, 0, 2).createInputStream()) { + String blocks = + CharStreams.toString(new InputStreamReader(blocksStream, StandardCharsets.UTF_8)); + assertEquals(sortBlock0 + sortBlock1, blocks); + } } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 26a11672b806..285eedb39c65 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -64,7 +64,7 @@ public void testFetchOne() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}), + new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0 }}, false), conf); verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); @@ -100,7 +100,8 @@ public void testFetchThreeShuffleBlocks() { BlockFetchingListener listener = fetchBlocks( blocks, blockIds, - new FetchShuffleBlocks("app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}), + new FetchShuffleBlocks( + "app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 1, 2 }}, false), conf); for (int i = 0; i < 3; i ++) { @@ -109,6 +110,23 @@ public void testFetchThreeShuffleBlocks() { } } + @Test + public void testBatchFetchThreeShuffleBlocks() { + LinkedHashMap blocks = Maps.newLinkedHashMap(); + blocks.put("shuffle_0_0_0_3", new NioManagedBuffer(ByteBuffer.wrap(new byte[58]))); + String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); + + BlockFetchingListener listener = fetchBlocks( + blocks, + blockIds, + new FetchShuffleBlocks( + "app-id", "exec-id", 0, new long[] { 0 }, new int[][] {{ 0, 3 }}, true), + conf); + + verify(listener, times(1)).onBlockFetchSuccess( + "shuffle_0_0_0_3", blocks.get("shuffle_0_0_0_3")); + } + @Test public void testFetchThree() { LinkedHashMap blocks = Maps.newLinkedHashMap(); diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index b2ab31488e4c..3a41c5f73c0a 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -29,7 +29,7 @@ import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithI import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, ShuffleBlockId, StorageLevel} +import org.apache.spark.storage.{BlockId, ShuffleBlockBatchId, ShuffleBlockId, StorageLevel} /** * Serves requests to open blocks by simply registering one chunk per block requested. @@ -65,12 +65,29 @@ class NettyBlockRpcServer( case fetchShuffleBlocks: FetchShuffleBlocks => val blocks = fetchShuffleBlocks.mapIds.zipWithIndex.flatMap { case (mapId, index) => - fetchShuffleBlocks.reduceIds.apply(index).map { reduceId => - blockManager.getBlockData( - ShuffleBlockId(fetchShuffleBlocks.shuffleId, mapId, reduceId)) + if (!fetchShuffleBlocks.batchFetchEnabled) { + fetchShuffleBlocks.reduceIds(index).map { reduceId => + blockManager.getBlockData( + ShuffleBlockId(fetchShuffleBlocks.shuffleId, mapId, reduceId)) + } + } else { + val startAndEndId = fetchShuffleBlocks.reduceIds(index) + if (startAndEndId.length != 2) { + throw new IllegalStateException(s"Invalid shuffle fetch request when batch mode " + + s"is enabled: $fetchShuffleBlocks") + } + Array(blockManager.getBlockData( + ShuffleBlockBatchId( + fetchShuffleBlocks.shuffleId, mapId, startAndEndId(0), startAndEndId(1)))) } } - val numBlockIds = fetchShuffleBlocks.reduceIds.map(_.length).sum + + val numBlockIds = if (fetchShuffleBlocks.batchFetchEnabled) { + fetchShuffleBlocks.mapIds.length + } else { + fetchShuffleBlocks.reduceIds.map(_.length).sum + } + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, client.getChannel) logTrace(s"Registered streamId $streamId with $numBlockIds buffers") diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 3e3c387911d3..623db9d00ab5 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -114,6 +114,7 @@ private[spark] class SerializerManager( case _: RDDBlockId => compressRdds case _: TempLocalBlockId => compressShuffleSpill case _: TempShuffleBlockId => compressShuffle + case _: ShuffleBlockBatchId => compressShuffle case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 3737102a1aba..14080f8822f9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -19,14 +19,14 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter /** - * Fetches and reads the partitions in range [startPartition, endPartition) from a shuffle by - * requesting them from other nodes' block stores. + * Fetches and reads the blocks from a shuffle by requesting them from other nodes' block stores. */ private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], @@ -35,11 +35,33 @@ private[spark] class BlockStoreShuffleReader[K, C]( readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + shouldBatchFetch: Boolean) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency + private def fetchContinuousBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.get(config.SHUFFLE_COMPRESS) + val codecConcatenation = if (compressed) { + CompressionCodec.supportsConcatenationOfSerializedStreams(CompressionCodec.createCodec(conf)) + } else { + true + } + + val doBatchFetch = shouldBatchFetch && serializerRelocatable && + (!compressed || codecConcatenation) + if (shouldBatchFetch && !doBatchFetch) { + logDebug("The feature tag of continuous shuffle block fetching is set to true, but " + + "we can not enable the feature because other conditions are not satisfied. " + + s"Shuffle compress: $compressed, serializer relocatable: $serializerRelocatable, " + + s"codec concatenation: $codecConcatenation.") + } + doBatchFetch + } + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { val wrappedStreams = new ShuffleBlockFetcherIterator( @@ -55,7 +77,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT_MEMORY), - readMetrics).toCompletionIterator + readMetrics, + fetchContinuousBlocksInBatch).toCompletionIterator val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 332164a7be3e..8b3993e21f07 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -190,10 +190,18 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData(blockId: BlockId): ManagedBuffer = { + val (shuffleId, mapId, startReduceId, endReduceId) = blockId match { + case id: ShuffleBlockId => + (id.shuffleId, id.mapId, id.reduceId, id.reduceId + 1) + case batchId: ShuffleBlockBatchId => + (batchId.shuffleId, batchId.mapId, batchId.startReduceId, batchId.endReduceId) + case _ => + throw new IllegalArgumentException("unexpected shuffle block id format: " + blockId) + } // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + val indexFile = getIndexFile(shuffleId, mapId) // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code // which is incorrectly using our file descriptor then this code will fetch the wrong offsets @@ -202,22 +210,23 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8L) + channel.position(startReduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { - val offset = in.readLong() - val nextOffset = in.readLong() + val startOffset = in.readLong() + channel.position(endReduceId * 8L) + val endOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8L + 16 + val expectedPosition = endReduceId * 8L + 8 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") } new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), - offset, - nextOffset - offset) + getDataFile(shuffleId, mapId), + startOffset, + endOffset - startOffset) } finally { in.close() } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index d1ecbc1bf017..c50789658d61 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.BlockId private[spark] /** @@ -34,7 +34,7 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(blockId: BlockId): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index a3529378a4d6..3cd04de0f741 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -127,7 +127,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( handle.shuffleId, startPartition, endPartition) new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) } override def getReaderForOneMapper[K, C]( @@ -140,7 +141,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByMapIndex( handle.shuffleId, mapIndex, startPartition, endPartition) new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics, + shouldBatchFetch = canUseBatchFetch(startPartition, endPartition, context)) } /** Get a writer for a given partition. Called on executors by map tasks. */ @@ -201,10 +203,26 @@ private[spark] object SortShuffleManager extends Logging { * The maximum number of shuffle output partitions that SortShuffleManager supports when * buffering map outputs in a serialized form. This is an extreme defensive programming measure, * since it's extremely unlikely that a single shuffle produces over 16 million output partitions. - * */ + */ val MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE = PackedRecordPointer.MAXIMUM_PARTITION_ID + 1 + /** + * The local property key for continuous shuffle block fetching feature. + */ + val FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY = + "__fetch_continuous_blocks_in_batch_enabled" + + /** + * Helper method for determining whether a shuffle reader should fetch the continuous blocks + * in batch. + */ + def canUseBatchFetch(startPartition: Int, endPartition: Int, context: TaskContext): Boolean = { + val fetchMultiPartitions = endPartition - startPartition > 1 + fetchMultiPartitions && + context.getLocalProperty(FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY) == "true" + } + /** * Helper method for determining whether a shuffle should use an optimized serialized shuffle * path or whether it should fall back to the original path that operates on deserialized objects. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 9c5b7f64e7ab..68ed3aa5b062 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -38,7 +38,7 @@ sealed abstract class BlockId { // convenience methods def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD: Boolean = isInstanceOf[RDDBlockId] - def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId] def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name @@ -56,6 +56,18 @@ case class ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends Bl override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +// The batch id of continuous shuffle blocks of same mapId in range [startReduceId, endReduceId). +@DeveloperApi +case class ShuffleBlockBatchId( + shuffleId: Int, + mapId: Long, + startReduceId: Int, + endReduceId: Int) extends BlockId { + override def name: String = { + "shuffle_" + shuffleId + "_" + mapId + "_" + startReduceId + "_" + endReduceId + } +} + @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" @@ -104,6 +116,7 @@ class UnrecognizedBlockId(name: String) object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r @@ -118,6 +131,8 @@ object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) + case SHUFFLE_BATCH(shuffleId, mapId, startReduceId, endReduceId) => + ShuffleBlockBatchId(shuffleId.toInt, mapId.toLong, startReduceId.toInt, endReduceId.toInt) case SHUFFLE_DATA(shuffleId, mapId, reduceId) => ShuffleDataBlockId(shuffleId.toInt, mapId.toLong, reduceId.toInt) case SHUFFLE_INDEX(shuffleId, mapId, reduceId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a7dfc20d15eb..c869a7078a1e 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -548,7 +548,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId) } else { getLocalBytes(blockId) match { case Some(blockData) => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index dce5ebaebbae..f8aa97267cf1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -62,6 +62,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param shuffleMetrics used to report shuffle metrics. + * @param doBatchFetch fetch continuous shuffle blocks from same executor in batch if the server + * side supports. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -76,7 +78,8 @@ final class ShuffleBlockFetcherIterator( maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, detectCorruptUseExtraMemory: Boolean, - shuffleMetrics: ShuffleReadMetricsReporter) + shuffleMetrics: ShuffleReadMetricsReporter, + doBatchFetch: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -292,9 +295,10 @@ final class ShuffleBlockFetcherIterator( throw new BlockException(blockId, "Zero-sized blocks should be excluded.") case None => // do nothing. } - localBlocks ++= blockInfos.map(info => (info._1, info._3)) - localBlockBytes += blockInfos.map(_._2).sum - numBlocksToFetch += localBlocks.size + val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)).to[ArrayBuffer]) + localBlocks ++= mergedBlockInfos.map(info => (info.blockId, info.mapIndex)) + localBlockBytes += mergedBlockInfos.map(_.size).sum } else { val iterator = blockInfos.iterator var curRequestSize = 0L @@ -308,23 +312,25 @@ final class ShuffleBlockFetcherIterator( throw new BlockException(blockId, "Zero-sized blocks should be excluded.") } else { curBlocks += FetchBlockInfo(blockId, size, mapIndex) - remoteBlocks += blockId - numBlocksToFetch += 1 curRequestSize += size } if (curRequestSize >= targetRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { // Add this FetchRequest - remoteRequests += new FetchRequest(address, curBlocks) + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks) + remoteBlocks ++= mergedBlocks.map(_.blockId) + remoteRequests += new FetchRequest(address, mergedBlocks) logDebug(s"Creating fetch request of $curRequestSize at $address " - + s"with ${curBlocks.size} blocks") + + s"with ${mergedBlocks.size} blocks") curBlocks = new ArrayBuffer[FetchBlockInfo] curRequestSize = 0 } } // Add in the final request if (curBlocks.nonEmpty) { - remoteRequests += new FetchRequest(address, curBlocks) + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks) + remoteBlocks ++= mergedBlocks.map(_.blockId) + remoteRequests += new FetchRequest(address, mergedBlocks) } } } @@ -335,6 +341,51 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + private[this] def mergeContinuousShuffleBlockIdsIfNeeded( + blocks: ArrayBuffer[FetchBlockInfo]): ArrayBuffer[FetchBlockInfo] = { + + def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = { + val startBlockId = toBeMerged.head.blockId.asInstanceOf[ShuffleBlockId] + FetchBlockInfo( + ShuffleBlockBatchId( + startBlockId.shuffleId, + startBlockId.mapId, + startBlockId.reduceId, + toBeMerged.last.blockId.asInstanceOf[ShuffleBlockId].reduceId + 1), + toBeMerged.map(_.size).sum, + toBeMerged.head.mapIndex) + } + + val result = if (doBatchFetch) { + var curBlocks = new ArrayBuffer[FetchBlockInfo] + val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo] + val iter = blocks.iterator + + while (iter.hasNext) { + val info = iter.next() + val curBlockId = info.blockId.asInstanceOf[ShuffleBlockId] + if (curBlocks.isEmpty) { + curBlocks += info + } else { + if (curBlockId.mapId != curBlocks.head.blockId.asInstanceOf[ShuffleBlockId].mapId) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + curBlocks.clear() + } + curBlocks += info + } + } + if (curBlocks.nonEmpty) { + mergedBlockInfo += mergeFetchBlockInfo(curBlocks) + } + mergedBlockInfo + } else { + blocks + } + // update metrics + numBlocksToFetch += result.size + result + } + /** * Fetch the local blocks while we are fetching remote blocks. This is ok because * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we @@ -594,6 +645,8 @@ final class ShuffleBlockFetcherIterator( blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => throw new FetchFailedException(address, shufId, mapId, mapIndex, reduceId, e) + case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) => + throw new FetchFailedException(address, shuffleId, mapId, mapIndex, startReduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 3f9536e224de..67adf5fa5e18 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -138,7 +138,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext taskContext, metrics, serializerManager, - blockManager) + blockManager, + shouldBatchFetch = false) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index 0f3767c4f8c8..ef7b13875540 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -64,6 +64,20 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("shuffle batch") { + val id = ShuffleBlockBatchId(1, 2, 3, 4) + assertSame(id, ShuffleBlockBatchId(1, 2, 3, 4)) + assertDifferent(id, ShuffleBlockBatchId(2, 2, 3, 4)) + assert(id.name === "shuffle_1_2_3_4") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === 2) + assert(id.startReduceId === 3) + assert(id.endReduceId === 4) + assert(id.isShuffle) + assertSame(id, BlockId(id.toString)) + } + test("shuffle data") { val id = ShuffleDataBlockId(4, 5, 6) assertSame(id, ShuffleDataBlockId(4, 5, 6)) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6f7469a9c2b4..85b1a865603a 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -117,7 +117,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, false, - metrics) + metrics, + false) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) @@ -148,6 +149,82 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) } + test("fetch continuous blocks in batch successful 3 local reads + 2 remote reads") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the merged block + val localBlocks = Seq[BlockId]( + ShuffleBlockId(0, 0, 0), + ShuffleBlockId(0, 0, 1), + ShuffleBlockId(0, 0, 2)) + val mergedLocalBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockBatchId(0, 0, 0, 3) -> createMockManagedBuffer()) + mergedLocalBlocks.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + } + + // Make sure remote blocks would return the merged block + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Seq[BlockId]( + ShuffleBlockId(0, 3, 0), + ShuffleBlockId(0, 3, 1)) + val mergedRemoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockBatchId(0, 3, 0, 2) -> createMockManagedBuffer()) + val transfer = createMockTransfer(mergedRemoteBlocks) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (localBmId, localBlocks.map(blockId => (blockId, 1L, 0))), + (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1))) + ).toIterator + + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true, + false, + metrics, + true) + + // 3 local blocks batch fetched in initialization + verify(blockManager, times(1)).getBlockData(any()) + + for (i <- 0 until 2) { + assert(iterator.hasNext, s"iterator should have 2 elements but actually has $i elements") + val (blockId, inputStream) = iterator.next() + + // Make sure we release buffers when a wrapped input stream is closed. + val mockBuf = mergedLocalBlocks.getOrElse(blockId, mergedRemoteBlocks(blockId)) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] + verify(mockBuf, times(0)).release() + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + } + + // 2 remote blocks batch fetched + // (but from the same block manager so one call to fetchBlocks) + verify(blockManager, times(1)).getBlockData(any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + } + test("release current unexhausted buffer in case the task completes early") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) @@ -195,7 +272,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, false, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -264,7 +342,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, false, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -353,7 +432,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -423,7 +503,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // We'll get back the block which has corruption after maxBytesInFlight/3 because the other // block will detect corruption on first fetch, and then get added to the queue again for @@ -487,7 +568,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) val (id, st) = iterator.next() // Check that the test setup is correct -- make sure we have a concatenated stream. assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) @@ -549,7 +631,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, false, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -610,7 +693,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxReqSizeShuffleToMem = 200, detectCorrupt = true, false, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( @@ -658,7 +742,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, true, false, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // All blocks fetched return zero length and should trigger a receive-side error: val e = intercept[FetchFailedException] { iterator.next() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 42e3beca2ad5..e2c1308cdc60 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -355,6 +355,16 @@ object SQLConf { .bytesConf(ByteUnit.BYTE) .createWithDefault(64 * 1024 * 1024) + + val FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED = + buildConf("spark.sql.adaptive.fetchShuffleBlocksInBatch.enabled") + .doc("Whether to fetch the continuous shuffle blocks in batch. Instead of fetching blocks " + + "one by one, fetching continuous shuffle blocks for the same map task in batch can " + + "reduce IO and improve performance. Note, this feature also depends on a relocatable " + + "serializer and the concatenation support codec in use.") + .booleanConf + .createWithDefault(true) + val ADAPTIVE_EXECUTION_ENABLED = buildConf("spark.sql.adaptive.enabled") .doc("When true, enable adaptive query execution.") .booleanConf @@ -2141,6 +2151,9 @@ class SQLConf extends Serializable with Logging { def targetPostShuffleInputSize: Long = getConf(SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) + def fetchShuffleBlocksInBatchEnabled: Boolean = + getConf(FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED) + def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) def nonEmptyPartitionRatioForBroadcastJoin: Double = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala index f5b0e761161d..4c19f95796d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ShuffledRowRDD.scala @@ -21,8 +21,10 @@ import java.util.Arrays import org.apache.spark._ import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} +import org.apache.spark.sql.internal.SQLConf /** * The [[Partition]] used by [[ShuffledRowRDD]]. A post-shuffle partition @@ -117,6 +119,11 @@ class ShuffledRowRDD( specifiedPartitionStartIndices: Option[Array[Int]] = None) extends RDD[InternalRow](dependency.rdd.context, Nil) { + if (SQLConf.get.fetchShuffleBlocksInBatchEnabled) { + dependency.rdd.context.setLocalProperty( + SortShuffleManager.FETCH_SHUFFLE_BLOCKS_IN_BATCH_ENABLED_KEY, "true") + } + private[this] val numPreShufflePartitions = dependency.partitioner.numPartitions private[this] val partitionStartIndices: Array[Int] = specifiedPartitionStartIndices match {