diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java index 138fd5389c20..d55779954659 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/BlockFetchingListener.java @@ -27,10 +27,10 @@ public interface BlockFetchingListener extends EventListener { * automatically. If the data will be passed to another thread, the receiver should retain() * and release() the buffer on their own, or copy the data to a new buffer. */ - void onBlockFetchSuccess(String blockId, ManagedBuffer data); + void onBlockFetchSuccess(String[] blockIds, ManagedBuffer data); /** * Called at least once per block upon failures. */ - void onBlockFetchFailure(String blockId, Throwable exception); + void onBlockFetchFailure(String[] blockIds, Throwable exception); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index b25e48a164e6..ee7175785581 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -20,9 +20,7 @@ import java.io.File; import java.io.IOException; import java.nio.ByteBuffer; -import java.util.HashMap; -import java.util.Iterator; -import java.util.Map; +import java.util.*; import com.codahale.metrics.Gauge; import com.codahale.metrics.Meter; @@ -91,16 +89,18 @@ protected void handleMessage( try { OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - long streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel()); + ManagedBufferIterator blocksIter = new ManagedBufferIterator( + msg.appId, msg.execId, msg.blockIds, msg.fetchContinuousShuffleBlocksInBatch); + long streamId = streamManager.registerStream( + client.getClientId(), blocksIter, client.getChannel()); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, - msg.blockIds.length, + blocksIter.getNumChunks(), client.getClientId(), getRemoteAddress(client.getChannel())); } - callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); + callback.onSuccess(new StreamHandle(streamId, blocksIter.getNumChunks()).toByteBuffer()); } finally { responseDelayContext.stop(); } @@ -211,42 +211,58 @@ private class ManagedBufferIterator implements Iterator { private final String appId; private final String execId; private final int shuffleId; - // An array containing mapId and reduceId pairs. - private final int[] mapIdAndReduceIds; + // An array containing mapId, reduceId and numReducers tuple + private final int[] shuffleBlockBatches; - ManagedBufferIterator(String appId, String execId, String[] blockIds) { + ManagedBufferIterator( + String appId, + String execId, + String[] blockIds, + boolean fetchContinuousShuffleBlocksInBatch) { this.appId = appId; this.execId = execId; String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length != 4 || !blockId0Parts[0].equals("shuffle")) { + if (!ExternalShuffleBlockResolver.isShuffleBlock(blockId0Parts)) { throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[0]); } this.shuffleId = Integer.parseInt(blockId0Parts[1]); - mapIdAndReduceIds = new int[2 * blockIds.length]; - for (int i = 0; i < blockIds.length; i++) { - String[] blockIdParts = blockIds[i].split("_"); - if (blockIdParts.length != 4 || !blockIdParts[0].equals("shuffle")) { - throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockIds[i]); + if (fetchContinuousShuffleBlocksInBatch) { + ArrayList> arrayShuffleBlockIds = + ExternalShuffleBlockResolver.mergeContinuousShuffleBlockIds(blockIds); + shuffleBlockBatches = new int[arrayShuffleBlockIds.size() * 3]; + for (int i = 0; i < arrayShuffleBlockIds.size(); i++) { + ArrayList arrayShuffleBlockId = arrayShuffleBlockIds.get(i); + int[] startBlockId = arrayShuffleBlockId.get(0); + int[] endBlockId = arrayShuffleBlockId.get(arrayShuffleBlockId.size() - 1); + shuffleBlockBatches[3 * i] = startBlockId[0]; + shuffleBlockBatches[3 * i + 1] = startBlockId[1]; + shuffleBlockBatches[3 * i + 2] = endBlockId[1] - startBlockId[1] + 1; } - if (Integer.parseInt(blockIdParts[1]) != shuffleId) { - throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + - ", got:" + blockIds[i]); + } else { + shuffleBlockBatches = new int[3 * blockIds.length]; + for (int i = 0; i < blockIds.length; i++) { + int[] blockIdParts = ExternalShuffleBlockResolver.getBlockIdParts(blockIds[i]); + shuffleBlockBatches[3 * i] = blockIdParts[0]; + shuffleBlockBatches[3 * i + 1] = blockIdParts[1]; + shuffleBlockBatches[3 * i + 2] = 1; } - mapIdAndReduceIds[2 * i] = Integer.parseInt(blockIdParts[2]); - mapIdAndReduceIds[2 * i + 1] = Integer.parseInt(blockIdParts[3]); } } + public int getNumChunks() { + return shuffleBlockBatches.length / 3; + } + @Override public boolean hasNext() { - return index < mapIdAndReduceIds.length; + return index < shuffleBlockBatches.length; } @Override public ManagedBuffer next() { final ManagedBuffer block = blockManager.getBlockData(appId, execId, shuffleId, - mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); - index += 2; + shuffleBlockBatches[index], shuffleBlockBatches[index + 1], shuffleBlockBatches[index + 2]); + index += 3; metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); return block; } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 0b7a27402369..40a53deb8d9c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -161,22 +161,69 @@ public void registerExecutor( executors.put(fullId, executorInfo); } + // For testing + public ManagedBuffer getBlockData( + String appId, + String execId, + int shuffleId, + int mapId, + int reduceId) { + return getBlockData(appId, execId, shuffleId, mapId, reduceId, 1); + } + /** - * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId). We make assumptions - * about how the hash and sort based shuffles store their data. + * Obtains a FileSegmentManagedBuffer from (shuffleId, mapId, reduceId, numReducers). We make + * assumptions about how the hash and sort based shuffles store their data. */ public ManagedBuffer getBlockData( String appId, String execId, int shuffleId, int mapId, - int reduceId) { + int reduceId, + int numReducers) { ExecutorShuffleInfo executor = executors.get(new AppExecId(appId, execId)); if (executor == null) { throw new RuntimeException( String.format("Executor is not registered (appId=%s, execId=%s)", appId, execId)); } - return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); + return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId, numReducers); + } + + public static boolean isShuffleBlock(String[] blockIdParts) { + return blockIdParts.length == 4 && blockIdParts[0].equals("shuffle"); + } + + public static int[] getBlockIdParts(String blockId) { + String[] blockIdParts = blockId.split("_"); + if (!isShuffleBlock(blockIdParts)) { + throw new IllegalArgumentException("Unexpected shuffle block id format: " + blockId); + } + return new int[] { Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]) }; + } + + // Currently, for all input blockIds, we can make assumption that block ids of the same mapper id + // are consecutive in the map output file. Although, logically, they might not be consecutive + // because of zero-sized blocks, which have been filtered out in the client side actually. + public static ArrayList> mergeContinuousShuffleBlockIds(String[] blockIds) { + ArrayList shuffleBlockIds = new ArrayList<>(); + ArrayList> arrayShuffleBlockIds = new ArrayList<>(); + + for (String blockId: blockIds) { + int[] blockIdParts = getBlockIdParts(blockId); + if (shuffleBlockIds.size() == 0) { + shuffleBlockIds.add(blockIdParts); + } else { + if (blockIdParts[0] != shuffleBlockIds.get(0)[0]) { + arrayShuffleBlockIds.add(shuffleBlockIds); + shuffleBlockIds = new ArrayList<>(); + } + shuffleBlockIds.add(blockIdParts); + } + } + arrayShuffleBlockIds.add(shuffleBlockIds); + + return arrayShuffleBlockIds; } /** @@ -280,13 +327,14 @@ public boolean accept(File dir, String name) { * and the block id format is from ShuffleDataBlockId and ShuffleIndexBlockId. */ private ManagedBuffer getSortBasedShuffleBlockData( - ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId) { + ExecutorShuffleInfo executor, int shuffleId, int mapId, int reduceId, int numReducers) { File indexFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, "shuffle_" + shuffleId + "_" + mapId + "_0.index"); try { ShuffleIndexInformation shuffleIndexInformation = shuffleIndexCache.get(indexFile); - ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(reduceId); + ShuffleIndexRecord shuffleIndexRecord = + shuffleIndexInformation.getIndex(reduceId, numReducers); return new FileSegmentManagedBuffer( conf, getFile(executor.localDirs, executor.subDirsPerLocalDir, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index e49e27ab5aa7..a9eb2f1173c3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -91,15 +91,16 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - DownloadFileManager downloadFileManager) { + DownloadFileManager downloadFileManager, + boolean fetchContinuousShuffleBlocksInBatch) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { RetryingBlockFetcher.BlockFetchStarter blockFetchStarter = (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); - new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, downloadFileManager).start(); + new OneForOneBlockFetcher(client, appId, execId, blockIds1, listener1, conf, + downloadFileManager, fetchContinuousShuffleBlocksInBatch).start(); }; int maxRetries = conf.maxIORetries(); @@ -112,9 +113,7 @@ public void fetchBlocks( } } catch (Exception e) { logger.error("Exception while beginning fetchBlocks", e); - for (String blockId : blockIds) { - listener.onBlockFetchFailure(blockId, e); - } + listener.onBlockFetchFailure(blockIds, e); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 30587023877c..7ed392a22642 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.Arrays; +import java.util.ArrayList; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -50,6 +51,10 @@ public class OneForOneBlockFetcher { private final TransportClient client; private final OpenBlocks openMessage; private final String[] blockIds; + // In adaptive execution, one returned chunk might contain data for several consecutive blockIds, + // blockIdIndices is used to record the mapping relationship between chunk and its blockIds. + // chunk i contains block Ids: blockIdIndices[i] until blockIdIndices[i + 1] in blockIds + private int[] blockIdIndices = null; private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private final TransportConf transportConf; @@ -64,7 +69,7 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf) { - this(client, appId, execId, blockIds, listener, transportConf, null); + this(client, appId, execId, blockIds, listener, transportConf, null, false); } public OneForOneBlockFetcher( @@ -74,9 +79,10 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - DownloadFileManager downloadFileManager) { + DownloadFileManager downloadFileManager, + boolean fetchContinuousShuffleBlocksInBatch) { this.client = client; - this.openMessage = new OpenBlocks(appId, execId, blockIds); + this.openMessage = new OpenBlocks(appId, execId, blockIds, fetchContinuousShuffleBlocksInBatch); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); @@ -89,13 +95,15 @@ private class ChunkCallback implements ChunkReceivedCallback { @Override public void onSuccess(int chunkIndex, ManagedBuffer buffer) { // On receipt of a chunk, pass it upwards as a block. - listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); + listener.onBlockFetchSuccess(Arrays.copyOfRange(blockIds, blockIdIndices[chunkIndex], + blockIdIndices[chunkIndex + 1]), buffer); } @Override public void onFailure(int chunkIndex, Throwable e) { // On receipt of a failure, fail every block from chunkIndex onwards. - String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + String[] remainingBlockIds = Arrays.copyOfRange(blockIds, blockIdIndices[chunkIndex], + blockIds.length); failRemainingBlocks(remainingBlockIds, e); } } @@ -117,6 +125,25 @@ public void onSuccess(ByteBuffer response) { streamHandle = (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response); logger.trace("Successfully opened blocks {}, preparing to fetch chunks.", streamHandle); + // initiate blockIdIndices + if (streamHandle.numChunks == blockIds.length) { + blockIdIndices = new int[streamHandle.numChunks + 1]; + for (int i = 0; i < blockIdIndices.length; i++) { + blockIdIndices[i] = i; + } + } else { + // server fetches continuous shuffle blocks in batch + ArrayList> arrayShuffleBlockIds = + ExternalShuffleBlockResolver.mergeContinuousShuffleBlockIds(blockIds); + assert(streamHandle.numChunks == arrayShuffleBlockIds.size()); + blockIdIndices = new int[arrayShuffleBlockIds.size() + 1]; + blockIdIndices[0] = 0; + for (int i = 1; i < blockIdIndices.length; i++) { + blockIdIndices[i] = blockIdIndices[i - 1] + arrayShuffleBlockIds.get(i - 1).size();; + } + } + assert blockIdIndices[blockIdIndices.length - 1] == blockIds.length; + // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { @@ -143,12 +170,10 @@ public void onFailure(Throwable e) { /** Invokes the "onBlockFetchFailure" callback for every listed block id. */ private void failRemainingBlocks(String[] failedBlockIds, Throwable e) { - for (String blockId : failedBlockIds) { - try { - listener.onBlockFetchFailure(blockId, e); - } catch (Exception e2) { - logger.error("Error in block fetch failure callback", e2); - } + try { + listener.onBlockFetchFailure(failedBlockIds, e); + } catch (Exception e2) { + logger.error("Error in block fetch failure callback", e2); } } @@ -173,7 +198,8 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { @Override public void onComplete(String streamId) throws IOException { - listener.onBlockFetchSuccess(blockIds[chunkIndex], channel.closeAndRead()); + listener.onBlockFetchSuccess(Arrays.copyOfRange(blockIds, blockIdIndices[chunkIndex], + blockIdIndices[chunkIndex + 1]), channel.closeAndRead()); if (!downloadFileManager.registerTempFileToClean(targetFile)) { targetFile.delete(); } @@ -183,7 +209,8 @@ public void onComplete(String streamId) throws IOException { public void onFailure(String streamId, Throwable cause) throws IOException { channel.close(); // On receipt of a failure, fail every block from chunkIndex onwards. - String[] remainingBlockIds = Arrays.copyOfRange(blockIds, chunkIndex, blockIds.length); + String[] remainingBlockIds = + Arrays.copyOfRange(blockIds, blockIdIndices[chunkIndex], blockIds.length); failRemainingBlocks(remainingBlockIds, cause); targetFile.delete(); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java index 6bf3da94030d..fde8c190cdc2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockFetcher.java @@ -18,6 +18,7 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import java.util.ArrayList; import java.util.Collections; import java.util.LinkedHashSet; import java.util.concurrent.ExecutorService; @@ -146,9 +147,7 @@ private void fetchAllOutstanding() { if (shouldRetry(e)) { initiateRetry(); } else { - for (String bid : blockIdsToFetch) { - listener.onBlockFetchFailure(bid, e); - } + listener.onBlockFetchFailure(blockIdsToFetch, e); } } } @@ -188,44 +187,60 @@ private synchronized boolean shouldRetry(Throwable e) { */ private class RetryingBlockFetchListener implements BlockFetchingListener { @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + public void onBlockFetchSuccess(String[] blockIds, ManagedBuffer data) { // We will only forward this success message to our parent listener if this block request is // outstanding and we are still the active listener. - boolean shouldForwardSuccess = false; + ArrayList forwardBlockIds = new ArrayList<>(blockIds.length); synchronized (RetryingBlockFetcher.this) { - if (this == currentListener && outstandingBlocksIds.contains(blockId)) { - outstandingBlocksIds.remove(blockId); - shouldForwardSuccess = true; + if (this == currentListener) { + for (String blockId : blockIds) { + if (outstandingBlocksIds.contains(blockId)) { + outstandingBlocksIds.remove(blockId); + forwardBlockIds.add(blockId); + } + } } } // Now actually invoke the parent listener, outside of the synchronized block. - if (shouldForwardSuccess) { - listener.onBlockFetchSuccess(blockId, data); + if (!forwardBlockIds.isEmpty()) { + assert forwardBlockIds.size() == blockIds.length; + listener.onBlockFetchSuccess(blockIds, data); } } @Override - public void onBlockFetchFailure(String blockId, Throwable exception) { + public void onBlockFetchFailure(String[] blockIds, Throwable exception) { // We will only forward this failure to our parent listener if this block request is // outstanding, we are still the active listener, AND we cannot retry the fetch. boolean shouldForwardFailure = false; + ArrayList possibleForwardBlockIds = new ArrayList<>(blockIds.length); synchronized (RetryingBlockFetcher.this) { - if (this == currentListener && outstandingBlocksIds.contains(blockId)) { - if (shouldRetry(exception)) { - initiateRetry(); - } else { - logger.error(String.format("Failed to fetch block %s, and will not retry (%s retries)", - blockId, retryCount), exception); - outstandingBlocksIds.remove(blockId); - shouldForwardFailure = true; + if (this == currentListener) { + for (String blockId : blockIds) { + if (outstandingBlocksIds.contains(blockId)) { + possibleForwardBlockIds.add(blockId); + } + } + if (!possibleForwardBlockIds.isEmpty()) { + if (shouldRetry(exception)) { + initiateRetry(); + } else { + logger.error(String.format("Failed to fetch of %s outstanding blocks, and will " + + "not retry (%s retries)", possibleForwardBlockIds.size(), retryCount), exception); + for (String blockId : possibleForwardBlockIds) { + outstandingBlocksIds.remove(blockId); + } + shouldForwardFailure = true; + } } } } // Now actually invoke the parent listener, outside of the synchronized block. if (shouldForwardFailure) { - listener.onBlockFetchFailure(blockId, exception); + listener.onBlockFetchFailure( + possibleForwardBlockIds.toArray(new String[possibleForwardBlockIds.size()]), exception); } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 62b99c40f61f..c87adc609ad5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -47,6 +47,8 @@ public void init(String appId) { } * If it's not null, the remote blocks will be streamed * into temp shuffle files to reduce the memory usage, otherwise, * they will be kept in memory. + * @param fetchContinuousShuffleBlocksInBatch fetch continuous shuffle blocks in batch if server + * side supports. */ public abstract void fetchBlocks( String host, @@ -54,7 +56,8 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - DownloadFileManager downloadFileManager); + DownloadFileManager downloadFileManager, + boolean fetchContinuousShuffleBlocksInBatch); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index 371149bef397..a280106565ca 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -53,9 +53,9 @@ public int getSize() { /** * Get index offset for a particular reducer. */ - public ShuffleIndexRecord getIndex(int reduceId) { + public ShuffleIndexRecord getIndex(int reduceId, int numReducers) { long offset = offsets.get(reduceId); - long nextOffset = offsets.get(reduceId + 1); + long nextOffset = offsets.get(reduceId + numReducers); return new ShuffleIndexRecord(offset, nextOffset - offset); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java index ce954b8a289e..0ea0267bc3e9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/OpenBlocks.java @@ -32,11 +32,25 @@ public class OpenBlocks extends BlockTransferMessage { public final String appId; public final String execId; public final String[] blockIds; + // Indicates if the client wants server to read continuous shuffle blocks in batch, to + // reduce IO. When it is true, OpenBlocks could contain ShuffleBlockId only. + // This field is newly added in Spark 3.0, and will be encoded in the message only when it's true. + public final boolean fetchContinuousShuffleBlocksInBatch; + // This is only used in tests. public OpenBlocks(String appId, String execId, String[] blockIds) { + this(appId, execId, blockIds, false); + } + + public OpenBlocks( + String appId, + String execId, + String[] blockIds, + boolean fetchContinuousShuffleBlocksInBatch) { this.appId = appId; this.execId = execId; this.blockIds = blockIds; + this.fetchContinuousShuffleBlocksInBatch = fetchContinuousShuffleBlocksInBatch; } @Override @@ -44,7 +58,8 @@ public OpenBlocks(String appId, String execId, String[] blockIds) { @Override public int hashCode() { - return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + return Objects.hashCode(appId, execId, fetchContinuousShuffleBlocksInBatch) * 41 + + Arrays.hashCode(blockIds); } @Override @@ -53,6 +68,7 @@ public String toString() { .add("appId", appId) .add("execId", execId) .add("blockIds", Arrays.toString(blockIds)) + .add("fetchContinuousShuffleBlocksInBatch", fetchContinuousShuffleBlocksInBatch) .toString(); } @@ -62,7 +78,9 @@ public boolean equals(Object other) { OpenBlocks o = (OpenBlocks) other; return Objects.equal(appId, o.appId) && Objects.equal(execId, o.execId) - && Arrays.equals(blockIds, o.blockIds); + && Arrays.equals(blockIds, o.blockIds) + && Objects.equal(fetchContinuousShuffleBlocksInBatch, + o.fetchContinuousShuffleBlocksInBatch); } return false; } @@ -71,7 +89,8 @@ public boolean equals(Object other) { public int encodedLength() { return Encoders.Strings.encodedLength(appId) + Encoders.Strings.encodedLength(execId) - + Encoders.StringArrays.encodedLength(blockIds); + + Encoders.StringArrays.encodedLength(blockIds) + + (fetchContinuousShuffleBlocksInBatch ? 1 : 0); } @Override @@ -79,12 +98,23 @@ public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); Encoders.Strings.encode(buf, execId); Encoders.StringArrays.encode(buf, blockIds); + if (fetchContinuousShuffleBlocksInBatch) { + buf.writeBoolean(true); + } } public static OpenBlocks decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); String execId = Encoders.Strings.decode(buf); String[] blockIds = Encoders.StringArrays.decode(buf); - return new OpenBlocks(appId, execId, blockIds); + boolean fetchContinuousShuffleBlocksInBatch = false; + if (buf.readableBytes() >= 1) { + // A sanity check. In `encode` we write true, so here we should read true. + assert buf.readBoolean(); + fetchContinuousShuffleBlocksInBatch = true; + } else { + fetchContinuousShuffleBlocksInBatch = false; + } + return new OpenBlocks(appId, execId, blockIds, fetchContinuousShuffleBlocksInBatch); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java index 02e6eb3a4467..8b711b0a3aec 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java @@ -192,11 +192,11 @@ public void testAppIsolation() throws Exception { CountDownLatch blockFetchLatch = new CountDownLatch(1); BlockFetchingListener listener = new BlockFetchingListener() { @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + public void onBlockFetchSuccess(String[] blockIds, ManagedBuffer data) { blockFetchLatch.countDown(); } @Override - public void onBlockFetchFailure(String blockId, Throwable t) { + public void onBlockFetchFailure(String[] blockIds, Throwable t) { exception.set(t); blockFetchLatch.countDown(); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java index 86c8609e7070..1bdede181113 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/BlockTransferMessagesSuite.java @@ -17,17 +17,25 @@ package org.apache.spark.network.shuffle; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + import org.junit.Test; import static org.junit.Assert.*; import org.apache.spark.network.shuffle.protocol.*; +import java.nio.ByteBuffer; + /** Verifies that all BlockTransferMessages can be serialized correctly. */ public class BlockTransferMessagesSuite { @Test public void serializeOpenShuffleBlocks() { - checkSerializeDeserialize(new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" })); + checkSerializeDeserialize( + new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }, false)); + checkSerializeDeserialize( + new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }, true)); checkSerializeDeserialize(new RegisterExecutor("app-1", "exec-2", new ExecutorShuffleInfo( new String[] { "/local1", "/local2" }, 32, "MyShuffleManager"))); checkSerializeDeserialize(new UploadBlock("app-1", "exec-2", "block-3", new byte[] { 1, 2 }, @@ -41,4 +49,34 @@ private void checkSerializeDeserialize(BlockTransferMessage msg) { assertEquals(msg.hashCode(), msg2.hashCode()); assertEquals(msg.toString(), msg2.toString()); } + + private BlockTransferMessage fromByteBuffer(ByteBuffer msg) { + ByteBuf buf = Unpooled.wrappedBuffer(msg); + byte type = buf.readByte(); + switch (type) { + case 0: return TestOpenBlocks.decode(buf); + default: throw new IllegalArgumentException("Unknown message type: " + type); + } + } + + private void verifyOpenBlocks(OpenBlocks ob1, TestOpenBlocks ob2) { + assertEquals(ob1.appId, ob2.appId); + assertEquals(ob1.execId, ob2.execId); + assertArrayEquals(ob1.blockIds, ob2.blockIds); + } + + @Test + public void checkOpenBlocksBackwardCompatibility() { + TestOpenBlocks testOpenBlocks = + new TestOpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }); + OpenBlocks openBlocks = + (OpenBlocks) BlockTransferMessage.Decoder.fromByteBuffer(testOpenBlocks.toByteBuffer()); + verifyOpenBlocks(openBlocks, testOpenBlocks); + assertEquals(openBlocks.fetchContinuousShuffleBlocksInBatch, false); + + openBlocks = new OpenBlocks("app-1", "exec-2", new String[] { "b1", "b2" }, true); + testOpenBlocks = (TestOpenBlocks) fromByteBuffer(openBlocks.toByteBuffer()); + verifyOpenBlocks(openBlocks, testOpenBlocks); + } + } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 537c277cd26b..58819e8f4836 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -85,8 +85,8 @@ public void testOpenShuffleBlocks() { ManagedBuffer block0Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[3])); ManagedBuffer block1Marker = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0)).thenReturn(block0Marker); - when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1)).thenReturn(block1Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 0, 1)).thenReturn(block0Marker); + when(blockResolver.getBlockData("app0", "exec1", 0, 0, 1, 1)).thenReturn(block1Marker); ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "shuffle_0_0_0", "shuffle_0_0_1" }) .toByteBuffer(); @@ -109,8 +109,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 0, 1); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", 0, 0, 1, 1); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index 459629c5f05f..65d17e366818 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -111,6 +111,13 @@ public void testSortShuffleBlocks() throws IOException { CharStreams.toString(new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); assertEquals(sortBlock1, block1); } + + try (InputStream block01Stream = resolver.getBlockData( + "app0", "exec0", 0, 0, 0, 2).createInputStream()) { + String block01 = + CharStreams.toString(new InputStreamReader(block01Stream, StandardCharsets.UTF_8)); + assertEquals(sortBlock0 + sortBlock1, block01); + } } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index 526b96b36447..143ed11cf34b 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -116,7 +116,15 @@ public void releaseBuffers() { // Fetch a set of blocks from a pre-registered executor. private FetchResult fetchBlocks(String execId, String[] blockIds) throws Exception { - return fetchBlocks(execId, blockIds, conf, server.getPort()); + return fetchBlocks(execId, blockIds, conf, server.getPort(), false); + } + + private FetchResult fetchBlocks( + String execId, + String[] blockIds, + boolean fetchContinuousShuffleBlocksInBatch) throws Exception { + return fetchBlocks(execId, blockIds, conf, server.getPort(), + fetchContinuousShuffleBlocksInBatch); } // Fetch a set of blocks from a pre-registered executor. Connects to the server on the given port, @@ -125,7 +133,8 @@ private FetchResult fetchBlocks( String execId, String[] blockIds, TransportConf clientConf, - int port) throws Exception { + int port, + boolean fetchContinuousShuffleBlocksInBatch) throws Exception { final FetchResult res = new FetchResult(); res.successBlocks = Collections.synchronizedSet(new HashSet()); res.failedBlocks = Collections.synchronizedSet(new HashSet()); @@ -138,27 +147,31 @@ private FetchResult fetchBlocks( client.fetchBlocks(TestUtils.getLocalHost(), port, execId, blockIds, new BlockFetchingListener() { @Override - public void onBlockFetchSuccess(String blockId, ManagedBuffer data) { + public void onBlockFetchSuccess(String[] blockIds, ManagedBuffer data) { synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - data.retain(); - res.successBlocks.add(blockId); - res.buffers.add(data); - requestsRemaining.release(); + data.retain(); + res.buffers.add(data); + for (String blockId : blockIds) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.successBlocks.add(blockId); + requestsRemaining.release(); + } } } } @Override - public void onBlockFetchFailure(String blockId, Throwable exception) { + public void onBlockFetchFailure(String[] blockIds, Throwable exception) { synchronized (this) { - if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { - res.failedBlocks.add(blockId); - requestsRemaining.release(); + for (String blockId : blockIds) { + if (!res.successBlocks.contains(blockId) && !res.failedBlocks.contains(blockId)) { + res.failedBlocks.add(blockId); + requestsRemaining.release(); + } } } } - }, null); + }, null, fetchContinuousShuffleBlocksInBatch); if (!requestsRemaining.tryAcquire(blockIds.length, 5, TimeUnit.SECONDS)) { fail("Timeout getting response from the server"); @@ -235,11 +248,39 @@ public void testFetchNoServer() throws Exception { new MapConfigProvider(ImmutableMap.of("spark.shuffle.io.maxRetries", "0"))); registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", - new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */); + new String[]{"shuffle_1_0_0", "shuffle_1_0_1"}, clientConf, 1 /* port */, false); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); } + private byte[] mergeBlockData(byte[][] blocks) { + int totalLength = 0; + for (byte[] block : blocks) { + totalLength += block.length; + } + + byte[] data = new byte[totalLength]; + int pos = 0; + for (byte[] block : blocks) { + System.arraycopy(block, 0, data, pos, block.length); + pos += block.length; + } + + return data; + } + + @Test + public void testBatchFetchThreeSort() throws Exception { + registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult exec0Fetch = fetchBlocks("exec-0", + new String[] { "shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2" }, true); + assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_0_0_1", "shuffle_0_0_2"), + exec0Fetch.successBlocks); + assertTrue(exec0Fetch.failedBlocks.isEmpty()); + assertBufferListsEqual(exec0Fetch.buffers, Arrays.asList(mergeBlockData(exec0Blocks))); + exec0Fetch.releaseBuffers(); + } + private static void registerExecutor(String executorId, ExecutorShuffleInfo executorInfo) throws IOException, InterruptedException { ExternalShuffleClient client = new ExternalShuffleClient(conf, null, false, 5000); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java index 95460637db89..6f61c24f8859 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockFetcherSuite.java @@ -31,7 +31,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.AdditionalMatchers.aryEq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -60,7 +60,8 @@ public void testFetchOne() { BlockFetchingListener listener = fetchBlocks(blocks); - verify(listener).onBlockFetchSuccess("shuffle_0_0_0", blocks.get("shuffle_0_0_0")); + verify(listener).onBlockFetchSuccess(new String[] { "shuffle_0_0_0" }, + blocks.get("shuffle_0_0_0")); } @Test @@ -73,7 +74,7 @@ public void testFetchThree() { BlockFetchingListener listener = fetchBlocks(blocks); for (int i = 0; i < 3; i ++) { - verify(listener, times(1)).onBlockFetchSuccess("b" + i, blocks.get("b" + i)); + verify(listener, times(1)).onBlockFetchSuccess(new String[] { "b" + i }, blocks.get("b" + i)); } } @@ -87,9 +88,9 @@ public void testFailure() { BlockFetchingListener listener = fetchBlocks(blocks); // Each failure will cause a failure to be invoked in all remaining block fetches. - verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); - verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any()); + verify(listener, times(1)).onBlockFetchSuccess(new String[] { "b0" }, blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(aryEq(new String[] { "b1", "b2" }), any()); + verify(listener, times(1)).onBlockFetchFailure(aryEq(new String[] { "b2" }), any()); } @Test @@ -102,10 +103,9 @@ public void testFailureAndSuccess() { BlockFetchingListener listener = fetchBlocks(blocks); // We may call both success and failure for the same block. - verify(listener, times(1)).onBlockFetchSuccess("b0", blocks.get("b0")); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); - verify(listener, times(1)).onBlockFetchSuccess("b2", blocks.get("b2")); - verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any()); + verify(listener, times(1)).onBlockFetchSuccess(new String[] { "b0" }, blocks.get("b0")); + verify(listener, times(1)).onBlockFetchFailure(aryEq(new String[] { "b1", "b2" }), any()); + verify(listener, times(1)).onBlockFetchSuccess(new String[] { "b2" }, blocks.get("b2")); } @Test diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index a530e16734db..f25dc2c86969 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -32,6 +32,7 @@ import org.mockito.stubbing.Stubber; import static org.junit.Assert.*; +import static org.mockito.AdditionalMatchers.aryEq; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.ManagedBuffer; @@ -64,8 +65,8 @@ public void testNoFailures() throws IOException, InterruptedException { performInteractions(interactions, listener); - verify(listener).onBlockFetchSuccess("b0", block0); - verify(listener).onBlockFetchSuccess("b1", block1); + verify(listener).onBlockFetchSuccess(new String[] { "b0" }, block0); + verify(listener).onBlockFetchSuccess(new String[] { "b1" }, block1); verifyNoMoreInteractions(listener); } @@ -83,8 +84,8 @@ public void testUnrecoverableFailure() throws IOException, InterruptedException performInteractions(interactions, listener); - verify(listener).onBlockFetchFailure(eq("b0"), any()); - verify(listener).onBlockFetchSuccess("b1", block1); + verify(listener).onBlockFetchFailure(aryEq(new String[] { "b0" }), any()); + verify(listener).onBlockFetchSuccess(new String[] { "b1" }, block1); verifyNoMoreInteractions(listener); } @@ -106,8 +107,8 @@ public void testSingleIOExceptionOnFirst() throws IOException, InterruptedExcept performInteractions(interactions, listener); - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b0" }, block0); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b1" }, block1); verifyNoMoreInteractions(listener); } @@ -128,8 +129,8 @@ public void testSingleIOExceptionOnSecond() throws IOException, InterruptedExcep performInteractions(interactions, listener); - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b0" }, block0); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b1" }, block1); verifyNoMoreInteractions(listener); } @@ -156,8 +157,8 @@ public void testTwoIOExceptions() throws IOException, InterruptedException { performInteractions(interactions, listener); - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchSuccess("b1", block1); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b0" }, block0); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b1" }, block1); verifyNoMoreInteractions(listener); } @@ -188,8 +189,8 @@ public void testThreeIOExceptions() throws IOException, InterruptedException { performInteractions(interactions, listener); - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), any()); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b0" }, block0); + verify(listener, timeout(5000)).onBlockFetchFailure(aryEq(new String[] { "b1" }), any()); verifyNoMoreInteractions(listener); } @@ -218,9 +219,9 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException performInteractions(interactions, listener); - verify(listener, timeout(5000)).onBlockFetchSuccess("b0", block0); - verify(listener, timeout(5000)).onBlockFetchFailure(eq("b1"), any()); - verify(listener, timeout(5000)).onBlockFetchSuccess("b2", block2); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b0" }, block0); + verify(listener, timeout(5000)).onBlockFetchFailure(aryEq(new String[] { "b1" }), any()); + verify(listener, timeout(5000)).onBlockFetchSuccess(new String[] { "b2" }, block2); verifyNoMoreInteractions(listener); } @@ -268,9 +269,10 @@ private static void performInteractions(List> inte Object blockValue = block.getValue(); if (blockValue instanceof ManagedBuffer) { - retryListener.onBlockFetchSuccess(blockId, (ManagedBuffer) blockValue); + retryListener.onBlockFetchSuccess(new String[] { blockId }, + (ManagedBuffer) blockValue); } else if (blockValue instanceof Exception) { - retryListener.onBlockFetchFailure(blockId, (Exception) blockValue); + retryListener.onBlockFetchFailure(new String[] { blockId }, (Exception) blockValue); } else { fail("Can only handle ManagedBuffers and Exceptions, got " + blockValue); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/TestOpenBlocks.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/TestOpenBlocks.java new file mode 100644 index 000000000000..89f2a4f76947 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/TestOpenBlocks.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + +/** TestOpenBlocks is used to test OpenBlocks backward compatibility only */ +public class TestOpenBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final String[] blockIds; + + public TestOpenBlocks(String appId, String execId, String[] blockIds) { + this.appId = appId; + this.execId = execId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.OPEN_BLOCKS; } + + @Override + public int hashCode() { + return Objects.hashCode(appId, execId) * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("execId", execId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other != null && other instanceof OpenBlocks) { + OpenBlocks o = (OpenBlocks) other; + return Objects.equal(appId, o.appId) + && Objects.equal(execId, o.execId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static TestOpenBlocks decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String execId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new TestOpenBlocks(appId, execId, blockIds); + } +} diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index a58c8fa2e763..6f747f14f201 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -68,7 +68,8 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: DownloadFileManager): Unit + tempFileManager: DownloadFileManager, + fetchContinuousShuffleBlocksInBatch: Boolean): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -97,10 +98,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), new BlockFetchingListener { - override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + override def onBlockFetchFailure(blockIds: Array[String], exception: Throwable): Unit = { result.failure(exception) } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + override def onBlockFetchSuccess(blockIds: Array[String], data: ManagedBuffer): Unit = { data match { case f: FileSegmentManagedBuffer => result.success(f) @@ -113,7 +114,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.success(new NioManagedBuffer(ret)) } } - }, tempFileManager) + }, tempFileManager, false) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 27f4f94ea55f..e520752c871b 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -20,6 +20,7 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.language.existentials import scala.reflect.ClassTag @@ -30,7 +31,7 @@ import org.apache.spark.network.client.{RpcResponseCallback, StreamCallbackWithI import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager} import org.apache.spark.network.shuffle.protocol._ import org.apache.spark.serializer.Serializer -import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage._ /** * Serves requests to open blocks by simply registering one chunk per block requested. @@ -56,11 +57,17 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocksNum = openBlocks.blockIds.length + val blockIds = if (openBlocks.fetchContinuousShuffleBlocksInBatch) { + BlockManager.mergeContinuousShuffleBlockIds( + openBlocks.blockIds.iterator.map(BlockId.apply)) + } else { + openBlocks.blockIds.map(BlockId.apply) + } + val blocksNum = blockIds.length val blocks = for (i <- (0 until blocksNum).view) - yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) - val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, - client.getChannel) + yield blockManager.getBlockData(blockIds(i)) + val streamId = streamManager.registerStream( + appId, blocks.iterator.asJava, client.getChannel) logTrace(s"Registered streamId $streamId with $blocksNum buffers") responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index dc55685b1e7b..2f92aff0bcb0 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -106,14 +106,15 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: DownloadFileManager): Unit = { + tempFileManager: DownloadFileManager, + fetchContinuousShuffleBlocksInBatch: Boolean): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempFileManager).start() + transportConf, tempFileManager, fetchContinuousShuffleBlocksInBatch).start() } } @@ -128,7 +129,7 @@ private[spark] class NettyBlockTransferService( } catch { case e: Exception => logError("Exception while beginning fetchBlocks", e) - blockIds.foreach(listener.onBlockFetchFailure(_, e)) + listener.onBlockFetchFailure(blockIds, e) } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index daafe305c8f8..8667f5ae7a6a 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -19,6 +19,7 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} +import org.apache.spark.io.CompressionCodec import org.apache.spark.serializer.SerializerManager import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator @@ -41,8 +42,24 @@ private[spark] class BlockStoreShuffleReader[K, C]( private val dep = handle.dependency + private def supportsConcatenationOfSerializedStreams(conf: SparkConf): Boolean = { + val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + CompressionCodec.supportsConcatenationOfSerializedStreams(compressionCodec) + } + + private def shouldFetchContinuousShuffleBlocksInBatch: Boolean = { + val conf = SparkEnv.get.conf + val serializerRelocatable = dep.serializer.supportsRelocationOfSerializedObjects + val compressed = conf.getBoolean("spark.shuffle.compress", true) + + endPartition - startPartition > 1 && + serializerRelocatable && + (!compressed || supportsConcatenationOfSerializedStreams(conf)) + } + /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val fetchContinuousShuffleBlocksInBatch = shouldFetchContinuousShuffleBlocksInBatch val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.shuffleClient, @@ -55,7 +72,8 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.get(config.SHUFFLE_DETECT_CORRUPT), - readMetrics) + readMetrics, + fetchContinuousShuffleBlocksInBatch) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d3f1c7ec1bbe..46a89c8290ee 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -190,10 +190,25 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData(blockId: BlockId): ManagedBuffer = { + blockId match { + case ArrayShuffleBlockId(blockIds) => + val startReducer = blockIds.head + val endReducer = blockIds.last + getBlockData(startReducer.shuffleId, startReducer.mapId, startReducer.reduceId, + endReducer.reduceId - startReducer.reduceId + 1) + case ShuffleBlockId(shuffleId, mapId, reduceId) => + getBlockData(shuffleId, mapId, reduceId, 1) + case _ => + throw new IllegalArgumentException("unexpected shuffle block id format: " + blockId); + } + } + + private def getBlockData(shuffleId: Int, mapId: Int, reduceId: Int, numReducers: Int) + : ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + val indexFile = getIndexFile(shuffleId, mapId) // SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code // which is incorrectly using our file descriptor then this code will fetch the wrong offsets @@ -202,20 +217,22 @@ private[spark] class IndexShuffleBlockResolver( // class of issue from re-occurring in the future which is why they are left here even though // SPARK-22982 is fixed. val channel = Files.newByteChannel(indexFile.toPath) - channel.position(blockId.reduceId * 8L) + channel.position(reduceId * 8L) val in = new DataInputStream(Channels.newInputStream(channel)) try { val offset = in.readLong() + val endReduceId = reduceId + numReducers + channel.position(endReduceId * 8L) val nextOffset = in.readLong() val actualPosition = channel.position() - val expectedPosition = blockId.reduceId * 8L + 16 + val expectedPosition = endReduceId * 8 + 8 if (actualPosition != expectedPosition) { throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " + s"expected $expectedPosition but actual position was $actualPosition.") } new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), + getDataFile(shuffleId, mapId), offset, nextOffset - offset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index d1ecbc1bf017..c50789658d61 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -18,7 +18,7 @@ package org.apache.spark.shuffle import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.BlockId private[spark] /** @@ -34,7 +34,7 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(blockId: BlockId): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 7ac2c71c18eb..399fa64d4293 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -38,7 +38,7 @@ sealed abstract class BlockId { // convenience methods def asRDDId: Option[RDDBlockId] = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD: Boolean = isInstanceOf[RDDBlockId] - def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] + def isShuffle: Boolean = isInstanceOf[ShuffleBlockId] || isInstanceOf[ArrayShuffleBlockId] def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name @@ -56,6 +56,17 @@ case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends Blo override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } +@DeveloperApi +case class ArrayShuffleBlockId(blockIds: Seq[ShuffleBlockId]) extends BlockId { + override def name: String = { + if (blockIds.length == 1) { + blockIds.head.toString + } else { + "array_shuffle_" + blockIds.head + "-" + blockIds.last + } + } +} + @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8f993bfbf08a..a7fbfb45a557 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -25,7 +25,7 @@ import java.util.Collections import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable -import scala.collection.mutable.HashMap +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import scala.reflect.ClassTag @@ -377,7 +377,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData(blockId) } else { getLocalBytes(blockId) match { case Some(blockData) => @@ -637,7 +637,7 @@ private[spark] class BlockManager( // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. val buf = new ChunkedByteBuffer( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + shuffleBlockResolver.getBlockData(blockId).nioByteBuffer()) Some(new ByteBufferBlockData(buf, true)) } else { blockInfoManager.lockForReading(blockId).map { info => doGetLocalBytes(blockId, info) } @@ -1640,6 +1640,27 @@ private[spark] class BlockManager( private[spark] object BlockManager { private val ID_GENERATOR = new IdGenerator + def mergeContinuousShuffleBlockIds(iter: Iterator[BlockId]): Array[ArrayShuffleBlockId] = { + var shuffleBlockIds = new ArrayBuffer[ShuffleBlockId] + val arrayShuffleBlockIds = new ArrayBuffer[ArrayShuffleBlockId] + + while (iter.hasNext) { + val blockId = iter.next().asInstanceOf[ShuffleBlockId] + if (shuffleBlockIds.isEmpty) { + shuffleBlockIds += blockId + } else { + if (blockId.mapId != shuffleBlockIds.head.mapId) { + arrayShuffleBlockIds += ArrayShuffleBlockId(shuffleBlockIds) + shuffleBlockIds = new ArrayBuffer[ShuffleBlockId] + } + shuffleBlockIds += blockId + } + } + arrayShuffleBlockIds += ArrayShuffleBlockId(shuffleBlockIds) + + arrayShuffleBlockIds.toArray + } + def blockIdsToLocations( blockIds: Array[BlockId], env: SparkEnv, diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 86f7c08eddcb..18c1454139a0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -60,6 +60,8 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. * @param shuffleMetrics used to report shuffle metrics. + * @param fetchContinuousShuffleBlocksInBatch fetch continuous shuffle blocks in batch if server + * side supports. */ private[spark] final class ShuffleBlockFetcherIterator( @@ -73,7 +75,8 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean, - shuffleMetrics: ShuffleReadMetricsReporter) + shuffleMetrics: ShuffleReadMetricsReporter, + fetchContinuousShuffleBlocksInBatch: Boolean) extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -223,11 +226,17 @@ final class ShuffleBlockFetcherIterator( // so we can look up the size of each blockID val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap val remainingBlocks = new HashSet[String]() ++= sizeMap.keys - val blockIds = req.blocks.map(_._1.toString) + val blockIds = if (!req.retry) { + req.blocks.map(_._1.toString) + } else { + req.blocks.head._1.asInstanceOf[ArrayShuffleBlockId].blockIds.map(_.toString) + } val address = req.address val blockFetchingListener = new BlockFetchingListener { - override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = { + override def onBlockFetchSuccess(blockIds: Array[String], buf: ManagedBuffer): Unit = { + val blockId = + ArrayShuffleBlockId(blockIds.map(BlockId.apply(_).asInstanceOf[ShuffleBlockId])) // Only add the buffer to results queue if the iterator is not zombie, // i.e. cleanup() has not been called yet. ShuffleBlockFetcherIterator.this.synchronized { @@ -235,8 +244,13 @@ final class ShuffleBlockFetcherIterator( // Increment the ref count because we need to pass this to a different thread. // This needs to be released after use. buf.retain() - remainingBlocks -= blockId - results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf, + remainingBlocks --= blockIds + val size = if (!req.retry) { + blockIds.map(sizeMap).sum + } else { + sizeMap(blockId.toString) + } + results.put(new SuccessFetchResult(blockId, address, size, buf, remainingBlocks.isEmpty)) logDebug("remainingBlocks: " + remainingBlocks) } @@ -244,9 +258,11 @@ final class ShuffleBlockFetcherIterator( logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime)) } - override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = { + override def onBlockFetchFailure(blockIds: Array[String], e: Throwable): Unit = { logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e) - results.put(new FailureFetchResult(BlockId(blockId), address, e)) + val blockId = + ArrayShuffleBlockId(blockIds.map(BlockId.apply(_).asInstanceOf[ShuffleBlockId])) + results.put(new FailureFetchResult(blockId, address, e)) } } @@ -255,10 +271,10 @@ final class ShuffleBlockFetcherIterator( // the data and write it to file directly. if (req.size > maxReqSizeShuffleToMem) { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, this) + blockFetchingListener, this, fetchContinuousShuffleBlocksInBatch) } else { shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray, - blockFetchingListener, null) + blockFetchingListener, null, fetchContinuousShuffleBlocksInBatch) } } @@ -322,6 +338,25 @@ final class ShuffleBlockFetcherIterator( remoteRequests } + private[this] def fetchBlockData(arrayBlockId: ArrayShuffleBlockId): Boolean = { + var success = true + try { + val buf = blockManager.getBlockData(arrayBlockId) + shuffleMetrics.incLocalBlocksFetched(arrayBlockId.blockIds.length) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(new SuccessFetchResult(arrayBlockId, blockManager.blockManagerId, + buf.size(), buf, false)) + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(new FailureFetchResult(arrayBlockId, blockManager.blockManagerId, e)) + success = false + } + success + } + /** * Fetch the local blocks while we are fetching remote blocks. This is ok because * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we @@ -330,21 +365,14 @@ final class ShuffleBlockFetcherIterator( private[this] def fetchLocalBlocks() { logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator - while (iter.hasNext) { - val blockId = iter.next() - try { - val buf = blockManager.getBlockData(blockId) - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, - buf.size(), buf, false)) - } catch { - case e: Exception => - // If we see an exception, stop immediately. - logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) - return + if (fetchContinuousShuffleBlocksInBatch) { + BlockManager.mergeContinuousShuffleBlockIds(iter).foreach { arrayShuffleBlockId => + if (!fetchBlockData(arrayShuffleBlockId)) return + } + } else { + while (iter.hasNext) { + val blockId = iter.next().asInstanceOf[ShuffleBlockId] + if (!fetchBlockData(ArrayShuffleBlockId(Array(blockId)))) return } } } @@ -387,8 +415,6 @@ final class ShuffleBlockFetcherIterator( throw new NoSuchElementException } - numBlocksProcessed += 1 - var result: FetchResult = null var input: InputStream = null // Take the next fetched result and try to decompress it to detect data corruption, @@ -402,16 +428,16 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => + case _ @ SuccessFetchResult(arrayBlockId, address, size, buf, isNetworkReqDone) => + val numBlocksFetched = arrayBlockId.blockIds.length + numBlocksProcessed += numBlocksFetched if (address != blockManager.blockManagerId) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { shuffleMetrics.incRemoteBytesReadToDisk(buf.size) } - shuffleMetrics.incRemoteBlocksFetched(1) - } - if (!localBlocks.contains(blockId)) { + shuffleMetrics.incRemoteBlocksFetched(numBlocksFetched) bytesInFlight -= size } if (isNetworkReqDone) { @@ -433,9 +459,9 @@ final class ShuffleBlockFetcherIterator( // uses are shared by the UnsafeShuffleWriter (both writers use DiskBlockObjectWriter // which returns a zero-size from commitAndGet() in case no records were written // since the last call. - val msg = s"Received a zero-size buffer for block $blockId from $address " + + val msg = s"Received a zero-size buffer for block $arrayBlockId from $address " + s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)" - throwFetchFailedException(blockId, address, new IOException(msg)) + throwFetchFailedException(arrayBlockId, address, new IOException(msg)) } val in = try { @@ -446,11 +472,11 @@ final class ShuffleBlockFetcherIterator( assert(buf.isInstanceOf[FileSegmentManagedBuffer]) logError("Failed to create input stream from local block", e) buf.release() - throwFetchFailedException(blockId, address, e) + throwFetchFailedException(arrayBlockId, address, e) } var isStreamCopied: Boolean = false try { - input = streamWrapper(blockId, in) + input = streamWrapper(arrayBlockId.blockIds.head, in) // Only copy the stream if it's wrapped by compression or encryption, also the size of // block is small (the decompressed block is smaller than maxBytesInFlight) if (detectCorrupt && !input.eq(in) && size < maxBytesInFlight / 3) { @@ -465,15 +491,17 @@ final class ShuffleBlockFetcherIterator( } catch { case e: IOException => buf.release() - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, address, e) - } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest(address, Array((blockId, size))) - result = null + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + throwFetchFailedException(arrayBlockId, address, e) + } + if (corruptedBlocks.contains(arrayBlockId)) { + throwFetchFailedException(arrayBlockId, address, e) } + logWarning(s"got an corrupted blocks from $address, fetch again", e) + corruptedBlocks += arrayBlockId + fetchRequests += FetchRequest(address, Array((arrayBlockId, size)), retry = true) + numBlocksToFetch += numBlocksFetched + result = null } finally { // TODO: release the buf here to free memory earlier if (isStreamCopied) { @@ -481,8 +509,9 @@ final class ShuffleBlockFetcherIterator( } } - case FailureFetchResult(blockId, address, e) => - throwFetchFailedException(blockId, address, e) + case FailureFetchResult(arrayBlockId, address, e) => + numBlocksProcessed += arrayBlockId.blockIds.length + throwFetchFailedException(arrayBlockId, address, e) } // Send fetch requests up to maxBytesInFlight @@ -490,7 +519,7 @@ final class ShuffleBlockFetcherIterator( } currentResult = result.asInstanceOf[SuccessFetchResult] - (currentResult.blockId, new BufferReleasingInputStream(input, this)) + (currentResult.arrayBlockId, new BufferReleasingInputStream(input, this)) } private def fetchUpToMaxBytes(): Unit = { @@ -552,6 +581,9 @@ final class ShuffleBlockFetcherIterator( blockId match { case ShuffleBlockId(shufId, mapId, reduceId) => throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e) + case ArrayShuffleBlockId(blockIds) => + val bid = blockIds.head + throw new FetchFailedException(address, bid.shuffleId, bid.mapId, bid.reduceId, e) case _ => throw new SparkException( "Failed to get block " + blockId + ", which is not a shuffle block", e) @@ -602,7 +634,10 @@ object ShuffleBlockFetcherIterator { * @param blocks Sequence of tuple, where the first element is the block id, * and the second element is the estimated size, used to calculate bytesInFlight. */ - case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) { + case class FetchRequest( + address: BlockManagerId, + blocks: Seq[(BlockId, Long)], + retry: Boolean = false) { val size = blocks.map(_._2).sum } @@ -610,21 +645,21 @@ object ShuffleBlockFetcherIterator { * Result of a fetch from a remote block. */ private[storage] sealed trait FetchResult { - val blockId: BlockId + val arrayBlockId: ArrayShuffleBlockId val address: BlockManagerId } /** * Result of a fetch from a remote block successfully. * @param blockId block id - * @param address BlockManager that the block was fetched from. - * @param size estimated size of the block. Note that this is NOT the exact bytes. - * Size of remote block is used to calculate bytesInFlight. + * @param address BlockManager that the block(s) was fetched from. + * @param size estimated size of the block(s). Note that this is NOT the exact bytes. + * Size of remote block(s) is used to calculate bytesInFlight. * @param buf `ManagedBuffer` for the content. * @param isNetworkReqDone Is this the last network request for this host in this fetch request. */ private[storage] case class SuccessFetchResult( - blockId: BlockId, + arrayBlockId: ArrayShuffleBlockId, address: BlockManagerId, size: Long, buf: ManagedBuffer, @@ -640,7 +675,7 @@ object ShuffleBlockFetcherIterator { * @param e the failure exception */ private[storage] case class FailureFetchResult( - blockId: BlockId, + arrayBlockId: ArrayShuffleBlockId, address: BlockManagerId, e: Throwable) extends FetchResult diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index f1cf14de1f87..9d8cfd5f9d34 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -159,14 +159,14 @@ class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar wi self.fetchBlocks(from.hostName, from.port, execId, Array(blockId.toString), new BlockFetchingListener { - override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = { + override def onBlockFetchFailure(blockIds: Array[String], exception: Throwable): Unit = { promise.failure(exception) } - override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { + override def onBlockFetchSuccess(blockIds: Array[String], data: ManagedBuffer): Unit = { promise.success(data.retain()) } - }, null) + }, null, false) ThreadUtils.awaitReady(promise.future, FiniteDuration(10, TimeUnit.SECONDS)) promise.future.value.get diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 6d2ef17a7a79..f8c4834f6ee0 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.serializer.{JavaSerializer, SerializerManager} -import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId} +import org.apache.spark.storage.{ArrayShuffleBlockId, BlockManager, BlockManagerId, ShuffleBlockId} /** * Wrapper for a managed buffer that keeps track of how many times retain and release are called. @@ -94,15 +94,16 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to // fetch shuffle data. - val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + val arrayShuffleBlockId = ArrayShuffleBlockId(Seq(ShuffleBlockId(shuffleId, mapId, reduceId))) + when(blockManager.getBlockData(arrayShuffleBlockId)).thenReturn(managedBuffer) managedBuffer } // Make a mocked MapOutputTracker for the shuffle reader to use to determine what // shuffle data to read. val mapOutputTracker = mock(classOf[MapOutputTracker]) - when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)).thenReturn { + when(mapOutputTracker.getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)) + .thenReturn { // Test a scenario where all data is local, to avoid creating a bunch of additional mocks // for the code to read data over the network. val shuffleBlockIdsAndSizes = (0 until numMaps).map { mapId => diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index ff4755833a91..2e4ad2eb0ae3 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -51,7 +51,7 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } - test("shuffle") { + test("shuffle block") { val id = ShuffleBlockId(1, 2, 3) assertSame(id, ShuffleBlockId(1, 2, 3)) assertDifferent(id, ShuffleBlockId(3, 2, 3)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 04de0e41a341..ae95d07bec18 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -1432,8 +1432,10 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempFileManager: DownloadFileManager): Unit = { - listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) + tempFileManager: DownloadFileManager, + fetchContinuousShuffleBlocksInBatch: Boolean): Unit = { + listener.onBlockFetchSuccess( + Array("mockBlockId"), new NioManagedBuffer(ByteBuffer.allocate(1))) } override def close(): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6b83243fe496..57f8dde2351e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -49,7 +49,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] @@ -57,9 +57,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (blockId <- blocks) { if (data.contains(BlockId(blockId))) { - listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + listener.onBlockFetchSuccess(Array(blockId), data(BlockId(blockId))) } else { - listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) + listener.onBlockFetchFailure(Array(blockId), new BlockNotFoundException(blockId)) } } } @@ -89,7 +89,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + val arrayBlockId = ArrayShuffleBlockId(Seq(blockId.asInstanceOf[ShuffleBlockId])) + doReturn(buf).when(blockManager).getBlockData(meq(arrayBlockId)) } // Make sure remote blocks would return @@ -118,14 +119,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, - metrics) + metrics, + false) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getBlockData(any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") - val (blockId, inputStream) = iterator.next() + val (arrayBlockId, inputStream) = iterator.next() + val blockId = arrayBlockId.asInstanceOf[ArrayShuffleBlockId].blockIds.head // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) @@ -146,7 +149,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) verify(blockManager, times(3)).getBlockData(any()) - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) } test("release current unexhausted buffer in case the task completes early") { @@ -165,19 +168,19 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { // Return the first two blocks, and wait till task completion before returning the 3rd one listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + Array(ShuffleBlockId(0, 0, 0).toString), blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + Array(ShuffleBlockId(0, 1, 0).toString), blocks(ShuffleBlockId(0, 1, 0))) sem.acquire() listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + Array(ShuffleBlockId(0, 2, 0).toString), blocks(ShuffleBlockId(0, 2, 0))) } } }) @@ -197,7 +200,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -234,18 +238,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + Array(ShuffleBlockId(0, 0, 0).toString), blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchFailure( - ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) + Array(ShuffleBlockId(0, 1, 0).toString), new BlockNotFoundException("blah")) listener.onBlockFetchFailure( - ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah")) + Array(ShuffleBlockId(0, 2, 0).toString), new BlockNotFoundException("blah")) sem.release() } } @@ -266,7 +270,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -305,18 +310,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + Array(ShuffleBlockId(0, 0, 0).toString), blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + Array(ShuffleBlockId(0, 1, 0).toString), mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) + Array(ShuffleBlockId(0, 2, 0).toString), corruptLocalBuffer) sem.release() } } @@ -337,23 +342,25 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + val (arrayBlockId, _) = iterator.next() + val id1 = arrayBlockId.asInstanceOf[ArrayShuffleBlockId].blockIds.head assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + Array(ShuffleBlockId(0, 1, 0).toString), mockCorruptBuffer()) sem.release() } } @@ -372,7 +379,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId - doReturn(corruptBuffer).when(blockManager).getBlockData(ShuffleBlockId(0, 0, 0)) + doReturn(corruptBuffer).when(blockManager) + .getBlockData(ArrayShuffleBlockId(Seq(ShuffleBlockId(0, 0, 0)))) val localBlockLengths = Seq[Tuple2[BlockId, Long]]( ShuffleBlockId(0, 0, 0) -> corruptBuffer.size() ) @@ -402,10 +410,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // Blocks should be returned without exceptions. - assert(Set(iterator.next()._1, iterator.next()._1) === - Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) + val id1 = iterator.next()._1.asInstanceOf[ArrayShuffleBlockId].blockIds.head + val id2 = iterator.next()._1.asInstanceOf[ArrayShuffleBlockId].blockIds.head + assert(Set(id1, id2) === Set(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 1, 0))) } test("retry corrupt blocks (disabled)") { @@ -425,18 +435,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] Future { // Return the first block, and then fail. listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + Array(ShuffleBlockId(0, 0, 0).toString), blocks(ShuffleBlockId(0, 0, 0))) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + Array(ShuffleBlockId(0, 1, 0).toString), mockCorruptBuffer()) listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) + Array(ShuffleBlockId(0, 2, 0).toString), mockCorruptBuffer()) sem.release() } } @@ -457,17 +467,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, false, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // Continue only after the mock calls onBlockFetchFailure sem.acquire() // The first block should be returned without an exception - val (id1, _) = iterator.next() + val id1 = iterator.next()._1.asInstanceOf[ArrayShuffleBlockId].blockIds.head assert(id1 === ShuffleBlockId(0, 0, 0)) - val (id2, _) = iterator.next() + val id2 = iterator.next()._1.asInstanceOf[ArrayShuffleBlockId].blockIds.head assert(id2 === ShuffleBlockId(0, 1, 0)) - val (id3, _) = iterator.next() + val id3 = iterator.next()._1.asInstanceOf[ArrayShuffleBlockId].blockIds.head assert(id3 === ShuffleBlockId(0, 2, 0)) } @@ -490,14 +501,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager] Future { listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + Array(ShuffleBlockId(0, 0, 0).toString), remoteBlocks(ShuffleBlockId(0, 0, 0))) } } }) @@ -519,7 +530,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, detectCorrupt = true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) } val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( @@ -566,10 +578,96 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT Int.MaxValue, Int.MaxValue, true, - taskContext.taskMetrics.createTempShuffleReadMetrics()) + taskContext.taskMetrics.createTempShuffleReadMetrics(), + false) // All blocks fetched return zero length and should trigger a receive-side error: val e = intercept[FetchFailedException] { iterator.next() } assert(e.getMessage.contains("Received a zero-size buffer")) } + + test("adaptive execution: successful 2 local blocks + 3 remote blocks") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the blocks + val localBlocks = Map[BlockId, ManagedBuffer]( + ArrayShuffleBlockId(Seq(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 0, 1))) -> + createMockManagedBuffer(2)) + localBlocks.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + } + + // Make sure remote blocks would return + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ArrayShuffleBlockId(Seq(ShuffleBlockId(0, 3, 0), ShuffleBlockId(0, 3, 1))) -> + createMockManagedBuffer(2), + ArrayShuffleBlockId(Seq(ShuffleBlockId(0, 4, 0))) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + listener.onBlockFetchSuccess(Array("shuffle_0_3_0", "shuffle_0_3_1"), remoteBlocks( + ArrayShuffleBlockId(Seq(ShuffleBlockId(0, 3, 0), ShuffleBlockId(0, 3, 1))))) + listener.onBlockFetchSuccess(Array("shuffle_0_4_0"), + remoteBlocks(ArrayShuffleBlockId(Seq(ShuffleBlockId(0, 4, 0))))) + } + }) + + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId, + Seq(ShuffleBlockId(0, 0, 0), ShuffleBlockId(0, 0, 1)).map((_, 1.asInstanceOf[Long]))), + (remoteBmId, Seq(ShuffleBlockId(0, 3, 0), ShuffleBlockId(0, 3, 1), + ShuffleBlockId(0, 4, 0)).map((_, 1.asInstanceOf[Long]))) + ).toIterator + + val taskContext = TaskContext.empty() + val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val iterator = new ShuffleBlockFetcherIterator( + taskContext, + transfer, + blockManager, + blocksByAddress, + (_, in) => in, + 48 * 1024 * 1024, + Int.MaxValue, + Int.MaxValue, + Int.MaxValue, + true, + metrics, + true) + + // 2 local blocks fetched in initialization + verify(blockManager, times(1)).getBlockData(any()) + + for (i <- 0 until 3) { + assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") + val (blockId, inputStream) = iterator.next() + + // Make sure we release buffers when a wrapped input stream is closed. + val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream + val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] + verify(mockBuf, times(0)).release() + val delegateAccess = PrivateMethod[InputStream]('delegate) + + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(0)).close() + wrappedInputStream.close() + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + wrappedInputStream.close() // close should be idempotent + verify(mockBuf, times(1)).release() + verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() + } + assert(!iterator.hasNext, s"iterator should have 3 elements but actually has > 3 elements") + + // 2 local blocks, and 3 remote blocks + // (but from the same block manager so one call to fetchBlocks) + verify(blockManager, times(1)).getBlockData(any()) + verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any(), any()) + } }