From df971dad356126e5c23fb1bff03a0ae0594fdaab Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Sat, 9 May 2020 16:08:41 -0700 Subject: [PATCH 01/27] Magnet shuffle service fetch block protocol --- .../server/TransportRequestHandler.java | 56 +++++++++++-- .../protocol/FetchMergedBlocksMeta.java | 83 +++++++++++++++++++ .../shuffle/protocol/MergedBlocksMeta.java | 75 +++++++++++++++++ 3 files changed, 209 insertions(+), 5 deletions(-) create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index ab2deac20fcd..8407f2f441af 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -32,6 +32,7 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.*; import org.apache.spark.network.protocol.*; +import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportFrameDecoder; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -185,6 +186,18 @@ public void onFailure(Throwable e) { private void processStreamUpload(final UploadStream req) { assert (req.body() == null); try { + // Retain the original metadata buffer, since it will be used during the invocation of + // this method. Will be released later. + req.meta.retain(); + // Make a copy of the original metadata buffer. In benchmark, we noticed that + // we cannot respond the original metadata buffer back to the client, otherwise + // in cases where multiple concurrent shuffles are present, a wrong metadata might + // be sent back to client. This is related to the eager release of the metadata buffer, + // i.e., we always release the original buffer by the time the invocation of this + // method ends, instead of by the time we respond it to the client. This is necessary, + // otherwise we start seeing memory issues very quickly in benchmarks. + // TODO check if the way metadata buffer is handled can be further improved + ByteBuffer meta = cloneBuffer(req.meta.nioByteBuffer()); RpcResponseCallback callback = new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { @@ -193,13 +206,17 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + // Piggyback request metadata as part of the exception error String, so we can + // respond the metadata upon a failure without changing the existing protocol. + respond(new RpcFailure(req.requestId, + JavaUtils.encodeHeaderIntoErrorString(meta.duplicate(), e))); + req.meta.release(); } }; TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - ByteBuffer meta = req.meta.nioByteBuffer(); - StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); + StreamCallbackWithID streamHandler = + rpcHandler.receiveStream(reverseClient, meta.duplicate(), callback); if (streamHandler == null) { throw new NullPointerException("rpcHandler returned a null streamHandler"); } @@ -213,12 +230,17 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { public void onComplete(String streamId) throws IOException { try { streamHandler.onComplete(streamId); - callback.onSuccess(ByteBuffer.allocate(0)); + callback.onSuccess(meta.duplicate()); } catch (Exception ex) { IOException ioExc = new IOException("Failure post-processing complete stream;" + " failing this rpc and leaving channel active", ex); + // req.meta will be released once inside callback.onFailure. Retain it one more + // time to be released in the finally block. + req.meta.retain(); callback.onFailure(ioExc); streamHandler.onFailure(streamId, ioExc); + } finally { + req.meta.release(); } } @@ -242,12 +264,26 @@ public String getID() { } } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + try { + // It's OK to respond the original metadata buffer here, because this is still inside + // the invocation of this method. + respond(new RpcFailure(req.requestId, + JavaUtils.encodeHeaderIntoErrorString(req.meta.nioByteBuffer(), e))); + } catch (IOException ioe) { + // No exception will be thrown here. req.meta.nioByteBuffer will not throw IOException + // because it's a NettyManagedBuffer. This try-catch block is to make compiler happy. + logger.error("Error in handling failure while invoking RpcHandler#receive() on RPC id " + + req.requestId, e); + } finally { + req.meta.release(); + } // We choose to totally fail the channel, rather than trying to recover as we do in other // cases. We don't know how many bytes of the stream the client has already sent for the // stream, it's not worth trying to recover. channel.pipeline().fireExceptionCaught(e); } finally { + // Make sure we always release the original metadata buffer by the time we exit the + // invocation of this method. Otherwise, we see memory issues fairly quickly in benchmarks. req.meta.release(); } } @@ -286,6 +322,16 @@ public void onFailure(Throwable e) { } } + /** + * Make a full copy of a nio ByteBuffer. + */ + private ByteBuffer cloneBuffer(ByteBuffer buf) { + ByteBuffer clone = ByteBuffer.allocate(buf.capacity()); + clone.put(buf.duplicate()); + clone.flip(); + return clone; + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java new file mode 100644 index 000000000000..863aa1bfae77 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Request to find the meta information for the specified merged blocks. The meta information + * currently contains only the number of chunks in each merged blocks. + */ +public class FetchMergedBlocksMeta extends BlockTransferMessage { + public final String appId; + public final String[] blockIds; + + public FetchMergedBlocksMeta(String appId, String[] blockIds) { + this.appId = appId; + this.blockIds = blockIds; + } + + @Override + protected Type type() { return Type.FETCH_MERGED_BLOCKS_META; } + + @Override + public int hashCode() { + return appId.hashCode() * 41 + Arrays.hashCode(blockIds); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("blockIds", Arrays.toString(blockIds)) + .toString(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof FetchMergedBlocksMeta) { + FetchMergedBlocksMeta o = (FetchMergedBlocksMeta) other; + return Objects.equal(appId, o.appId) + && Arrays.equals(blockIds, o.blockIds); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.StringArrays.encodedLength(blockIds); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.StringArrays.encode(buf, blockIds); + } + + public static FetchMergedBlocksMeta decode(ByteBuf buf) { + String appId = Encoders.Strings.decode(buf); + String[] blockIds = Encoders.StringArrays.decode(buf); + return new FetchMergedBlocksMeta(appId, blockIds); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java new file mode 100644 index 000000000000..94c3e616491f --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; +import javax.annotation.Nonnull; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +/** + * Response of {@link FetchMergedBlocksMeta}. + */ +public class MergedBlocksMeta extends BlockTransferMessage { + + public final int[] numChunks; + + public MergedBlocksMeta(@Nonnull int[] numChunks) { + this.numChunks = numChunks; + } + + @Override + protected Type type() { return Type.MERGED_BLOCKS_META; } + + @Override + public int hashCode() { + return Arrays.hashCode(numChunks); + } + + @Override + public String toString() { + Objects.ToStringHelper helper = Objects.toStringHelper(this); + return helper.add("numChunks", Arrays.toString(numChunks)).toString(); + } + + @Override + public boolean equals(Object other) { + if (other instanceof MergedBlocksMeta) { + MergedBlocksMeta o = (MergedBlocksMeta) other; + return Arrays.equals(numChunks, o.numChunks); + } + return false; + } + + @Override + public int encodedLength() { + return Encoders.IntArrays.encodedLength(numChunks); + } + + @Override + public void encode(ByteBuf buf) { + Encoders.IntArrays.encode(buf, numChunks); + } + + public static MergedBlocksMeta decode(ByteBuf buf) { + return new MergedBlocksMeta(Encoders.IntArrays.decode(buf)); + } +} From 041ca7054180d55cbb297beb6107fe8bd3d09964 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Mon, 11 May 2020 13:37:07 -0700 Subject: [PATCH 02/27] LIHADOOP-53321 Magnet: Merge client shuffle block fetcher related changes to li-2.3.0 RB=2095445 BUG=LIHADOOP-53321 G=spark-reviewers R=chsingh,mshen A=chsingh,mshen --- .../spark/network/BlockDataManager.scala | 2 + .../spark/serializer/SerializerManager.scala | 1 + .../shuffle/BlockStoreShuffleReader.scala | 1 + .../org/apache/spark/storage/BlockId.scala | 15 +- .../storage/ShuffleBlockFetcherIterator.scala | 663 ++++++++++++++++-- .../apache/spark/storage/BlockIdSuite.scala | 13 + .../ShuffleBlockFetcherIteratorSuite.scala | 555 ++++++++++++++- 7 files changed, 1171 insertions(+), 79 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index cafb39ea82ad..4bf979e41b6c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -71,4 +71,6 @@ trait BlockDataManager { * Release locks acquired by [[putBlockData()]] and [[getLocalBlockData()]]. */ def releaseLock(blockId: BlockId, taskContext: Option[TaskContext]): Unit + + def getMergedBlockData(blockId: ShuffleBlockId): Seq[ManagedBuffer] } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 623db9d00ab5..640396a69526 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -110,6 +110,7 @@ private[spark] class SerializerManager( private def shouldCompress(blockId: BlockId): Boolean = { blockId match { case _: ShuffleBlockId => compressShuffle + case _: ShuffleBlockChunkId => compressShuffle case _: BroadcastBlockId => compressBroadcast case _: RDDBlockId => compressRdds case _: TempLocalBlockId => compressShuffleSpill diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 6782c748aff7..7f90eabf34bd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -71,6 +71,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.blockStoreClient, blockManager, + mapOutputTracker, blocksByAddress, serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 47c1b9664103..dc70a9af7e9c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -43,6 +43,7 @@ sealed abstract class BlockId { (isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId] || isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId]) } + def isShuffleChunk: Boolean = isInstanceOf[ShuffleBlockChunkId] def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name @@ -72,6 +73,15 @@ case class ShuffleBlockBatchId( } } +@Since("3.2.0") +@DeveloperApi +case class ShuffleBlockChunkId( + shuffleId: Int, + reduceId: Int, + chunkId: Int) extends BlockId { + override def name: String = "shuffleChunk_" + shuffleId + "_" + reduceId + "_" + chunkId +} + @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" @@ -152,7 +162,7 @@ class UnrecognizedBlockId(name: String) @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r - val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r @@ -160,6 +170,7 @@ object BlockId { val SHUFFLE_MERGED_DATA = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r val SHUFFLE_MERGED_INDEX = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r val SHUFFLE_MERGED_META = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r + val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -186,6 +197,8 @@ object BlockId { ShuffleMergedIndexBlockId(appId, shuffleId.toInt, reduceId.toInt) case SHUFFLE_MERGED_META(appId, shuffleId, reduceId) => ShuffleMergedMetaBlockId(appId, shuffleId.toInt, reduceId.toInt) + case SHUFFLE_CHUNK(shuffleId, reduceId, chunkId) => + ShuffleBlockChunkId(shuffleId.toInt, reduceId.toInt, chunkId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4465d76e3127..0bc0b11400b5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -24,13 +24,14 @@ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.util.{Failure, Success} import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils +import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ @@ -57,6 +58,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * block, which indicate the index in the map stage. * Note that zero-sized blocks are already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker [[MapOutputTracker]] for falling back to fetching unmerged blocks if + * we fail to fetch merged block chunks when push based shuffle is enabled. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -75,6 +78,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: BlockStoreClient, blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, @@ -108,13 +112,6 @@ final class ShuffleBlockFetcherIterator( private[this] val startTimeNs = System.nanoTime() - /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() - - /** Host local blockIds to fetch by executors, excluding zero-sized blocks. */ - private[this] val hostLocalBlocksByExecutor = - LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() - /** Host local blocks to fetch, excluding zero-sized blocks. */ private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() @@ -179,6 +176,13 @@ final class ShuffleBlockFetcherIterator( private[this] val onCompleteCallback = new ShuffleFetchCompletionListener(this) + /** A map for storing merged block shuffle chunk bitmap */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + private[this] val localShuffleMergerBlockMgrId = BlockManagerId( + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, + blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) + initialize() // Decrements the buffer reference count. @@ -248,6 +252,10 @@ final class ShuffleBlockFetcherIterator( private[this] def sendRequest(req: FetchRequest): Unit = { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) + if (req.hasMergedBlocks) { + sendFetchMergedStatusRequest(req) + return + } bytesInFlight += req.size reqsInFlight += 1 @@ -329,7 +337,14 @@ final class ShuffleBlockFetcherIterator( } case _ => - results.put(FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + results.put( + IgnoreFetchResult(block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } } } } @@ -347,19 +362,173 @@ final class ShuffleBlockFetcherIterator( } } - private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = { + private[this] def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { + logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}") + try { + results.put(MergedBlocksMetaFetchResult(shuffleId, reduceId, sizeMap(shuffleId, reduceId), + meta.getNumChunks, meta.readChunkBitmaps(), address)) + } catch { + case _: Throwable => + results.put(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + + override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { + logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", exception) + results.put(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + req.blocks.foreach(block => { + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId] + shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId, + shuffleBlockId.reduceId, mergedBlocksMetaListener) + }) + } + + /** + * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed + * to fetch. + * It calls out to map output tracker to get the list of original blocks for the + * given merged blocks, split them into remote and local blocks, and process them + * accordingly. + * The fallback happens when: + * 1. There is an exception while creating shuffle block chunk from local merged shuffle block. + * See fetchLocalBlock. + * 2. There is a failure when fetching remote shuffle block chunks. + * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk (local or + * remote). + */ + private[this] def initiateFallbackBlockFetchForMergedBlock( + blockId: BlockId, + address: BlockManagerId): Unit = { + logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId") + // Increase the blocks processed, since we will process another block in the next iteration of + // the while loop. + numBlocksProcessed += 1 + val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = + if (blockId.isShuffle) { + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, shuffleBlockId.reduceId) + } else { + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull + if (isRemote(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = removePendingChunks(shuffleChunkId, address) + if (pendingShuffleChunks.nonEmpty) { + pendingShuffleChunks.foreach { pendingBlockId => + logWarning(s"Falling back immediately for merged block $pendingBlockId") + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).orNull + assert(bitmapOfPendingChunk != null) + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + numBlocksProcessed += pendingShuffleChunks.size + } + } + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) + } + val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val fallbackHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr, + fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(fallbackRemoteReqs) + logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged " + + s"block $blockId") + // If there is any fall back block that's a local block, we get them here. The original + // invocation to fetchLocalBlocks might have already returned by this time, so we need + // to invoke it again here. + fetchLocalBlocks(fallbackLocalBlocks) + // Merged local blocks should be empty during fallback + assert(fallbackMergedLocalBlocks.isEmpty, "There should be zero merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor) + } + + private[this] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleBlockChunk(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll(req => { + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleBlockChunk(firstBlock.blockId) + }) + fetchRequestsToRemove.foreach(req => { + removedChunkIds ++= req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId]) + }) + } + + filterRequests(fetchRequests) + val defRequests = deferredFetchRequests.remove(address).orNull + if (defRequests != null) { + filterRequests(defRequests) + if (defRequests.nonEmpty) { + deferredFetchRequests(address) = defRequests + } + } + removedChunkIds + } + + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") - // Partition to local, host-local and remote blocks. Remote blocks are further split into - // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight + // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks. + // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order + // to limit the amount of data in flight val collectedRemoteRequests = new ArrayBuffer[FetchRequest] + val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]() var localBlockBytes = 0L var hostLocalBlockBytes = 0L + var mergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId for ((address, blockInfos) <- blocksByAddress) { - if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) { + if (isMergedShuffleBlockAddress(address)) { + // These are push-based merged blocks or chunks of these merged blocks. + if (address.host == blockManager.blockManagerId.host) { + checkBlockSizes(blockInfos) + val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( + blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false) + numBlocksToFetch += pushMergedBlockInfos.size + mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId) + mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum + logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " + + s"of size $mergedLocalBlockBytes") + } else { + remoteBlockBytes += blockInfos.map(_._2).sum + collectFetchReqsFromMergedBlocks(address, blockInfos, collectedRemoteRequests) + } + } else if ( + Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) { checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) @@ -375,7 +544,7 @@ final class ShuffleBlockFetcherIterator( val blocksForAddress = mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) hostLocalBlocksByExecutor += address -> blocksForAddress - hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3)) + hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3)) hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum } else { val (_, timeCost) = Utils.timeTakenMs[Unit] { @@ -386,40 +555,54 @@ final class ShuffleBlockFetcherIterator( } val (remoteBlockBytes, numRemoteBlocks) = collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) - val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes - assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks, - s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " + - s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " + - s"+ the number of remote blocks ${numRemoteBlocks}.") - logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " + - s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + - s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " + - s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks") + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + mergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert(blocksToFetchCurrentIteration == localBlocks.size + + hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " + + s"the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " + + s"the number of merged-local blocks ${mergedLocalBlocks.size} " + + s"+ the number of remote blocks ${numRemoteBlocks} ") + logInfo(s"[${context.taskAttemptId()}] Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${mergedLocalBlocks.size} (${Utils.bytesToString(mergedLocalBlockBytes)}) " + + s"local merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + if (hostLocalBlocksCurrentIteration.nonEmpty) { + this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration + } collectedRemoteRequests } private def createFetchRequest( blocks: Seq[FetchBlockInfo], - address: BlockManagerId): FetchRequest = { + address: BlockManagerId, + areMergedBlocks: Boolean = false): FetchRequest = { logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + s"with ${blocks.size} blocks") - FetchRequest(address, blocks) + FetchRequest(address, blocks, areMergedBlocks) } private def createFetchRequests( curBlocks: Seq[FetchBlockInfo], address: BlockManagerId, isLast: Boolean, - collectedRemoteRequests: ArrayBuffer[FetchRequest]): ArrayBuffer[FetchBlockInfo] = { - val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, doBatchFetch) + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + areMergedBlocks: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) numBlocksToFetch += mergedBlocks.size val retBlocks = new ArrayBuffer[FetchBlockInfo] if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { - collectedRemoteRequests += createFetchRequest(mergedBlocks, address) + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, areMergedBlocks) } else { mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks => if (blocks.length == maxBlocksInFlightPerAddress || isLast) { - collectedRemoteRequests += createFetchRequest(blocks, address) + collectedRemoteRequests += createFetchRequest(blocks, address, areMergedBlocks) } else { // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back // to `curBlocks`. @@ -431,6 +614,67 @@ final class ShuffleBlockFetcherIterator( retBlocks } + /** + * Collect the FetchRequests for push-based merged blocks from remote shuffle services. + * This method will be called to either initialize the FetchRequests with the original + * blocks or during the fallback when the fetch of merged shuffle blocks/chunks fail. + * + * @param address remote shuffle service address + * @param blockInfos shuffle block information + * @param collectedRemoteRequests queue of FetchRequests + * @return Number of blocks put into remoteRequests during this invocation + */ + private[this] def collectFetchReqsFromMergedBlocks( + address: BlockManagerId, + blockInfos: Seq[(BlockId, Long, Int)], + collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { + val iterator = blockInfos.iterator + var curRequestSize = 0L + var curBlocks = Seq.empty[FetchBlockInfo] + var mergedRequestsSize = 0L + var mergedBlocks = Seq.empty[FetchBlockInfo] + while (iterator.hasNext) { + val (blockId, size, _) = iterator.next() + assertPositiveBlockSize(blockId, size) + blockId match { + case ShuffleBlockId(_, mapId, _) => + assert(mapId == -1) + mergedBlocks = mergedBlocks ++ Seq(FetchBlockInfo(blockId, size, -1)) + mergedRequestsSize += size + if (mergedRequestsSize >= targetRemoteRequestSize || + mergedBlocks.size >= maxBlocksInFlightPerAddress) { + mergedBlocks = createFetchRequests(mergedBlocks, address, isLast = false, + collectedRemoteRequests, enableBatchFetch = false, areMergedBlocks = true) + mergedRequestsSize = mergedBlocks.map(_.size).sum + } + case ShuffleBlockChunkId(_, _, _) => + curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, -1)) + curRequestSize += size + if (curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests(curBlocks, address, isLast = false, + collectedRemoteRequests, enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case _ => + throw new SparkException( + "Failed to match block " + blockId + ", which is not a merged shuffle block or chunk" + ) + } + } + // Add in the final request + if (curBlocks.nonEmpty) { + curBlocks = createFetchRequests(curBlocks, address, isLast = true, + collectedRemoteRequests, enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + if (mergedBlocks.nonEmpty) { + mergedBlocks = createFetchRequests(mergedBlocks, address, isLast = true, + collectedRemoteRequests, enableBatchFetch = false, areMergedBlocks = true) + mergedRequestsSize = mergedBlocks.map(_.size).sum + } + } + private def collectFetchRequests( address: BlockManagerId, blockInfos: Seq[(BlockId, Long, Int)], @@ -448,13 +692,14 @@ final class ShuffleBlockFetcherIterator( val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, - collectedRemoteRequests) + collectedRemoteRequests, doBatchFetch) curRequestSize = curBlocks.map(_.size).sum } } // Add in the final request if (curBlocks.nonEmpty) { - createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests) + createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests, + doBatchFetch) } } @@ -475,7 +720,8 @@ final class ShuffleBlockFetcherIterator( * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ - private[this] def fetchLocalBlocks(): Unit = { + private[this] def fetchLocalBlocks( + localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { @@ -529,7 +775,10 @@ final class ShuffleBlockFetcherIterator( * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ - private[this] def fetchHostLocalBlocks(hostLocalDirManager: HostLocalDirManager): Unit = { + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): + Unit = { val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case (hostLocalBmId, _) => @@ -599,12 +848,100 @@ final class ShuffleBlockFetcherIterator( } } + /** + * Fetch a single local merged block generated. + * @param blockId ShuffleBlockId to be fetched + * @param localDirs Local directories where the merged shuffle files are stored + * @param blockManagerId BlockManagerId + * @return Boolean represents successful or failed fetch + */ + private[this] def fetchMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs).readChunkBitmaps() + // Fetch local merged shuffle block data as multiple chunks + val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs) + // Update total number of blocks to fetch, reflecting the multiple local chunks + numBlocksToFetch += bufs.size - 1 + for (chunkId <- bufs.indices) { + val buf = bufs(chunkId) + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId, shuffleBlockId.reduceId, + chunkId) + results.put(SuccessFetchResult(shuffleChunkId, -1, blockManagerId, buf.size(), buf, + isNetworkReqDone = false)) + chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) + } + true + } catch { + case e: Exception => + // If we see an exception with reading a local merged block, we fallback to + // fetch the original unmerged blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning(s"Error occurred while fetching local merged block, " + + s"prepare to fetch the original blocks", e) + results.put(IgnoreFetchResult(blockId, blockManagerId, 0, false)) + false + } + } + + /** + * Fetch the merged local blocks while we are fetching remote blocks. + */ + private[this] def fetchMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER) + if (cachedMergerDirs.isDefined) { + logDebug(s"Fetching local merged blocks with cached executors dir: " + + s"${cachedMergerDirs.get.mkString(", ")}") + mergedLocalBlocks.foreach(blockId => + fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId)) + } else { + logDebug(s"Asynchronous fetching local merged blocks without cached executors dir") + hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, + localShuffleMergerBlockMgrId.port, Array(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + mergedLocalBlocks.takeWhile { + case blockId => + logDebug(s"Successfully fetched local dirs: " + + s"${dirs.get(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchMergedLocalBlock(blockId, dirs(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + logDebug(s"Got local merged blocks (without cached executors' dir) in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + case Failure(throwable) => + // If we see an exception with getting the local dirs for local merged blocks, + // we fallback to fetch the original unmerged blocks. We do not report block fetch + // failure. + logWarning(s"Error occurred while getting the local dirs for local merged " + + s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable) + mergedLocalBlocks.foreach( + blockId => results.put(IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, + isNetworkReqDone = false)) + ) + } + } + } + private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. context.addTaskCompletionListener(onCompleteCallback) - - // Partition blocks by the different fetch modes: local, host-local and remote blocks. - val remoteRequests = partitionBlocksByFetchMode() + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val mergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, merged-local and remote + // blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, localBlocks, hostLocalBlocksByExecutor, mergedLocalBlocks) // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), @@ -620,11 +957,25 @@ final class ShuffleBlockFetcherIterator( (if (numDeferredRequest > 0 ) s", deferred $numDeferredRequest requests" else "")) // Get Local Blocks - fetchLocalBlocks() + fetchLocalBlocks(localBlocks) logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + fetchAllMergedLocalBlocks(mergedLocalBlocks) + } + + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): + Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) + } + } - if (hostLocalBlocks.nonEmpty) { - blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks) + // Fetch all outstanding merged local blocks + private def fetchAllMergedLocalBlocks(mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (mergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) } } @@ -661,16 +1012,29 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { - if (hostLocalBlocks.contains(blockId -> mapIndex)) { + if (isMergedShuffleBlockAddress(address) + && address.host == blockManager.blockManagerId.host) { + // It is a local merged block chunk + assert(blockId.isShuffleChunk) + shuffleMetrics.incLocalBlocksFetched( + getNumberOfBlocksInChunk(blockId.asInstanceOf[ShuffleBlockChunkId])) + shuffleMetrics.incLocalBytesRead(buf.size) + } else if (hostLocalBlocks.contains(blockId -> mapIndex)) { shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) } else { + // Could be a remote merged block chunk or remote block numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) if (buf.isInstanceOf[FileSegmentManagedBuffer]) { shuffleMetrics.incRemoteBytesReadToDisk(buf.size) } - shuffleMetrics.incRemoteBlocksFetched(1) + if (blockId.isShuffleChunk) { + shuffleMetrics.incRemoteBlocksFetched( + getNumberOfBlocksInChunk(blockId.asInstanceOf[ShuffleBlockChunkId])) + } else { + shuffleMetrics.incRemoteBlocksFetched(1) + } bytesInFlight -= size } } @@ -712,38 +1076,64 @@ final class ShuffleBlockFetcherIterator( case e: IOException => logError("Failed to create input stream from local block", e) } buf.release() - throwFetchFailedException(blockId, mapIndex, address, e) - } - try { - input = streamWrapper(blockId, in) - // If the stream is compressed or wrapped, then we optionally decompress/unwrap the - // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion - // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if - // the corruption is later, we'll still detect the corruption later in the stream. - streamCompressedOrEncrypted = !input.eq(in) - if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { - // TODO: manage the memory used here, and spill it into disk in case of OOM. - input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) - } - } catch { - case e: IOException => - buf.release() - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, mapIndex, address, e) - } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest( - address, Array(FetchBlockInfo(blockId, size, mapIndex))) + if (blockId.isShuffleChunk) { + initiateFallbackBlockFetchForMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either. result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + buf.release() + if (blockId.isShuffleChunk) { + // Retrying a corrupt block may result again in a corrupt block. For merged + // block chunks, we opt to fallback on the original shuffle blocks + // that belong to that corrupt merged block chunk immediately instead of + // retrying to fetch the corrupt chunk. This also makes the code simpler + // because the chunkMeta corresponding to a Shuffle Chunk is always removed + // from chunksMetaMap whenever a Shuffle Block Chunk gets processed. + // If we try to re-fetch a corrupt shuffle chunk, then it has to be added + // back to the chunksMetaMap. + initiateFallbackBlockFetchForMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else { + if (buf.isInstanceOf[FileSegmentManagedBuffer] + || corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } + } finally { + if (blockId.isShuffleChunk) { + chunksMetaMap.remove(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() } - } finally { - // TODO: release the buf here to free memory earlier - if (input == null) { - // Close the underlying stream if there was an issue in wrapping the stream using - // streamWrapper - in.close() } } @@ -767,6 +1157,48 @@ final class ShuffleBlockFetcherIterator( deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) defReqQueue.enqueue(request) result = null + + case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) => + if (isRemote(address)) { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + initiateFallbackBlockFetchForMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, + address, _) => + val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] = + new ArrayBuffer[(BlockId, Long, Int)]() + // Remove the original block from numBlocksToFetch. When #splitLocalRemoteBlocks + // is called to add new chunks, numBlocksToFetch is updated to account for the + // chunk requests. + numBlocksToFetch -= 1 + val approxChunkSize = blockSize / numChunks + for (i <- 0 until numChunks) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToRequest += ((blockChunkId, approxChunkSize, -1)) + } + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchReqsFromMergedBlocks(address, blocksToRequest.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address, _) => + // If we fail to fetch the merged status of a merged block, we fall back to fetching the + // unmerged blocks. + initiateFallbackBlockFetchForMergedBlock(ShuffleBlockId(shuffleId, -1, reduceId), address) + // Set result to null to force another iteration. + result = null } // Send fetch requests up to maxBytesInFlight @@ -790,6 +1222,26 @@ final class ShuffleBlockFetcherIterator( onCompleteCallback.onComplete(context)) } + private def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) + } + + /** + * Returns true if the address is of a remote block manager or host-local. false otherwise. + */ + private def isRemote(address: BlockManagerId): Boolean = { + // If the executor id is empty then it is a merged block so we compare the hosts to check + // if it's remote. Otherwise, compare the blockManager Id. + (BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) && + address.host != blockManager.blockManagerId.host) || + (!BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) && + address != blockManager.blockManagerId) + } + + private def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { + chunksMetaMap(blockId).getCardinality + } + private def fetchUpToMaxBytes(): Unit = { if (isNettyOOMOnShuffle.get()) { if (reqsInFlight > 0) { @@ -835,8 +1287,11 @@ final class ShuffleBlockFetcherIterator( def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { sendRequest(request) - numBlocksInFlightPerAddress(remoteAddress) = - numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + if (!request.hasMergedBlocks) { + // Not updating any metrics for chunk count requests. + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } } def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { @@ -978,6 +1433,12 @@ object ShuffleBlockFetcherIterator { } } + /** + * Dummy shuffle block id to fill into [[MergedBlocksMetaFetchResult]] and + * [[MergedBlocksMetaFailedFetchResult]], to match the [[FetchResult]] trait. + */ + private val DUMMY_SHUFFLE_BLOCK_ID = ShuffleBlockId(-1, -1, -1) + /** * This function is used to merged blocks when doBatchFetch is true. Blocks which have the * same `mapId` can be merged into one block batch. The block batch is specified by a range @@ -1074,8 +1535,13 @@ object ShuffleBlockFetcherIterator { * A request to fetch blocks from a remote BlockManager. * @param address remote BlockManager to fetch from. * @param blocks Sequence of the information for blocks to fetch from the same address. + * @param hasMergedBlocks true if this request contains merged blocks; false if it contains + * regular or shuffle block chunks. */ - case class FetchRequest(address: BlockManagerId, blocks: Seq[FetchBlockInfo]) { + case class FetchRequest( + address: BlockManagerId, + blocks: Seq[FetchBlockInfo], + hasMergedBlocks: Boolean = false) { val size = blocks.map(_.size).sum } @@ -1124,4 +1590,51 @@ object ShuffleBlockFetcherIterator { */ private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult + + /** + * Result of a fetch from a remote merged block unsuccessfully. + * Instead of treating this as a FailureFetchResult, we ignore this failure + * and fallback to fetch the original unmerged blocks. + * @param blockId block id + * @param address BlockManager that the merged block was attempted to be fetched from + * @param size size of the block, used to update bytesInFlight. + * @param isNetworkReqDone Is this the last network request for this host in this fetch + * request. Used to update reqsInFlight. + */ + private[storage] case class IgnoreFetchResult(blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) extends FetchResult + + /** + * Result of a successful fetch of meta information for a merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param blockSize size of each merged block. + * @param numChunks number of chunks in the merged block. + * @param bitmaps bitmaps for every chunk. + * @param address BlockManager that the merged status was fetched from. + */ + private[storage] case class MergedBlocksMetaFetchResult( + shuffleId: Int, + reduceId: Int, + blockSize: Long, + numChunks: Int, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId, + blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult + + /** + * Result of a failure while fetching the meta information for a merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param address BlockManager that the merged status was fetched from. + */ + private[storage] case class MergedBlocksMetaFailedFetchResult( + shuffleId: Int, + reduceId: Int, + address: BlockManagerId, + blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index b3138d7f126d..fc48dd85b22c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -210,4 +210,17 @@ class BlockIdSuite extends SparkFunSuite { assert(!id.isShuffle) assertSame(id, BlockId(id.toString)) } + + test("shuffle chunk") { + val id = ShuffleBlockChunkId(1, 1, 0) + assertSame(id, ShuffleBlockChunkId(1, 1, 0)) + assertDifferent(id, ShuffleBlockChunkId(1, 1, 1)) + assert(id.name === "shuffleChunk_1_1_0") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.reduceId === 1) + assert(id.chunkId === 0) + assertSame(id, BlockId(id.toString)) + } + } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 9b633479cb3a..64d5bc354e49 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -22,20 +22,23 @@ import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.{CompletableFuture, Semaphore} +import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future import io.netty.util.internal.OutOfDirectMemoryError import org.apache.log4j.Level import org.mockito.ArgumentMatchers.{any, eq => meq} -import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.Mockito.{doThrow, mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.roaringbitmap.RoaringBitmap import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo @@ -45,9 +48,13 @@ import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { private var transfer: BlockTransferService = _ + private var mapOutputTracker: MapOutputTracker = _ override def beforeEach(): Unit = { transfer = mock(classOf[BlockTransferService]) + mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(Seq.empty.iterator) } private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) @@ -178,6 +185,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT tContext, transfer, blockManager.getOrElse(createMockBlockManager()), + mapOutputTracker, blocksByAddress.toIterator, (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in), maxBytesInFlight, @@ -1017,4 +1025,545 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } assert(e.getMessage.contains("fetch failed after 10 retries due to Netty OOM")) } + + /** + * Prepares the transfer to trigger success for all the blocks present in blockChunks. It will + * trigger failure of block which is not part of blockChunks. + */ + private def configureMockTransferForPushShuffle( + blocksSem: Semaphore, + blockChunks: Map[BlockId, ManagedBuffer]): Unit = { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val regularBlocks = invocation.getArguments()(3).asInstanceOf[Array[String]] + val blockFetchListener = + invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + regularBlocks.foreach(blockId => { + val shuffleBlock = BlockId(blockId) + if (!blockChunks.contains(shuffleBlock)) { + // force failure + blockFetchListener.onBlockFetchFailure( + blockId, new RuntimeException("failed to fetch")) + } else { + blockFetchListener.onBlockFetchSuccess(blockId, blockChunks(shuffleBlock)) + } + blocksSem.release() + }) + } + }) + } + + test("fetch merged blocks meta") { + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1)), + (BlockManagerId("remote-client-1", "remote-host-1", 1), + toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)) + ) + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + + val metaSem = new Semaphore(0) + val mergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(mergedBlockMeta.getNumChunks).thenReturn(2) + when(mergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val roaringBitmaps = Array(new RoaringBitmap, new RoaringBitmap) + when(mergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + Future { + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + + s"port = ${invocation.getArguments()(1)}, " + + s"shuffleId = $shuffleId, reduceId = $reduceId") + metaSem.acquire() + metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + } + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) + blocksSem.acquire(2) + // The first block should be returned without an exception + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + metaSem.release() + val (id3, _) = iterator.next() + blocksSem.acquire() + assert(id3 === ShuffleBlockChunkId(0, 2, 0)) + val (id4, _) = iterator.next() + blocksSem.acquire() + assert(id4 === ShuffleBlockChunkId(0, 2, 1)) + assert(!iterator.hasNext) + } + + test("failed to fetch merged blocks meta") { + val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1)), + (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))) + + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer() + ) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((remoteBmId, toBlockList( + Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error")) + } + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) + blocksSem.acquire(2) + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + blocksSem.acquire(2) + assert(id3 === ShuffleBlockId(0, 1, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + assert(!iterator.hasNext) + } + + private def createMockMergedBlockMeta( + numChunks: Int, + bitmaps: Array[RoaringBitmap]): MergedBlockMeta = { + val mergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(mergedBlockMeta.getNumChunks).thenReturn(numChunks) + if (bitmaps == null) { + when(mergedBlockMeta.readChunkBitmaps()).thenThrow(new IOException("forced error")) + } else { + when(mergedBlockMeta.readChunkBitmaps()).thenReturn(bitmaps) + } + doReturn(createMockManagedBuffer()).when(mergedBlockMeta).getChunksBitmapBuffer + mergedBlockMeta + } + + private def prepareBlocksForFallbackWhenBlocksAreLocal( + blockManager: BlockManager, + localDirsMap : Map[String, Array[String]], + failReadingLocalChunksMeta: Boolean = false): + Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = { + val localBmId = BlockManagerId("test-client", "test-local-host", 1) + doReturn(localBmId).when(blockManager).blockManagerId + initHostLocalDirManager(blockManager, localDirsMap) + + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer() + ) + + doReturn(blockChunks(ShuffleBlockId(0, 0, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 0, 2)) + doReturn(blockChunks(ShuffleBlockId(0, 1, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 1, 2)) + doReturn(blockChunks(ShuffleBlockId(0, 2, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 2, 2)) + doReturn(blockChunks(ShuffleBlockId(0, 3, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 3, 2)) + + val dirsForMergedData = localDirsMap(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER) + doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) + .getMergedBlockData(ShuffleBlockId(0, -1, 2), dirsForMergedData) + + // Get a valid chunk meta for this test + val bitmaps = Array(new RoaringBitmap) + bitmaps(0).add(1) // chunk 0 has mapId 1 + bitmaps(0).add(2) // chunk 0 has mapId 2 + val mergedBlockMeta: MergedBlockMeta = if (failReadingLocalChunksMeta) { + createMockMergedBlockMeta(bitmaps.length, null) + } else { + createMockMergedBlockMeta(bitmaps.length, bitmaps) + } + when(blockManager.getMergedBlockMeta(ShuffleBlockId(0, -1, 2), dirsForMergedData)) + .thenReturn(mergedBlockMeta) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((localBmId, + toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2, bitmaps(0))) + .thenReturn(Seq((localBmId, + toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) + val mergedBmId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1) + Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), + (mergedBmId, toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1))) + } + + private def verifyLocalBlocksFromFallback(iterator: ShuffleBlockFetcherIterator): Unit = { + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 1, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + assert(!iterator.hasNext) + } + + test("failed to fetch local merged blocks then fallback to fetch original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getMergedBlockData(ShuffleBlockId(0, -1, 2), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("failed to fetch local merged blocks then fallback to fetch original shuffle " + + "blocks which contains host-local blocks") { + val blockManager = mock(classOf[BlockManager]) + // BlockManagerId from another executor on the same host + val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) + val hostLocalDirs = Map("test-client-1" -> Array("local-dir"), + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir")) + val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal(blockManager, hostLocalDirs) + + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getMergedBlockData(ShuffleBlockId(0, -1, 2), Array("local-dir")) + // host local read for a shuffle block + doReturn(createMockManagedBuffer()).when(blockManager) + .getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir")) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenAnswer( + (_: InvocationOnMock) => { + Seq((blockManager.blockManagerId, toBlockList(Seq(ShuffleBlockId(0, 1, 2)), 1L, 1)), + (hostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("initialization and fallback with host locals blocks") { + val blockManager = mock(classOf[BlockManager]) + // BlockManagerId of another executor on the same host + val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) + val fallbackHostLocalBmId = BlockManagerId("test-client-2", "test-local-host", 1) + val hostLocalDirs = Map(hostLocalBmId.executorId -> Array("local-dir"), + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir"), + fallbackHostLocalBmId.executorId -> Array("local-dir")) + + val hostLocalBlocks = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (hostLocalBmId, Seq((ShuffleBlockId(0, 5, 2), 1L, 1)))) + + val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + blockManager, hostLocalDirs) ++ hostLocalBlocks + + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getMergedBlockData(ShuffleBlockId(0, -1, 2), Array("local-dir")) + // host Local read for this original shuffle block + doReturn(createMockManagedBuffer()).when(blockManager) + .getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir")) + doReturn(createMockManagedBuffer()).when(blockManager) + .getHostLocalShuffleData(ShuffleBlockId(0, 5, 2), Array("local-dir")) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenAnswer( + (_: InvocationOnMock) => { + Seq((blockManager.blockManagerId, toBlockList(Seq(ShuffleBlockId(0, 2, 2)), 1L, 1)), + (fallbackHostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 5, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + val (id5, _) = iterator.next() + assert(id5 === ShuffleBlockId(0, 1, 2)) + assert(!iterator.hasNext) + } + + test("failure while reading shuffle chunks should fallback to original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + // This will throw an IOException when input stream is created from the ManagedBuffer + doReturn(Seq({ + new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100) + })).when(blockManager).getMergedBlockData(ShuffleBlockId(0, -1, 2), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("fallback to original shuffle block when a merged block chunk is corrupt") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + val corruptBuffer = createMockManagedBuffer(2) + doReturn(Seq({corruptBuffer})).when(blockManager) + .getMergedBlockData(ShuffleBlockId(0, -1, 2), localDirs) + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + doReturn(corruptStream).when(corruptBuffer).createInputStream() + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) + verifyLocalBlocksFromFallback(iterator) + } + + test("failure when reading chunkBitmaps of local merged block should fallback to " + + "original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs), + failReadingLocalChunksMeta = true) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) + verifyLocalBlocksFromFallback(iterator) + } + + test("fallback to original blocks when failed to fetch remote shuffle chunk") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + val bitmaps = Array(new RoaringBitmap, new RoaringBitmap) + bitmaps(1).add(3) + bitmaps(1).add(4) + bitmaps(1).add(5) + val mergedBlockMeta = createMockMergedBlockMeta(2, bitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + } + }) + val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (BlockManagerId("remote-client", "remote-host-2", 1), + toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2)), 4L, 1))) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(fallbackBlocksByAddr.iterator) + val iterator = createShuffleBlockIteratorWithDefaults(Map( + BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) -> + toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 12L, -1))) + val (id1, _) = iterator.next() + blocksSem.acquire(1) + assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + val (id3, _) = iterator.next() + blocksSem.acquire(3) + assert(id3 === ShuffleBlockId(0, 3, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 4, 2)) + val (id5, _) = iterator.next() + assert(id5 === ShuffleBlockId(0, 5, 2)) + assert(!iterator.hasNext) + } + + test("fallback to original blocks when failed to parse remote merged block meta") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer() + ) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((BlockManagerId("remote-client-1", "remote-host-1", 1), + toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + val mergedBlockMeta = createMockMergedBlockMeta(2, null) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + } + }) + val remoteMergedBlockMgrId = BlockManagerId( + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteMergedBlockMgrId -> toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1))) + val (id1, _) = iterator.next() + blocksSem.acquire(2) + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 1, 2)) + assert(!iterator.hasNext) + } + + test("failure to fetch a remote merged block chunk initiates the fallback of" + + " deferred shuffle chunks immediately") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + // ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in block-chunks. + ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 3) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 6, 2) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + + val metaSem = new Semaphore(0) + val mergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(mergedBlockMeta.getNumChunks).thenReturn(4) + when(mergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val roaringBitmaps = Array.fill[RoaringBitmap](4)(new RoaringBitmap) + when(mergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + + s"port = ${invocation.getArguments()(1)}, " + + s"shuffleId = $shuffleId, reduceId = $reduceId") + metaSem.release() + metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + } + }) + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)), 1L, 1))) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(fallbackBlocksByAddr.iterator) + + val iterator = createShuffleBlockIteratorWithDefaults(Map( + BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> + toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 16L, -1)), + maxBytesInFlight = 4) + metaSem.acquire(1) + val (id1, _) = iterator.next() + blocksSem.acquire(1) + assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + val regularBlocks = new mutable.HashSet[BlockId]() + val (id2, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id2) + val (id3, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id3) + val (id4, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id4) + val (id5, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id5) + assert(!iterator.hasNext) + assert(regularBlocks === Set(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2))) + } + + test("failure to fetch a remote merged block chunk initiates the fallback of" + + " deferred shuffle chunks immediately which got deferred") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(), + // ShuffleBlockChunkId(0, 2, 3) will cause failure as it is not in bock chunks + ShuffleBlockChunkId(0, 2, 4) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 5) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 6, 2) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + val metaSem = new Semaphore(0) + val mergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(mergedBlockMeta.getNumChunks).thenReturn(6) + when(mergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val roaringBitmaps = Array.fill[RoaringBitmap](6)(new RoaringBitmap) + when(mergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + + s"port = ${invocation.getArguments()(1)}, " + + s"shuffleId = $shuffleId, reduceId = $reduceId") + metaSem.release() + metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + } + }) + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)), 1L, 1))) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(fallbackBlocksByAddr.iterator) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2), + Seq((ShuffleBlockId(0, -1, 2) + , 24L, -1)))) + val iterator = createShuffleBlockIteratorWithDefaults(Map( + BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> + toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 24L, -1)), + maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1) + metaSem.acquire(1) + val (id1, _) = iterator.next() + blocksSem.acquire(2) + assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockChunkId(0, 2, 1)) + val (id3, _) = iterator.next() + blocksSem.acquire(1) + assert(id3 === ShuffleBlockChunkId(0, 2, 2)) + val regularBlocks = new mutable.HashSet[BlockId]() + val (id4, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id4) + val (id5, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id5) + val (id6, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id6) + val (id7, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id7) + assert(!iterator.hasNext) + assert(regularBlocks === Set[ShuffleBlockId](ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2))) + } + } From 6762de3e5c21cad7025412c16293d7ed694efe36 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Thu, 23 Jul 2020 17:43:50 -0700 Subject: [PATCH 03/27] LIHADOOP-52494 Magnet fallback to origin shuffle blocks when fetch of a shuffle chunk fails RB=2203642 BUG=LIHADOOP-52494 G=spark-reviewers R=yzhou,mshen,vsowrira A=mshen --- .../spark/network/sasl/SparkSaslSuite.java | 6 +- .../protocol/FetchMergedBlocksMeta.java | 83 ------------------- .../shuffle/protocol/MergedBlocksMeta.java | 75 ----------------- .../spark/shuffle/ShuffleBlockResolver.scala | 5 ++ .../org/apache/spark/storage/BlockId.scala | 9 ++ .../apache/spark/storage/BlockManager.scala | 7 ++ 6 files changed, 25 insertions(+), 160 deletions(-) delete mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java delete mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 32c9acd32721..a63f9072561f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -357,8 +357,10 @@ public void testRpcHandlerDelegate() throws Exception { public void testDelegates() throws Exception { Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); for (Method m : rpcHandlerMethods) { - Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes()); - assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class); + if (!m.getName().contains("lambda")) { + Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes()); + assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class); + } } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java deleted file mode 100644 index 863aa1bfae77..000000000000 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchMergedBlocksMeta.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import java.util.Arrays; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; - -/** - * Request to find the meta information for the specified merged blocks. The meta information - * currently contains only the number of chunks in each merged blocks. - */ -public class FetchMergedBlocksMeta extends BlockTransferMessage { - public final String appId; - public final String[] blockIds; - - public FetchMergedBlocksMeta(String appId, String[] blockIds) { - this.appId = appId; - this.blockIds = blockIds; - } - - @Override - protected Type type() { return Type.FETCH_MERGED_BLOCKS_META; } - - @Override - public int hashCode() { - return appId.hashCode() * 41 + Arrays.hashCode(blockIds); - } - - @Override - public String toString() { - return Objects.toStringHelper(this) - .add("appId", appId) - .add("blockIds", Arrays.toString(blockIds)) - .toString(); - } - - @Override - public boolean equals(Object other) { - if (other instanceof FetchMergedBlocksMeta) { - FetchMergedBlocksMeta o = (FetchMergedBlocksMeta) other; - return Objects.equal(appId, o.appId) - && Arrays.equals(blockIds, o.blockIds); - } - return false; - } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) - + Encoders.StringArrays.encodedLength(blockIds); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.StringArrays.encode(buf, blockIds); - } - - public static FetchMergedBlocksMeta decode(ByteBuf buf) { - String appId = Encoders.Strings.decode(buf); - String[] blockIds = Encoders.StringArrays.decode(buf); - return new FetchMergedBlocksMeta(appId, blockIds); - } -} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java deleted file mode 100644 index 94c3e616491f..000000000000 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergedBlocksMeta.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import java.util.Arrays; -import javax.annotation.Nonnull; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.spark.network.protocol.Encoders; - -/** - * Response of {@link FetchMergedBlocksMeta}. - */ -public class MergedBlocksMeta extends BlockTransferMessage { - - public final int[] numChunks; - - public MergedBlocksMeta(@Nonnull int[] numChunks) { - this.numChunks = numChunks; - } - - @Override - protected Type type() { return Type.MERGED_BLOCKS_META; } - - @Override - public int hashCode() { - return Arrays.hashCode(numChunks); - } - - @Override - public String toString() { - Objects.ToStringHelper helper = Objects.toStringHelper(this); - return helper.add("numChunks", Arrays.toString(numChunks)).toString(); - } - - @Override - public boolean equals(Object other) { - if (other instanceof MergedBlocksMeta) { - MergedBlocksMeta o = (MergedBlocksMeta) other; - return Arrays.equals(numChunks, o.numChunks); - } - return false; - } - - @Override - public int encodedLength() { - return Encoders.IntArrays.encodedLength(numChunks); - } - - @Override - public void encode(ByteBuf buf) { - Encoders.IntArrays.encode(buf, numChunks); - } - - public static MergedBlocksMeta decode(ByteBuf buf) { - return new MergedBlocksMeta(Encoders.IntArrays.decode(buf)); - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 49e59298cc0c..2d839254e1a7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -51,5 +51,10 @@ trait ShuffleBlockResolver { */ def getMergedBlockMeta(blockId: ShuffleBlockId, dirs: Option[Array[String]]): MergedBlockMeta + /** + * Retrieve the meta data for the specified merged shuffle block. + */ + def getMergedBlockMeta(blockId: ShuffleBlockId): MergedBlockMeta + def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index dc70a9af7e9c..1dd5aff7cec4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -125,6 +125,15 @@ case class ShuffleMergedMetaBlockId( appId + "_" + shuffleId + "_" + reduceId + ".meta" } +@DeveloperApi +case class ShuffleMergedMetaBlockId( + appId: String, + shuffleId: Int, + reduceId: Int) extends BlockId { + override def name: String = + "mergedShuffle_" + appId + "_" + shuffleId + "_" + reduceId + ".meta" +} + @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index df449fba24e9..a7718184701f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -750,6 +750,13 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId, Some(dirs)) } + /** + * Get the local merged shuffle block metada data for the given block ID. + */ + def getMergedBlockMeta(blockId: ShuffleBlockId): MergedBlockMeta = { + shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId) + } + /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing. From 5bbe466bc9dbab343624e829e696889f31ee414b Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Sun, 13 Dec 2020 12:37:13 -0800 Subject: [PATCH 04/27] Changed the MergedBlockMetaRequest --- .../server/TransportRequestHandler.java | 56 ++----------------- .../spark/network/sasl/SparkSaslSuite.java | 6 +- .../spark/network/BlockDataManager.scala | 15 ++++- .../org/apache/spark/storage/BlockId.scala | 9 --- 4 files changed, 21 insertions(+), 65 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 8407f2f441af..ab2deac20fcd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -32,7 +32,6 @@ import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.*; import org.apache.spark.network.protocol.*; -import org.apache.spark.network.util.JavaUtils; import org.apache.spark.network.util.TransportFrameDecoder; import static org.apache.spark.network.util.NettyUtils.getRemoteAddress; @@ -186,18 +185,6 @@ public void onFailure(Throwable e) { private void processStreamUpload(final UploadStream req) { assert (req.body() == null); try { - // Retain the original metadata buffer, since it will be used during the invocation of - // this method. Will be released later. - req.meta.retain(); - // Make a copy of the original metadata buffer. In benchmark, we noticed that - // we cannot respond the original metadata buffer back to the client, otherwise - // in cases where multiple concurrent shuffles are present, a wrong metadata might - // be sent back to client. This is related to the eager release of the metadata buffer, - // i.e., we always release the original buffer by the time the invocation of this - // method ends, instead of by the time we respond it to the client. This is necessary, - // otherwise we start seeing memory issues very quickly in benchmarks. - // TODO check if the way metadata buffer is handled can be further improved - ByteBuffer meta = cloneBuffer(req.meta.nioByteBuffer()); RpcResponseCallback callback = new RpcResponseCallback() { @Override public void onSuccess(ByteBuffer response) { @@ -206,17 +193,13 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { - // Piggyback request metadata as part of the exception error String, so we can - // respond the metadata upon a failure without changing the existing protocol. - respond(new RpcFailure(req.requestId, - JavaUtils.encodeHeaderIntoErrorString(meta.duplicate(), e))); - req.meta.release(); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); } }; TransportFrameDecoder frameDecoder = (TransportFrameDecoder) channel.pipeline().get(TransportFrameDecoder.HANDLER_NAME); - StreamCallbackWithID streamHandler = - rpcHandler.receiveStream(reverseClient, meta.duplicate(), callback); + ByteBuffer meta = req.meta.nioByteBuffer(); + StreamCallbackWithID streamHandler = rpcHandler.receiveStream(reverseClient, meta, callback); if (streamHandler == null) { throw new NullPointerException("rpcHandler returned a null streamHandler"); } @@ -230,17 +213,12 @@ public void onData(String streamId, ByteBuffer buf) throws IOException { public void onComplete(String streamId) throws IOException { try { streamHandler.onComplete(streamId); - callback.onSuccess(meta.duplicate()); + callback.onSuccess(ByteBuffer.allocate(0)); } catch (Exception ex) { IOException ioExc = new IOException("Failure post-processing complete stream;" + " failing this rpc and leaving channel active", ex); - // req.meta will be released once inside callback.onFailure. Retain it one more - // time to be released in the finally block. - req.meta.retain(); callback.onFailure(ioExc); streamHandler.onFailure(streamId, ioExc); - } finally { - req.meta.release(); } } @@ -264,26 +242,12 @@ public String getID() { } } catch (Exception e) { logger.error("Error while invoking RpcHandler#receive() on RPC id " + req.requestId, e); - try { - // It's OK to respond the original metadata buffer here, because this is still inside - // the invocation of this method. - respond(new RpcFailure(req.requestId, - JavaUtils.encodeHeaderIntoErrorString(req.meta.nioByteBuffer(), e))); - } catch (IOException ioe) { - // No exception will be thrown here. req.meta.nioByteBuffer will not throw IOException - // because it's a NettyManagedBuffer. This try-catch block is to make compiler happy. - logger.error("Error in handling failure while invoking RpcHandler#receive() on RPC id " - + req.requestId, e); - } finally { - req.meta.release(); - } + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); // We choose to totally fail the channel, rather than trying to recover as we do in other // cases. We don't know how many bytes of the stream the client has already sent for the // stream, it's not worth trying to recover. channel.pipeline().fireExceptionCaught(e); } finally { - // Make sure we always release the original metadata buffer by the time we exit the - // invocation of this method. Otherwise, we see memory issues fairly quickly in benchmarks. req.meta.release(); } } @@ -322,16 +286,6 @@ public void onFailure(Throwable e) { } } - /** - * Make a full copy of a nio ByteBuffer. - */ - private ByteBuffer cloneBuffer(ByteBuffer buf) { - ByteBuffer clone = ByteBuffer.allocate(buf.capacity()); - clone.put(buf.duplicate()); - clone.flip(); - return clone; - } - /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index a63f9072561f..32c9acd32721 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -357,10 +357,8 @@ public void testRpcHandlerDelegate() throws Exception { public void testDelegates() throws Exception { Method[] rpcHandlerMethods = RpcHandler.class.getDeclaredMethods(); for (Method m : rpcHandlerMethods) { - if (!m.getName().contains("lambda")) { - Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes()); - assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class); - } + Method delegate = SaslRpcHandler.class.getMethod(m.getName(), m.getParameterTypes()); + assertNotEquals(delegate.getDeclaringClass(), RpcHandler.class); } } diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 4bf979e41b6c..205addb2097c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -22,6 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID +import org.apache.spark.network.shuffle.MergedBlockMeta import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] @@ -72,5 +73,17 @@ trait BlockDataManager { */ def releaseLock(blockId: BlockId, taskContext: Option[TaskContext]): Unit - def getMergedBlockData(blockId: ShuffleBlockId): Seq[ManagedBuffer] + /** + * Interface to get merged shuffle block data. Throws an exception if the block cannot be found + * or cannot be read successfully. + */ + // PART OF SPARK-33350 + def getMergedBlockData(blockId: BlockId, dirs: Array[String]): Seq[ManagedBuffer] + + /** + * Interface to get merged shuffle block meta. Throws an exception if the meta cannot be found + * or cannot be read successfully. + */ + // PART OF SPARK-33350 + def getMergedBlockMeta(blockId: BlockId, dirs: Array[String]): MergedBlockMeta } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 1dd5aff7cec4..dc70a9af7e9c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -125,15 +125,6 @@ case class ShuffleMergedMetaBlockId( appId + "_" + shuffleId + "_" + reduceId + ".meta" } -@DeveloperApi -case class ShuffleMergedMetaBlockId( - appId: String, - shuffleId: Int, - reduceId: Int) extends BlockId { - override def name: String = - "mergedShuffle_" + appId + "_" + shuffleId + "_" + reduceId + ".meta" -} - @DeveloperApi case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { override def name: String = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) From ff80579769b09660da5751b922c9ce2df9ae4204 Mon Sep 17 00:00:00 2001 From: Min Shen Date: Mon, 14 Dec 2020 12:04:33 -0800 Subject: [PATCH 05/27] empty commit from Min From ef464aa725de450d757b27c0c08eac78d8b2cc42 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 26 May 2021 17:11:58 -0700 Subject: [PATCH 06/27] Created PushBasedFetchHelper that encapsulates all pushbased functionality --- .../org/apache/spark/MapOutputTracker.scala | 3 +- .../shuffle/BlockStoreShuffleReader.scala | 1 + .../storage/ShuffleBlockFetcherIterator.scala | 724 ++++++++++-------- .../ShuffleBlockFetcherIteratorSuite.scala | 39 +- 4 files changed, 408 insertions(+), 359 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 2b06c4980515..e605eea80229 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -1453,7 +1453,8 @@ private[spark] object MapOutputTracker extends Logging { // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is // a merged shuffle block. splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1)) + ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, + SHUFFLE_PUSH_MAP_ID)) // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper // shuffle partition blocks, fetch the original map produced shuffle partition blocks val mapStatusesWithIndex = mapStatuses.zipWithIndex diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 7f90eabf34bd..818aa2ef75a9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -35,6 +35,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0bc0b11400b5..69f392456de5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -32,11 +32,13 @@ import org.apache.commons.io.IOUtils import org.roaringbitmap.RoaringBitmap import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.storage.ShuffleBlockFetcherIterator.{FetchBlockInfo, FetchRequest, IgnoreFetchResult, MergedBlocksMetaFailedFetchResult, MergedBlocksMetaFetchResult, SuccessFetchResult} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} /** @@ -176,12 +178,8 @@ final class ShuffleBlockFetcherIterator( private[this] val onCompleteCallback = new ShuffleFetchCompletionListener(this) - /** A map for storing merged block shuffle chunk bitmap */ - private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() - - private[this] val localShuffleMergerBlockMgrId = BlockManagerId( - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, - blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) + private[this] val pushBasedFetchHelper = new PushBasedFetchHelper( + this, shuffleClient, blockManager, mapOutputTracker) initialize() @@ -253,7 +251,7 @@ final class ShuffleBlockFetcherIterator( logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) if (req.hasMergedBlocks) { - sendFetchMergedStatusRequest(req) + pushBasedFetchHelper.sendFetchMergedStatusRequest(req) return } bytesInFlight += req.size @@ -362,136 +360,10 @@ final class ShuffleBlockFetcherIterator( } } - private[this] def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { - val sizeMap = req.blocks.map { - case FetchBlockInfo(blockId, size, _) => - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap - val address = req.address - val mergedBlocksMetaListener = new MergedBlocksMetaListener { - override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { - logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId) " + - s"from ${req.address.host}:${req.address.port}") - try { - results.put(MergedBlocksMetaFetchResult(shuffleId, reduceId, sizeMap(shuffleId, reduceId), - meta.getNumChunks, meta.readChunkBitmaps(), address)) - } catch { - case _: Throwable => - results.put(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) - } - } - - override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { - logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " + - s"from ${req.address.host}:${req.address.port}", exception) - results.put(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) - } - } - req.blocks.foreach(block => { - val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId] - shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId, - shuffleBlockId.reduceId, mergedBlocksMetaListener) - }) - } - /** - * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed - * to fetch. - * It calls out to map output tracker to get the list of original blocks for the - * given merged blocks, split them into remote and local blocks, and process them - * accordingly. - * The fallback happens when: - * 1. There is an exception while creating shuffle block chunk from local merged shuffle block. - * See fetchLocalBlock. - * 2. There is a failure when fetching remote shuffle block chunks. - * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk (local or - * remote). + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. */ - private[this] def initiateFallbackBlockFetchForMergedBlock( - blockId: BlockId, - address: BlockManagerId): Unit = { - logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId") - // Increase the blocks processed, since we will process another block in the next iteration of - // the while loop. - numBlocksProcessed += 1 - val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = - if (blockId.isShuffle) { - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - mapOutputTracker.getMapSizesForMergeResult( - shuffleBlockId.shuffleId, shuffleBlockId.reduceId) - } else { - val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] - val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull - if (isRemote(address)) { - // Fallback for all the pending fetch requests - val pendingShuffleChunks = removePendingChunks(shuffleChunkId, address) - if (pendingShuffleChunks.nonEmpty) { - pendingShuffleChunks.foreach { pendingBlockId => - logWarning(s"Falling back immediately for merged block $pendingBlockId") - val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).orNull - assert(bitmapOfPendingChunk != null) - chunkBitmap.or(bitmapOfPendingChunk) - } - // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed - numBlocksProcessed += pendingShuffleChunks.size - } - } - mapOutputTracker.getMapSizesForMergeResult( - shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) - } - val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() - val fallbackHostLocalBlocksByExecutor = - mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() - val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() - val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr, - fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks) - // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(fallbackRemoteReqs) - logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged " + - s"block $blockId") - // If there is any fall back block that's a local block, we get them here. The original - // invocation to fetchLocalBlocks might have already returned by this time, so we need - // to invoke it again here. - fetchLocalBlocks(fallbackLocalBlocks) - // Merged local blocks should be empty during fallback - assert(fallbackMergedLocalBlocks.isEmpty, "There should be zero merged blocks during fallback") - // Some of the fallback local blocks could be host local blocks - fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor) - } - - private[this] def removePendingChunks( - failedBlockId: ShuffleBlockChunkId, - address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { - val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() - - def sameShuffleBlockChunk(block: BlockId): Boolean = { - val chunkId = block.asInstanceOf[ShuffleBlockChunkId] - chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId - } - - def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { - val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() - fetchRequestsToRemove ++= queue.dequeueAll(req => { - val firstBlock = req.blocks.head - firstBlock.blockId.isShuffleChunk && req.address.equals(address) && - sameShuffleBlockChunk(firstBlock.blockId) - }) - fetchRequestsToRemove.foreach(req => { - removedChunkIds ++= req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId]) - }) - } - - filterRequests(fetchRequests) - val defRequests = deferredFetchRequests.remove(address).orNull - if (defRequests != null) { - filterRequests(defRequests) - if (defRequests.nonEmpty) { - deferredFetchRequests(address) = defRequests - } - } - removedChunkIds - } - private[this] def partitionBlocksByFetchMode( blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)], @@ -512,7 +384,7 @@ final class ShuffleBlockFetcherIterator( val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId for ((address, blockInfos) <- blocksByAddress) { - if (isMergedShuffleBlockAddress(address)) { + if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) { // These are push-based merged blocks or chunks of these merged blocks. if (address.host == blockManager.blockManagerId.host) { checkBlockSizes(blockInfos) @@ -525,7 +397,7 @@ final class ShuffleBlockFetcherIterator( s"of size $mergedLocalBlockBytes") } else { remoteBlockBytes += blockInfos.map(_._2).sum - collectFetchReqsFromMergedBlocks(address, blockInfos, collectedRemoteRequests) + collectFetchRequests(address, blockInfos, collectedRemoteRequests) } } else if ( Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) { @@ -613,93 +485,59 @@ final class ShuffleBlockFetcherIterator( } retBlocks } - - /** - * Collect the FetchRequests for push-based merged blocks from remote shuffle services. - * This method will be called to either initialize the FetchRequests with the original - * blocks or during the fallback when the fetch of merged shuffle blocks/chunks fail. - * - * @param address remote shuffle service address - * @param blockInfos shuffle block information - * @param collectedRemoteRequests queue of FetchRequests - * @return Number of blocks put into remoteRequests during this invocation - */ - private[this] def collectFetchReqsFromMergedBlocks( + + private def collectFetchRequests( address: BlockManagerId, blockInfos: Seq[(BlockId, Long, Int)], collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { val iterator = blockInfos.iterator var curRequestSize = 0L - var curBlocks = Seq.empty[FetchBlockInfo] - var mergedRequestsSize = 0L - var mergedBlocks = Seq.empty[FetchBlockInfo] + var curBlocks = new ArrayBuffer[FetchBlockInfo]() + while (iterator.hasNext) { - val (blockId, size, _) = iterator.next() + val (blockId, size, mapIndex) = iterator.next() assertPositiveBlockSize(blockId, size) + curBlocks += FetchBlockInfo(blockId, size, mapIndex) + curRequestSize += size blockId match { - case ShuffleBlockId(_, mapId, _) => - assert(mapId == -1) - mergedBlocks = mergedBlocks ++ Seq(FetchBlockInfo(blockId, size, -1)) - mergedRequestsSize += size - if (mergedRequestsSize >= targetRemoteRequestSize || - mergedBlocks.size >= maxBlocksInFlightPerAddress) { - mergedBlocks = createFetchRequests(mergedBlocks, address, isLast = false, - collectedRemoteRequests, enableBatchFetch = false, areMergedBlocks = true) - mergedRequestsSize = mergedBlocks.map(_.size).sum - } + // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // hasMergedBlocks set. case ShuffleBlockChunkId(_, _, _) => - curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, -1)) - curRequestSize += size if (curRequestSize >= targetRemoteRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { curBlocks = createFetchRequests(curBlocks, address, isLast = false, collectedRemoteRequests, enableBatchFetch = false) curRequestSize = curBlocks.map(_.size).sum } + case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => + if (curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests(curBlocks, address, isLast = false, + collectedRemoteRequests, enableBatchFetch = false, areMergedBlocks = true) + curRequestSize = curBlocks.map(_.size).sum + } case _ => - throw new SparkException( - "Failed to match block " + blockId + ", which is not a merged shuffle block or chunk" - ) + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, + collectedRemoteRequests, doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } } } // Add in the final request if (curBlocks.nonEmpty) { - curBlocks = createFetchRequests(curBlocks, address, isLast = true, - collectedRemoteRequests, enableBatchFetch = false) - curRequestSize = curBlocks.map(_.size).sum - } - if (mergedBlocks.nonEmpty) { - mergedBlocks = createFetchRequests(mergedBlocks, address, isLast = true, - collectedRemoteRequests, enableBatchFetch = false, areMergedBlocks = true) - mergedRequestsSize = mergedBlocks.map(_.size).sum - } - } - - private def collectFetchRequests( - address: BlockManagerId, - blockInfos: Seq[(BlockId, Long, Int)], - collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = { - val iterator = blockInfos.iterator - var curRequestSize = 0L - var curBlocks = new ArrayBuffer[FetchBlockInfo]() - - while (iterator.hasNext) { - val (blockId, size, mapIndex) = iterator.next() - assertPositiveBlockSize(blockId, size) - curBlocks += FetchBlockInfo(blockId, size, mapIndex) - curRequestSize += size - // For batch fetch, the actual block in flight should count for merged block. - val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress - if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { - curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, - collectedRemoteRequests, doBatchFetch) - curRequestSize = curBlocks.map(_.size).sum + val (enableBatchFetch, areMergedBlocks) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _) => (false, false) + case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true) + case _ => (doBatchFetch, false) + } } - } - // Add in the final request - if (curBlocks.nonEmpty) { createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests, - doBatchFetch) + enableBatchFetch = enableBatchFetch, areMergedBlocks = areMergedBlocks) } } @@ -778,7 +616,7 @@ final class ShuffleBlockFetcherIterator( private[this] def fetchHostLocalBlocks( hostLocalDirManager: HostLocalDirManager, hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): - Unit = { + Unit = { val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case (hostLocalBmId, _) => @@ -848,88 +686,6 @@ final class ShuffleBlockFetcherIterator( } } - /** - * Fetch a single local merged block generated. - * @param blockId ShuffleBlockId to be fetched - * @param localDirs Local directories where the merged shuffle files are stored - * @param blockManagerId BlockManagerId - * @return Boolean represents successful or failed fetch - */ - private[this] def fetchMergedLocalBlock( - blockId: BlockId, - localDirs: Array[String], - blockManagerId: BlockManagerId): Boolean = { - try { - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs).readChunkBitmaps() - // Fetch local merged shuffle block data as multiple chunks - val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs) - // Update total number of blocks to fetch, reflecting the multiple local chunks - numBlocksToFetch += bufs.size - 1 - for (chunkId <- bufs.indices) { - val buf = bufs(chunkId) - buf.retain() - val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId, shuffleBlockId.reduceId, - chunkId) - results.put(SuccessFetchResult(shuffleChunkId, -1, blockManagerId, buf.size(), buf, - isNetworkReqDone = false)) - chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) - } - true - } catch { - case e: Exception => - // If we see an exception with reading a local merged block, we fallback to - // fetch the original unmerged blocks. We do not report block fetch failure - // and will continue with the remaining local block read. - logWarning(s"Error occurred while fetching local merged block, " + - s"prepare to fetch the original blocks", e) - results.put(IgnoreFetchResult(blockId, blockManagerId, 0, false)) - false - } - } - - /** - * Fetch the merged local blocks while we are fetching remote blocks. - */ - private[this] def fetchMergedLocalBlocks( - hostLocalDirManager: HostLocalDirManager, - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { - val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER) - if (cachedMergerDirs.isDefined) { - logDebug(s"Fetching local merged blocks with cached executors dir: " + - s"${cachedMergerDirs.get.mkString(", ")}") - mergedLocalBlocks.foreach(blockId => - fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId)) - } else { - logDebug(s"Asynchronous fetching local merged blocks without cached executors dir") - hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, - localShuffleMergerBlockMgrId.port, Array(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER)) { - case Success(dirs) => - mergedLocalBlocks.takeWhile { - case blockId => - logDebug(s"Successfully fetched local dirs: " + - s"${dirs.get(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") - fetchMergedLocalBlock(blockId, dirs(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER), - localShuffleMergerBlockMgrId) - } - logDebug(s"Got local merged blocks (without cached executors' dir) in " + - s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - case Failure(throwable) => - // If we see an exception with getting the local dirs for local merged blocks, - // we fallback to fetch the original unmerged blocks. We do not report block fetch - // failure. - logWarning(s"Error occurred while getting the local dirs for local merged " + - s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", - throwable) - mergedLocalBlocks.foreach( - blockId => results.put(IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, - isNetworkReqDone = false)) - ) - } - } - } - private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. context.addTaskCompletionListener(onCompleteCallback) @@ -961,24 +717,17 @@ final class ShuffleBlockFetcherIterator( logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") // Get host local blocks if any fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) - fetchAllMergedLocalBlocks(mergedLocalBlocks) + pushBasedFetchHelper.fetchAllMergedLocalBlocks(mergedLocalBlocks) } private def fetchAllHostLocalBlocks( hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): - Unit = { + Unit = { if (hostLocalBlocksByExecutor.nonEmpty) { blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) } } - // Fetch all outstanding merged local blocks - private def fetchAllMergedLocalBlocks(mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { - if (mergedLocalBlocks.nonEmpty) { - blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) - } - } - override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch /** @@ -1012,12 +761,11 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { - if (isMergedShuffleBlockAddress(address) - && address.host == blockManager.blockManagerId.host) { + if (pushBasedFetchHelper.isMergedLocal(address)) { // It is a local merged block chunk assert(blockId.isShuffleChunk) - shuffleMetrics.incLocalBlocksFetched( - getNumberOfBlocksInChunk(blockId.asInstanceOf[ShuffleBlockChunkId])) + shuffleMetrics.incLocalBlocksFetched(pushBasedFetchHelper.getNumberOfBlocksInChunk( + blockId.asInstanceOf[ShuffleBlockChunkId])) shuffleMetrics.incLocalBytesRead(buf.size) } else if (hostLocalBlocks.contains(blockId -> mapIndex)) { shuffleMetrics.incLocalBlocksFetched(1) @@ -1031,7 +779,8 @@ final class ShuffleBlockFetcherIterator( } if (blockId.isShuffleChunk) { shuffleMetrics.incRemoteBlocksFetched( - getNumberOfBlocksInChunk(blockId.asInstanceOf[ShuffleBlockChunkId])) + pushBasedFetchHelper.getNumberOfBlocksInChunk( + blockId.asInstanceOf[ShuffleBlockChunkId])) } else { shuffleMetrics.incRemoteBlocksFetched(1) } @@ -1077,7 +826,8 @@ final class ShuffleBlockFetcherIterator( } buf.release() if (blockId.isShuffleChunk) { - initiateFallbackBlockFetchForMergedBlock(blockId, address) + numBlocksProcessed += pushBasedFetchHelper + .initiateFallbackBlockFetchForMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop to get either. result = null null @@ -1109,7 +859,8 @@ final class ShuffleBlockFetcherIterator( // from chunksMetaMap whenever a Shuffle Block Chunk gets processed. // If we try to re-fetch a corrupt shuffle chunk, then it has to be added // back to the chunksMetaMap. - initiateFallbackBlockFetchForMergedBlock(blockId, address) + numBlocksProcessed += pushBasedFetchHelper + .initiateFallbackBlockFetchForMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop. result = null } else { @@ -1126,7 +877,7 @@ final class ShuffleBlockFetcherIterator( } } finally { if (blockId.isShuffleChunk) { - chunksMetaMap.remove(blockId.asInstanceOf[ShuffleBlockChunkId]) + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) } // TODO: release the buf here to free memory earlier if (input == null) { @@ -1159,7 +910,7 @@ final class ShuffleBlockFetcherIterator( result = null case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) => - if (isRemote(address)) { + if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 bytesInFlight -= size } @@ -1167,28 +918,22 @@ final class ShuffleBlockFetcherIterator( reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) } - initiateFallbackBlockFetchForMergedBlock(blockId, address) + numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock( + blockId, address) // Set result to null to trigger another iteration of the while loop to get either // a SuccessFetchResult or a FailureFetchResult. result = null case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, address, _) => - val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] = - new ArrayBuffer[(BlockId, Long, Int)]() - // Remove the original block from numBlocksToFetch. When #splitLocalRemoteBlocks - // is called to add new chunks, numBlocksToFetch is updated to account for the - // chunk requests. + // The original meta request is processed so we decrease numBlocksToFetch by 1. We will + // collect new chunks request and the count of this is added to numBlocksToFetch in + // collectFetchReqsFromMergedBlocks. numBlocksToFetch -= 1 - val approxChunkSize = blockSize / numChunks - for (i <- 0 until numChunks) { - val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) - chunksMetaMap.put(blockChunkId, bitmaps(i)) - logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") - blocksToRequest += ((blockChunkId, approxChunkSize, -1)) - } + val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, reduceId, blockSize, numChunks, bitmaps) val additionalRemoteReqs = new ArrayBuffer[FetchRequest] - collectFetchReqsFromMergedBlocks(address, blocksToRequest.toSeq, additionalRemoteReqs) + collectFetchRequests(address, blocksToRequest, additionalRemoteReqs) fetchRequests ++= additionalRemoteReqs // Set result to null to force another iteration. result = null @@ -1196,7 +941,8 @@ final class ShuffleBlockFetcherIterator( case MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address, _) => // If we fail to fetch the merged status of a merged block, we fall back to fetching the // unmerged blocks. - initiateFallbackBlockFetchForMergedBlock(ShuffleBlockId(shuffleId, -1, reduceId), address) + numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock( + ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId), address) // Set result to null to force another iteration. result = null } @@ -1222,26 +968,6 @@ final class ShuffleBlockFetcherIterator( onCompleteCallback.onComplete(context)) } - private def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) - } - - /** - * Returns true if the address is of a remote block manager or host-local. false otherwise. - */ - private def isRemote(address: BlockManagerId): Boolean = { - // If the executor id is empty then it is a merged block so we compare the hosts to check - // if it's remote. Otherwise, compare the blockManager Id. - (BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) && - address.host != blockManager.blockManagerId.host) || - (!BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) && - address != blockManager.blockManagerId) - } - - private def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { - chunksMetaMap(blockId).getCardinality - } - private def fetchUpToMaxBytes(): Unit = { if (isNettyOOMOnShuffle.get()) { if (reqsInFlight > 0) { @@ -1326,6 +1052,82 @@ final class ShuffleBlockFetcherIterator( "Failed to get block " + blockId + ", which is not a shuffle block", e) } } + + /** + * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator + */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = { + numBlocksToFetch += moreBlocksToFetch + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure with a shuffle merged block/chunk. + */ + private[storage] def fetchFallbackBlocks( + fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { + val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val fallbackHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr, + fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(fallbackRemoteReqs) + logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged") + // If there is any fall back block that's a local block, we get them here. The original + // invocation to fetchLocalBlocks might have already returned by this time, so we need + // to invoke it again here. + fetchLocalBlocks(fallbackLocalBlocks) + // Merged local blocks should be empty during fallback + assert(fallbackMergedLocalBlocks.isEmpty, + "There should be zero merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host as the block chunk that had + * a fetch failure. + * + * @return set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleBlockChunk(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll(req => { + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleBlockChunk(firstBlock.blockId) + }) + fetchRequestsToRemove.foreach(req => { + removedChunkIds ++= req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId]) + }) + } + + filterRequests(fetchRequests) + val defRequests = deferredFetchRequests.remove(address).orNull + if (defRequests != null) { + filterRequests(defRequests) + if (defRequests.nonEmpty) { + deferredFetchRequests(address) = defRequests + } + } + removedChunkIds + } } /** @@ -1638,3 +1440,249 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult } + +/** + * Helper class that encapsulates all the push-based functionality to fetch merged block meta + * and merged shuffle block chunks. + */ +private class PushBasedFetchHelper( + private val iterator: ShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker) extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[this] val localShuffleMergerBlockMgrId = BlockManagerId( + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, + blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) + + /** A map for storing merged block shuffle chunk bitmap */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** + * Returns true if the address is for a push-merged block. + */ + def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) + } + + /** + * Returns true if the address is not of executor local or merged local block. false otherwise. + */ + def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = { + (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) || + (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId) + } + + /** + * Returns true if the address if of merged local block. false otherwise. + */ + def isMergedLocal(address: BlockManagerId): Boolean = { + isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { + chunksMetaMap(blockId).getCardinality + } + + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + reduceId: Int, + blockSize: Long, + numChunks: Int, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / numChunks + val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] = + new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- 0 until numChunks) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToRequest + } + + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { + logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId, + sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address)) + } catch { + case _: Throwable => + iterator.addToResultsQueue( + MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + + override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { + logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", exception) + iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + req.blocks.foreach(block => { + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId] + shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId, + shuffleBlockId.reduceId, mergedBlocksMetaListener) + }) + } + + // Fetch all outstanding merged local blocks + def fetchAllMergedLocalBlocks( + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (mergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) + } + } + + /** + * Fetch the merged local blocks dirs/blocks.. + */ + private def fetchMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( + BlockManagerId.SHUFFLE_MERGER_IDENTIFIER) + if (cachedMergerDirs.isDefined) { + logDebug(s"Fetching local merged blocks with cached executors dir: " + + s"${cachedMergerDirs.get.mkString(", ")}") + mergedLocalBlocks.foreach(blockId => + fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId)) + } else { + logDebug(s"Asynchronous fetching local merged blocks without cached executors dir") + hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, + localShuffleMergerBlockMgrId.port, Array(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + mergedLocalBlocks.takeWhile { + blockId => + logDebug(s"Successfully fetched local dirs: " + + s"${dirs.get(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchMergedLocalBlock(blockId, dirs(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + logDebug(s"Got local merged blocks (without cached executors' dir) in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + case Failure(throwable) => + // If we see an exception with getting the local dirs for local merged blocks, + // we fallback to fetch the original unmerged blocks. We do not report block fetch + // failure. + logWarning(s"Error occurred while getting the local dirs for local merged " + + s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable) + mergedLocalBlocks.foreach( + blockId => iterator.addToResultsQueue( + IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) + ) + } + } + } + + /** + * Fetch a single local merged block generated. + * @param blockId ShuffleBlockId to be fetched + * @param localDirs Local directories where the merged shuffle files are stored + * @param blockManagerId BlockManagerId + * @return Boolean represents successful or failed fetch + */ + private[this] def fetchMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs) + .readChunkBitmaps() + // Fetch local merged shuffle block data as multiple chunks + val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs) + // Update total number of blocks to fetch, reflecting the multiple local chunks + iterator.foundMoreBlocksToFetch(bufs.size - 1) + for (chunkId <- bufs.indices) { + val buf = bufs(chunkId) + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId, + shuffleBlockId.reduceId, chunkId) + iterator.addToResultsQueue( + SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf, + isNetworkReqDone = false)) + chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) + } + true + } catch { + case e: Exception => + // If we see an exception with reading a local merged block, we fallback to + // fetch the original unmerged blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning(s"Error occurred while fetching local merged block, " + + s"prepare to fetch the original blocks", e) + iterator.addToResultsQueue( + IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + false + } + } + + /** + * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed + * to fetch. + * It calls out to map output tracker to get the list of original blocks for the + * given merged blocks, split them into remote and local blocks, and process them + * accordingly. + * The fallback happens when: + * 1. There is an exception while creating shuffle block chunk from local merged shuffle block. + * See fetchLocalBlock. + * 2. There is a failure when fetching remote shuffle block chunks. + * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk + * (local or remote). + * + * @return number of blocks processed + */ + def initiateFallbackBlockFetchForMergedBlock( + blockId: BlockId, + address: BlockManagerId): Int = { + logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId") + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + var blocksProcessed = 1 + val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = + if (blockId.isShuffle) { + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, shuffleBlockId.reduceId) + } else { + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull + if (isNotExecutorOrMergedLocal(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + if (pendingShuffleChunks.nonEmpty) { + pendingShuffleChunks.foreach { pendingBlockId => + logWarning(s"Falling back immediately for merged block $pendingBlockId") + val bitmapOfPendingChunk: RoaringBitmap = + chunksMetaMap.remove(pendingBlockId).orNull + assert(bitmapOfPendingChunk != null) + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + } + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) + } + iterator.fetchFallbackBlocks(fallbackBlocksByAddr) + blocksProcessed + } +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 64d5bc354e49..004ee51a0e41 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -36,12 +36,13 @@ import org.roaringbitmap.RoaringBitmap import org.scalatest.PrivateMethodTester import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.apache.spark.util.Utils @@ -1057,7 +1058,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("fetch merged blocks meta") { val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), - toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1)), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), (BlockManagerId("remote-client-1", "remote-host-1", 1), toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)) ) @@ -1110,7 +1111,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), - toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1)), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))) val blockChunks = Map[BlockId, ManagedBuffer]( @@ -1188,7 +1189,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val dirsForMergedData = localDirsMap(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER) doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, -1, 2), dirsForMergedData) + .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData) // Get a valid chunk meta for this test val bitmaps = Array(new RoaringBitmap) @@ -1199,8 +1200,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } else { createMockMergedBlockMeta(bitmaps.length, bitmaps) } - when(blockManager.getMergedBlockMeta(ShuffleBlockId(0, -1, 2), dirsForMergedData)) - .thenReturn(mergedBlockMeta) + when(blockManager.getMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + dirsForMergedData)).thenReturn(mergedBlockMeta) when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( Seq((localBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) @@ -1210,7 +1211,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val mergedBmId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1) Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), - (mergedBmId, toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1))) + (mergedBmId, toBlockList( + Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) } private def verifyLocalBlocksFromFallback(iterator: ShuffleBlockFetcherIterator): Unit = { @@ -1231,7 +1233,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, -1, 2), localDirs) + .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) verifyLocalBlocksFromFallback(iterator) @@ -1247,7 +1249,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal(blockManager, hostLocalDirs) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, -1, 2), Array("local-dir")) + .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) // host local read for a shuffle block doReturn(createMockManagedBuffer()).when(blockManager) .getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir")) @@ -1277,7 +1279,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, hostLocalDirs) ++ hostLocalBlocks doThrow(new RuntimeException("Forced error")).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, -1, 2), Array("local-dir")) + .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) // host Local read for this original shuffle block doReturn(createMockManagedBuffer()).when(blockManager) .getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir")) @@ -1311,7 +1313,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // This will throw an IOException when input stream is created from the ManagedBuffer doReturn(Seq({ new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100) - })).when(blockManager).getMergedBlockData(ShuffleBlockId(0, -1, 2), localDirs) + })).when(blockManager).getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) verifyLocalBlocksFromFallback(iterator) @@ -1324,7 +1326,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) val corruptBuffer = createMockManagedBuffer(2) doReturn(Seq({corruptBuffer})).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, -1, 2), localDirs) + .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) val corruptStream = mock(classOf[InputStream]) when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) doReturn(corruptStream).when(corruptBuffer).createInputStream() @@ -1376,7 +1378,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .thenReturn(fallbackBlocksByAddr.iterator) val iterator = createShuffleBlockIteratorWithDefaults(Map( BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) -> - toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 12L, -1))) + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 12L, SHUFFLE_PUSH_MAP_ID))) val (id1, _) = iterator.next() blocksSem.acquire(1) assert(id1 === ShuffleBlockChunkId(0, 2, 0)) @@ -1413,7 +1415,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteMergedBlockMgrId = BlockManagerId( BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1) val iterator = createShuffleBlockIteratorWithDefaults( - Map(remoteMergedBlockMgrId -> toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 2L, -1))) + Map(remoteMergedBlockMgrId -> toBlockList( + Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) val (id1, _) = iterator.next() blocksSem.acquire(2) assert(id1 === ShuffleBlockId(0, 0, 2)) @@ -1465,7 +1468,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val iterator = createShuffleBlockIteratorWithDefaults(Map( BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> - toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 16L, -1)), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 16L, SHUFFLE_PUSH_MAP_ID)), maxBytesInFlight = 4) metaSem.acquire(1) val (id1, _) = iterator.next() @@ -1531,13 +1534,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) .thenReturn(fallbackBlocksByAddr.iterator) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2), - Seq((ShuffleBlockId(0, -1, 2) - , 24L, -1)))) val iterator = createShuffleBlockIteratorWithDefaults(Map( BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> - toBlockList(Seq(ShuffleBlockId(0, -1, 2)), 24L, -1)), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 24L, SHUFFLE_PUSH_MAP_ID)), maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1) metaSem.acquire(1) val (id1, _) = iterator.next() From 2d9a98a6fb82dc169b963e229c5394c77f55bc4a Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 26 May 2021 18:12:06 -0700 Subject: [PATCH 07/27] Fixing my review comments --- .../storage/ShuffleBlockFetcherIterator.scala | 20 +++++++------ .../ShuffleBlockFetcherIteratorSuite.scala | 29 ++++++++++--------- 2 files changed, 26 insertions(+), 23 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 69f392456de5..700d943dc178 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -38,6 +38,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER import org.apache.spark.storage.ShuffleBlockFetcherIterator.{FetchBlockInfo, FetchRequest, IgnoreFetchResult, MergedBlocksMetaFailedFetchResult, MergedBlocksMetaFetchResult, SuccessFetchResult} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} @@ -923,7 +924,7 @@ final class ShuffleBlockFetcherIterator( // Set result to null to trigger another iteration of the while loop to get either // a SuccessFetchResult or a FailureFetchResult. result = null - + case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, address, _) => // The original meta request is processed so we decrease numBlocksToFetch by 1. We will @@ -1454,7 +1455,7 @@ private class PushBasedFetchHelper( private[this] val startTimeNs = System.nanoTime() private[this] val localShuffleMergerBlockMgrId = BlockManagerId( - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, + SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) /** A map for storing merged block shuffle chunk bitmap */ @@ -1464,7 +1465,7 @@ private class PushBasedFetchHelper( * Returns true if the address is for a push-merged block. */ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) + SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) } /** @@ -1550,13 +1551,14 @@ private class PushBasedFetchHelper( } /** - * Fetch the merged local blocks dirs/blocks.. + * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local + * blocks. */ private def fetchMergedLocalBlocks( hostLocalDirManager: HostLocalDirManager, mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER) + SHUFFLE_MERGER_IDENTIFIER) if (cachedMergerDirs.isDefined) { logDebug(s"Fetching local merged blocks with cached executors dir: " + s"${cachedMergerDirs.get.mkString(", ")}") @@ -1565,13 +1567,13 @@ private class PushBasedFetchHelper( } else { logDebug(s"Asynchronous fetching local merged blocks without cached executors dir") hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, - localShuffleMergerBlockMgrId.port, Array(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER)) { + localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) { case Success(dirs) => mergedLocalBlocks.takeWhile { blockId => logDebug(s"Successfully fetched local dirs: " + - s"${dirs.get(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") - fetchMergedLocalBlock(blockId, dirs(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER), + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER), localShuffleMergerBlockMgrId) } logDebug(s"Got local merged blocks (without cached executors' dir) in " + @@ -1685,4 +1687,4 @@ private class PushBasedFetchHelper( iterator.fetchFallbackBlocks(fallbackBlocksByAddr) blocksProcessed } -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 004ee51a0e41..6c387152803d 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -42,6 +42,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.apache.spark.util.Utils @@ -1057,7 +1058,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("fetch merged blocks meta") { val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( - (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), (BlockManagerId("remote-client-1", "remote-host-1", 1), toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)) @@ -1110,7 +1111,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("failed to fetch merged blocks meta") { val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( - (BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))) @@ -1187,7 +1188,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT doReturn(blockChunks(ShuffleBlockId(0, 3, 2))).when(blockManager) .getLocalBlockData(ShuffleBlockId(0, 3, 2)) - val dirsForMergedData = localDirsMap(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER) + val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER) doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData) @@ -1208,7 +1209,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(mapOutputTracker.getMapSizesForMergeResult(0, 2, bitmaps(0))) .thenReturn(Seq((localBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) - val mergedBmId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1) + val mergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1) Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), (mergedBmId, toBlockList( @@ -1231,7 +1232,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockManager = mock(classOf[BlockManager]) val localDirs = Array("testPath1", "testPath2") val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( - blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) doThrow(new RuntimeException("Forced error")).when(blockManager) .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, @@ -1245,7 +1246,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // BlockManagerId from another executor on the same host val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) val hostLocalDirs = Map("test-client-1" -> Array("local-dir"), - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir")) + SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir")) val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal(blockManager, hostLocalDirs) doThrow(new RuntimeException("Forced error")).when(blockManager) @@ -1269,7 +1270,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) val fallbackHostLocalBmId = BlockManagerId("test-client-2", "test-local-host", 1) val hostLocalDirs = Map(hostLocalBmId.executorId -> Array("local-dir"), - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir"), + SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir"), fallbackHostLocalBmId.executorId -> Array("local-dir")) val hostLocalBlocks = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( @@ -1309,7 +1310,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockManager = mock(classOf[BlockManager]) val localDirs = Array("local-dir") val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( - blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) // This will throw an IOException when input stream is created from the ManagedBuffer doReturn(Seq({ new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100) @@ -1323,7 +1324,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockManager = mock(classOf[BlockManager]) val localDirs = Array("local-dir") val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( - blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) val corruptBuffer = createMockManagedBuffer(2) doReturn(Seq({corruptBuffer})).when(blockManager) .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) @@ -1340,7 +1341,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blockManager = mock(classOf[BlockManager]) val localDirs = Array("local-dir") val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( - blockManager, Map(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -> localDirs), + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs), failReadingLocalChunksMeta = true) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) @@ -1377,7 +1378,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) .thenReturn(fallbackBlocksByAddr.iterator) val iterator = createShuffleBlockIteratorWithDefaults(Map( - BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) -> + BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) -> toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 12L, SHUFFLE_PUSH_MAP_ID))) val (id1, _) = iterator.next() blocksSem.acquire(1) @@ -1413,7 +1414,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } }) val remoteMergedBlockMgrId = BlockManagerId( - BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1) + SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1) val iterator = createShuffleBlockIteratorWithDefaults( Map(remoteMergedBlockMgrId -> toBlockList( Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) @@ -1467,7 +1468,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .thenReturn(fallbackBlocksByAddr.iterator) val iterator = createShuffleBlockIteratorWithDefaults(Map( - BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> + BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 16L, SHUFFLE_PUSH_MAP_ID)), maxBytesInFlight = 4) metaSem.acquire(1) @@ -1535,7 +1536,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .thenReturn(fallbackBlocksByAddr.iterator) val iterator = createShuffleBlockIteratorWithDefaults(Map( - BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> + BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 24L, SHUFFLE_PUSH_MAP_ID)), maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1) metaSem.acquire(1) From 6fb6f16e48c6cf31ce1fa2e24ea72491d3802d22 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 26 May 2021 23:24:13 -0700 Subject: [PATCH 08/27] More styling and nit fixes --- .../storage/ShuffleBlockFetcherIterator.scala | 24 +++++++++---------- .../ShuffleBlockFetcherIteratorSuite.scala | 8 +++---- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 700d943dc178..313dbb4a70e5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -934,7 +934,7 @@ final class ShuffleBlockFetcherIterator( val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( shuffleId, reduceId, blockSize, numChunks, bitmaps) val additionalRemoteReqs = new ArrayBuffer[FetchRequest] - collectFetchRequests(address, blocksToRequest, additionalRemoteReqs) + collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs) fetchRequests ++= additionalRemoteReqs // Set result to null to force another iteration. result = null @@ -1492,11 +1492,11 @@ private class PushBasedFetchHelper( } def createChunkBlockInfosFromMetaResponse( - shuffleId: Int, - reduceId: Int, - blockSize: Long, - numChunks: Int, - bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + shuffleId: Int, + reduceId: Int, + blockSize: Long, + numChunks: Int, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { val approxChunkSize = blockSize / numChunks val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] = new ArrayBuffer[(BlockId, Long, Int)]() @@ -1544,7 +1544,7 @@ private class PushBasedFetchHelper( // Fetch all outstanding merged local blocks def fetchAllMergedLocalBlocks( - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { if (mergedLocalBlocks.nonEmpty) { blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) } @@ -1555,8 +1555,8 @@ private class PushBasedFetchHelper( * blocks. */ private def fetchMergedLocalBlocks( - hostLocalDirManager: HostLocalDirManager, - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + hostLocalDirManager: HostLocalDirManager, + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( SHUFFLE_MERGER_IDENTIFIER) if (cachedMergerDirs.isDefined) { @@ -1601,9 +1601,9 @@ private class PushBasedFetchHelper( * @return Boolean represents successful or failed fetch */ private[this] def fetchMergedLocalBlock( - blockId: BlockId, - localDirs: Array[String], - blockManagerId: BlockManagerId): Boolean = { + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { try { val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6c387152803d..f4777198072e 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1426,8 +1426,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("failure to fetch a remote merged block chunk initiates the fallback of" + - " deferred shuffle chunks immediately") { + test("failure to fetch a remote merged block chunk initiates the fallback of " + + "deferred shuffle chunks immediately") { val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), // ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in block-chunks. @@ -1493,8 +1493,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2))) } - test("failure to fetch a remote merged block chunk initiates the fallback of" + - " deferred shuffle chunks immediately which got deferred") { + test("failure to fetch a remote merged block chunk initiates the fallback of " + + "deferred shuffle chunks immediately which got deferred") { val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(), From a1a0674426ce40f0c4e2b7078e3086ed173b8b5e Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Fri, 4 Jun 2021 15:56:24 -0700 Subject: [PATCH 09/27] Addressed more of Mridul's comments --- .../org/apache/spark/storage/BlockIdSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index fc48dd85b22c..e8c3c2df261c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -211,6 +211,18 @@ class BlockIdSuite extends SparkFunSuite { assertSame(id, BlockId(id.toString)) } + test("merged shuffle id") { + val id = ShuffleBlockId(1, -1, 0) + assertSame(id, ShuffleBlockId(1, -1, 0)) + assertDifferent(id, ShuffleBlockId(1, 1, 1)) + assert(id.name === "shuffle_1_-1_0") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === -1) + assert(id.reduceId === 0) + assertSame(id, BlockId(id.toString)) + } + test("shuffle chunk") { val id = ShuffleBlockChunkId(1, 1, 0) assertSame(id, ShuffleBlockChunkId(1, 1, 0)) From fcde43dd2a05ca20d95846b30b58a10c5a6dd89f Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Tue, 8 Jun 2021 00:18:44 -0700 Subject: [PATCH 10/27] Removed all the changes which are part of SPARK-35671 --- .../network/client/BaseResponseCallback.java | 31 ----- .../network/client/RpcResponseCallback.java | 5 +- .../spark/network/client/TransportClient.java | 29 +---- .../client/TransportResponseHandler.java | 27 +---- .../spark/network/crypto/AuthRpcHandler.java | 5 - .../protocol/MergedBlockMetaSuccess.java | 92 -------------- .../network/protocol/MessageDecoder.java | 6 - .../server/AbstractAuthRpcHandler.java | 5 - .../server/TransportRequestHandler.java | 26 ---- .../network/TransportRequestHandlerSuite.java | 55 --------- .../TransportResponseHandlerSuite.java | 39 ------ .../protocol/MergedBlockMetaSuccessSuite.java | 101 ---------------- .../shuffle/ExternalShuffleBlockResolver.java | 4 +- .../protocol/AbstractFetchShuffleBlocks.java | 88 -------------- .../protocol/BlockTransferMessage.java | 4 +- .../shuffle/protocol/FetchShuffleBlocks.java | 45 ++++--- .../shuffle/ExternalBlockHandlerSuite.java | 112 ------------------ .../FetchShuffleBlockChunksSuite.java | 42 ------- .../protocol/FetchShuffleBlocksSuite.java | 42 ------- .../yarn/YarnShuffleServiceMetricsSuite.scala | 3 +- 20 files changed, 41 insertions(+), 720 deletions(-) delete mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java delete mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java delete mode 100644 common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java delete mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java delete mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java delete mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java deleted file mode 100644 index d9b7fb2b3bb8..000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.client; - -/** - * A basic callback. This is extended by {@link RpcResponseCallback} and - * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests - * can be handled in {@link TransportResponseHandler} a similar way. - * - * @since 3.2.0 - */ -public interface BaseResponseCallback { - - /** Exception either propagated from server or raised on client side. */ - void onFailure(Throwable e); -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index a3b8cb1d90a2..6afc63f71bb3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -23,7 +23,7 @@ * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ -public interface RpcResponseCallback extends BaseResponseCallback { +public interface RpcResponseCallback { /** * Successful serialized result from server. * @@ -31,4 +31,7 @@ public interface RpcResponseCallback extends BaseResponseCallback { * Please copy the content of `response` if you want to use it after `onSuccess` returns. */ void onSuccess(ByteBuffer response); + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index a50c04cf802a..eb2882074d7c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -200,31 +200,6 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { return requestId; } - /** - * Sends a MergedBlockMetaRequest message to the server. The response of this message is - * either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}. - * - * @param appId applicationId. - * @param shuffleId shuffle id. - * @param reduceId reduce id. - * @param callback callback the handle the reply. - */ - public void sendMergedBlockMetaReq( - String appId, - int shuffleId, - int reduceId, - MergedBlockMetaResponseCallback callback) { - long requestId = requestId(); - if (logger.isTraceEnabled()) { - logger.trace( - "Sending RPC {} to fetch merged block meta to {}", requestId, getRemoteAddress(channel)); - } - handler.addRpcRequest(requestId, callback); - RpcChannelListener listener = new RpcChannelListener(requestId, callback); - channel.writeAndFlush( - new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener); - } - /** * Send data to the remote end as a stream. This differs from stream() in that this is a request * to *send* data to the remote end, not to receive it from the remote. @@ -374,9 +349,9 @@ void handleFailure(String errorMsg, Throwable cause) throws Exception {} private class RpcChannelListener extends StdChannelListener { final long rpcRequestId; - final BaseResponseCallback callback; + final RpcResponseCallback callback; - RpcChannelListener(long rpcRequestId, BaseResponseCallback callback) { + RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { super("RPC " + rpcRequestId); this.rpcRequestId = rpcRequestId; this.callback = callback; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 576c08858d6c..3aac2d2441d2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -33,7 +33,6 @@ import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; @@ -57,7 +56,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingFetches; - private final Map outstandingRpcs; + private final Map outstandingRpcs; private final Queue> streamCallbacks; private volatile boolean streamActive; @@ -82,7 +81,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); } - public void addRpcRequest(long requestId, BaseResponseCallback callback) { + public void addRpcRequest(long requestId, RpcResponseCallback callback) { updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -113,7 +112,7 @@ private void failOutstandingRequests(Throwable cause) { logger.warn("ChunkReceivedCallback.onFailure throws exception", e); } } - for (Map.Entry entry : outstandingRpcs.entrySet()) { + for (Map.Entry entry : outstandingRpcs.entrySet()) { try { entry.getValue().onFailure(cause); } catch (Exception e) { @@ -185,7 +184,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof RpcResponse) { RpcResponse resp = (RpcResponse) message; - RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId); + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); @@ -200,7 +199,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; - BaseResponseCallback listener = outstandingRpcs.get(resp.requestId); + RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.errorString); @@ -208,22 +207,6 @@ public void handle(ResponseMessage message) throws Exception { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } - } else if (message instanceof MergedBlockMetaSuccess) { - MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message; - try { - MergedBlockMetaResponseCallback listener = - (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId); - if (listener == null) { - logger.warn( - "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not" - + " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); - } else { - outstandingRpcs.remove(resp.requestId); - listener.onSuccess(resp.getNumChunks(), resp.body()); - } - } finally { - resp.body().release(); - } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; Pair entry = streamCallbacks.poll(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index 8f0a40c38021..dd31c955350f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -138,9 +138,4 @@ protected boolean doAuthChallenge( LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); return true; } - - @Override - public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { - return saslHandler.getMergedBlockMetaReqHandler(); - } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java deleted file mode 100644 index d2edaf4532e1..000000000000 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; -import org.apache.commons.lang3.builder.ToStringBuilder; -import org.apache.commons.lang3.builder.ToStringStyle; - -import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.buffer.NettyManagedBuffer; - -/** - * Response to {@link MergedBlockMetaRequest} request. - * Note that the server-side encoding of this messages does NOT include the buffer itself. - * - * @since 3.2.0 - */ -public class MergedBlockMetaSuccess extends AbstractResponseMessage { - public final long requestId; - public final int numChunks; - - public MergedBlockMetaSuccess( - long requestId, - int numChunks, - ManagedBuffer chunkBitmapsBuffer) { - super(chunkBitmapsBuffer, true); - this.requestId = requestId; - this.numChunks = numChunks; - } - - @Override - public Type type() { - return Type.MergedBlockMetaSuccess; - } - - @Override - public int hashCode() { - return Objects.hashCode(requestId, numChunks); - } - - @Override - public String toString() { - return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) - .append("requestId", requestId).append("numChunks", numChunks).toString(); - } - - @Override - public int encodedLength() { - return 8 + 4; - } - - /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ - @Override - public void encode(ByteBuf buf) { - buf.writeLong(requestId); - buf.writeInt(numChunks); - } - - public int getNumChunks() { - return numChunks; - } - - /** Decoding uses the given ByteBuf as our data, and will retain() it. */ - public static MergedBlockMetaSuccess decode(ByteBuf buf) { - long requestId = buf.readLong(); - int numChunks = buf.readInt(); - buf.retain(); - NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); - return new MergedBlockMetaSuccess(requestId, numChunks, managedBuf); - } - - @Override - public ResponseMessage createFailureResponse(String error) { - return new RpcFailure(requestId, error); - } -} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index 98f7f612a486..bf80aed0afe1 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -83,12 +83,6 @@ private Message decode(Message.Type msgType, ByteBuf in) { case UploadStream: return UploadStream.decode(in); - case MergedBlockMetaRequest: - return MergedBlockMetaRequest.decode(in); - - case MergedBlockMetaSuccess: - return MergedBlockMetaSuccess.decode(in); - default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 95fde677624f..92eb88628344 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -104,9 +104,4 @@ public void exceptionCaught(Throwable cause, TransportClient client) { public boolean isAuthenticated() { return isAuthenticated; } - - @Override - public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { - return delegate.getMergedBlockMetaReqHandler(); - } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index ab2deac20fcd..4a30f8de0782 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -113,8 +113,6 @@ public void handle(RequestMessage request) throws Exception { processStreamRequest((StreamRequest) request); } else if (request instanceof UploadStream) { processStreamUpload((UploadStream) request); - } else if (request instanceof MergedBlockMetaRequest) { - processMergedBlockMetaRequest((MergedBlockMetaRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -262,30 +260,6 @@ private void processOneWayMessage(OneWayMessage req) { } } - private void processMergedBlockMetaRequest(final MergedBlockMetaRequest req) { - try { - rpcHandler.getMergedBlockMetaReqHandler().receiveMergeBlockMetaReq(reverseClient, req, - new MergedBlockMetaResponseCallback() { - - @Override - public void onSuccess(int numChunks, ManagedBuffer buffer) { - logger.trace("Sending meta for request {} numChunks {}", req, numChunks); - respond(new MergedBlockMetaSuccess(req.requestId, numChunks, buffer)); - } - - @Override - public void onFailure(Throwable e) { - logger.trace("Failed to send meta for {}", req); - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } - }); - } catch (Exception e) { - logger.error("Error while invoking receiveMergeBlockMetaReq() for appId {} shuffleId {} " - + "reduceId {}", req.appId, req.shuffleId, req.appId, e); - respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); - } - } - /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index b3befb8baf2d..0a6447176237 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -17,7 +17,6 @@ package org.apache.spark.network; -import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; @@ -25,19 +24,16 @@ import org.junit.Assert; import org.junit.Test; -import static org.junit.Assert.*; import static org.mockito.Mockito.*; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.network.buffer.ManagedBuffer; -import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.*; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; -import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportRequestHandler; public class TransportRequestHandlerSuite { @@ -113,55 +109,4 @@ public void handleStreamRequest() throws Exception { streamManager.connectionTerminated(channel); Assert.assertEquals(0, streamManager.numStreamStates()); } - - @Test - public void handleMergedBlockMetaRequest() throws Exception { - RpcHandler.MergedBlockMetaReqHandler metaHandler = (client, request, callback) -> { - if (request.shuffleId != -1 && request.reduceId != -1) { - callback.onSuccess(2, mock(ManagedBuffer.class)); - } else { - callback.onFailure(new RuntimeException("empty block")); - } - }; - RpcHandler rpcHandler = new RpcHandler() { - @Override - public void receive( - TransportClient client, - ByteBuffer message, - RpcResponseCallback callback) {} - - @Override - public StreamManager getStreamManager() { - return null; - } - - @Override - public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { - return metaHandler; - } - }; - Channel channel = mock(Channel.class); - List> responseAndPromisePairs = new ArrayList<>(); - when(channel.writeAndFlush(any())).thenAnswer(invocationOnMock0 -> { - Object response = invocationOnMock0.getArguments()[0]; - ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); - responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); - return channelFuture; - }); - - TransportClient reverseClient = mock(TransportClient.class); - TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, - rpcHandler, 2L, null); - MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0); - requestHandler.handle(validMetaReq); - assertEquals(1, responseAndPromisePairs.size()); - assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess); - assertEquals(2, - ((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks()); - - MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1); - requestHandler.handle(invalidMetaReq); - assertEquals(2, responseAndPromisePairs.size()); - assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure); - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index 4de13f951d49..b4032c4c3f03 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -23,20 +23,17 @@ import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; -import org.mockito.ArgumentCaptor; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; -import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; -import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; @@ -170,40 +167,4 @@ public void failOutstandingStreamCallbackOnException() throws Exception { verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); } - - @Test - public void handleSuccessfulMergedBlockMeta() throws Exception { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); - handler.addRpcRequest(13, callback); - assertEquals(1, handler.numOutstandingRequests()); - - // This response should be ignored. - handler.handle(new MergedBlockMetaSuccess(22, 2, - new NioManagedBuffer(ByteBuffer.allocate(7)))); - assertEquals(1, handler.numOutstandingRequests()); - - ByteBuffer resp = ByteBuffer.allocate(10); - handler.handle(new MergedBlockMetaSuccess(13, 2, new NioManagedBuffer(resp))); - ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(NioManagedBuffer.class); - verify(callback, times(1)).onSuccess(eq(2), bufferCaptor.capture()); - assertEquals(resp, bufferCaptor.getValue().nioByteBuffer()); - assertEquals(0, handler.numOutstandingRequests()); - } - - @Test - public void handleFailedMergedBlockMeta() throws Exception { - TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); - MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); - handler.addRpcRequest(51, callback); - assertEquals(1, handler.numOutstandingRequests()); - - // This response should be ignored. - handler.handle(new RpcFailure(6, "failed")); - assertEquals(1, handler.numOutstandingRequests()); - - handler.handle(new RpcFailure(51, "failed")); - verify(callback, times(1)).onFailure(any()); - assertEquals(0, handler.numOutstandingRequests()); - } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java deleted file mode 100644 index f4a055188c86..000000000000 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.protocol; - -import java.io.DataOutputStream; -import java.io.File; -import java.io.FileOutputStream; -import java.nio.file.Files; -import java.util.List; - -import com.google.common.collect.Lists; -import io.netty.buffer.ByteBuf; -import io.netty.buffer.ByteBufAllocator; -import io.netty.buffer.Unpooled; -import io.netty.channel.ChannelHandlerContext; -import org.junit.Assert; -import org.junit.Test; -import org.roaringbitmap.RoaringBitmap; - -import static org.mockito.Mockito.*; - -import org.apache.spark.network.buffer.FileSegmentManagedBuffer; -import org.apache.spark.network.util.ByteArrayWritableChannel; -import org.apache.spark.network.util.TransportConf; - -/** - * Test for {@link MergedBlockMetaSuccess}. - */ -public class MergedBlockMetaSuccessSuite { - - @Test - public void testMergedBlocksMetaEncodeDecode() throws Exception { - File chunkMetaFile = new File("target/mergedBlockMetaTest"); - Files.deleteIfExists(chunkMetaFile.toPath()); - RoaringBitmap chunk1 = new RoaringBitmap(); - chunk1.add(1); - chunk1.add(3); - RoaringBitmap chunk2 = new RoaringBitmap(); - chunk2.add(2); - chunk2.add(4); - RoaringBitmap[] expectedChunks = new RoaringBitmap[]{chunk1, chunk2}; - try (DataOutputStream metaOutput = new DataOutputStream(new FileOutputStream(chunkMetaFile))) { - for (int i = 0; i < expectedChunks.length; i++) { - expectedChunks[i].serialize(metaOutput); - } - } - TransportConf conf = mock(TransportConf.class); - when(conf.lazyFileDescriptor()).thenReturn(false); - long requestId = 1L; - MergedBlockMetaSuccess expectedMeta = new MergedBlockMetaSuccess(requestId, 2, - new FileSegmentManagedBuffer(conf, chunkMetaFile, 0, chunkMetaFile.length())); - - List out = Lists.newArrayList(); - ChannelHandlerContext context = mock(ChannelHandlerContext.class); - when(context.alloc()).thenReturn(ByteBufAllocator.DEFAULT); - - MessageEncoder.INSTANCE.encode(context, expectedMeta, out); - Assert.assertEquals(1, out.size()); - MessageWithHeader msgWithHeader = (MessageWithHeader) out.remove(0); - - ByteArrayWritableChannel writableChannel = - new ByteArrayWritableChannel((int) msgWithHeader.count()); - while (msgWithHeader.transfered() < msgWithHeader.count()) { - msgWithHeader.transferTo(writableChannel, msgWithHeader.transfered()); - } - ByteBuf messageBuf = Unpooled.wrappedBuffer(writableChannel.getData()); - messageBuf.readLong(); // frame length - MessageDecoder.INSTANCE.decode(mock(ChannelHandlerContext.class), messageBuf, out); - Assert.assertEquals(1, out.size()); - MergedBlockMetaSuccess decoded = (MergedBlockMetaSuccess) out.get(0); - Assert.assertEquals("merged block", expectedMeta.requestId, decoded.requestId); - Assert.assertEquals("num chunks", expectedMeta.getNumChunks(), decoded.getNumChunks()); - - ByteBuf responseBuf = Unpooled.wrappedBuffer(decoded.body().nioByteBuffer()); - RoaringBitmap[] responseBitmaps = new RoaringBitmap[expectedMeta.getNumChunks()]; - for (int i = 0; i < expectedMeta.getNumChunks(); i++) { - responseBitmaps[i] = Encoders.Bitmaps.decode(responseBuf); - } - Assert.assertEquals( - "num of roaring bitmaps", expectedMeta.getNumChunks(), responseBitmaps.length); - for (int i = 0; i < expectedMeta.getNumChunks(); i++) { - Assert.assertEquals("chunk bitmap " + i, expectedChunks[i], responseBitmaps[i]); - } - Files.delete(chunkMetaFile.toPath()); - } -} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index 493edd2b3462..a095bf272341 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -361,8 +361,8 @@ public int removeBlocks(String appId, String execId, String[] blockIds) { return numRemovedBlocks; } - public Map getLocalDirs(String appId, Set execIds) { - return execIds.stream() + public Map getLocalDirs(String appId, String[] execIds) { + return Arrays.stream(execIds) .map(exec -> { ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec)); if (info == null) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java deleted file mode 100644 index 0fca27cf26df..000000000000 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import com.google.common.base.Objects; -import io.netty.buffer.ByteBuf; - -import org.apache.commons.lang3.builder.ToStringBuilder; -import org.apache.commons.lang3.builder.ToStringStyle; -import org.apache.spark.network.protocol.Encoders; - -/** - * Base class for fetch shuffle blocks and chunks. - * - * @since 3.2.0 - */ -public abstract class AbstractFetchShuffleBlocks extends BlockTransferMessage { - public final String appId; - public final String execId; - public final int shuffleId; - - protected AbstractFetchShuffleBlocks( - String appId, - String execId, - int shuffleId) { - this.appId = appId; - this.execId = execId; - this.shuffleId = shuffleId; - } - - public ToStringBuilder toStringHelper() { - return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) - .append("appId", appId) - .append("execId", execId) - .append("shuffleId", shuffleId); - } - - /** - * Returns number of blocks in the request. - */ - public abstract int getNumBlocks(); - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - AbstractFetchShuffleBlocks that = (AbstractFetchShuffleBlocks) o; - return shuffleId == that.shuffleId - && Objects.equal(appId, that.appId) && Objects.equal(execId, that.execId); - } - - @Override - public int hashCode() { - int result = appId.hashCode(); - result = 31 * result + execId.hashCode(); - result = 31 * result + shuffleId; - return result; - } - - @Override - public int encodedLength() { - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + 4; /* encoded length of shuffleId */ - } - - @Override - public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - buf.writeInt(shuffleId); - } -} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index a55a6cf7ed93..7f5058124988 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -48,8 +48,7 @@ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), - PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), - FETCH_SHUFFLE_BLOCK_CHUNKS(15); + PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14); private final byte id; @@ -83,7 +82,6 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 12: return PushBlockStream.decode(buf); case 13: return FinalizeShuffleMerge.decode(buf); case 14: return MergeStatuses.decode(buf); - case 15: return FetchShuffleBlockChunks.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index 68550a2fba86..98057d58f7ab 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -20,6 +20,8 @@ import java.util.Arrays; import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; import org.apache.spark.network.protocol.Encoders; @@ -27,7 +29,10 @@ import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ -public class FetchShuffleBlocks extends AbstractFetchShuffleBlocks { +public class FetchShuffleBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. public final long[] mapIds; @@ -45,7 +50,9 @@ public FetchShuffleBlocks( long[] mapIds, int[][] reduceIds, boolean batchFetchEnabled) { - super(appId, execId, shuffleId); + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; this.mapIds = mapIds; this.reduceIds = reduceIds; assert(mapIds.length == reduceIds.length); @@ -62,7 +69,10 @@ public FetchShuffleBlocks( @Override public String toString() { - return toStringHelper() + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("shuffleId", shuffleId) .append("mapIds", Arrays.toString(mapIds)) .append("reduceIds", Arrays.deepToString(reduceIds)) .append("batchFetchEnabled", batchFetchEnabled) @@ -75,40 +85,35 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; FetchShuffleBlocks that = (FetchShuffleBlocks) o; - if (!super.equals(that)) return false; + + if (shuffleId != that.shuffleId) return false; if (batchFetchEnabled != that.batchFetchEnabled) return false; + if (!appId.equals(that.appId)) return false; + if (!execId.equals(that.execId)) return false; if (!Arrays.equals(mapIds, that.mapIds)) return false; return Arrays.deepEquals(reduceIds, that.reduceIds); } @Override public int hashCode() { - int result = super.hashCode(); + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + shuffleId; result = 31 * result + Arrays.hashCode(mapIds); result = 31 * result + Arrays.deepHashCode(reduceIds); result = 31 * result + (batchFetchEnabled ? 1 : 0); return result; } - @Override - public int getNumBlocks() { - if (batchFetchEnabled) { - return mapIds.length; - } - int numBlocks = 0; - for (int[] ids : reduceIds) { - numBlocks += ids.length; - } - return numBlocks; - } - @Override public int encodedLength() { int encodedLengthOfReduceIds = 0; for (int[] ids: reduceIds) { encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids); } - return super.encodedLength() + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 4 /* encoded length of shuffleId */ + Encoders.LongArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ + encodedLengthOfReduceIds @@ -117,7 +122,9 @@ public int encodedLength() { @Override public void encode(ByteBuf buf) { - super.encode(buf); + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); Encoders.LongArrays.encode(buf, mapIds); buf.writeInt(reduceIds.length); for (int[] ids: reduceIds) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index dc41e957f0fc..531657bde481 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -36,16 +36,13 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; -import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; -import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; -import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -266,113 +263,4 @@ public void testFinalizeShuffleMerge() throws IOException { .get("finalizeShuffleMergeLatencyMillis"); assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount()); } - - @Test - public void testFetchMergedBlocksMeta() { - when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn( - new MergedBlockMeta(1, mock(ManagedBuffer.class))); - when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn( - new MergedBlockMeta(3, mock(ManagedBuffer.class))); - when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn( - new MergedBlockMeta(5, mock(ManagedBuffer.class))); - - int[] expectedCount = new int[]{1, 3, 5}; - String appId = "app0"; - long requestId = 0L; - for (int reduceId = 0; reduceId < 3; reduceId++) { - MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId); - MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); - handler.getMergedBlockMetaReqHandler() - .receiveMergeBlockMetaReq(client, req, callback); - verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId); - - ArgumentCaptor numChunksResponse = ArgumentCaptor.forClass(Integer.class); - ArgumentCaptor chunkBitmapResponse = - ArgumentCaptor.forClass(ManagedBuffer.class); - verify(callback, times(1)).onSuccess(numChunksResponse.capture(), - chunkBitmapResponse.capture()); - assertEquals("num chunks in merged block " + reduceId, expectedCount[reduceId], - numChunksResponse.getValue().intValue()); - assertNotNull("chunks bitmap buffer " + reduceId, chunkBitmapResponse.getValue()); - } - } - - @Test - public void testOpenBlocksWithShuffleChunks() { - verifyBlockChunkFetches(true); - } - - @Test - public void testFetchShuffleChunks() { - verifyBlockChunkFetches(false); - } - - private void verifyBlockChunkFetches(boolean useOpenBlocks) { - RpcResponseCallback callback = mock(RpcResponseCallback.class); - ByteBuffer buffer; - if (useOpenBlocks) { - OpenBlocks openBlocks = - new OpenBlocks("app0", "exec1", - new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0", - "shuffleChunk_0_1_1"}); - buffer = openBlocks.toByteBuffer(); - } else { - FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks( - "app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}}); - buffer = fetchChunks.toByteBuffer(); - } - ManagedBuffer[][] buffers = new ManagedBuffer[][] { - { - new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), - new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) - }, - { - new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), - new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) - } - }; - for (int reduceId = 0; reduceId < 2; reduceId++) { - for (int chunkId = 0; chunkId < 2; chunkId++) { - when(mergedShuffleManager.getMergedBlockData( - "app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]); - } - } - handler.receive(client, buffer, callback); - ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); - verify(callback, times(1)).onSuccess(response.capture()); - verify(callback, never()).onFailure(any()); - StreamHandle handle = - (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); - assertEquals(4, handle.numChunks); - - @SuppressWarnings("unchecked") - ArgumentCaptor> stream = (ArgumentCaptor>) - (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(any(), stream.capture(), any()); - Iterator bufferIter = stream.getValue(); - for (int reduceId = 0; reduceId < 2; reduceId++) { - for (int chunkId = 0; chunkId < 2; chunkId++) { - assertEquals(buffers[reduceId][chunkId], bufferIter.next()); - } - } - assertFalse(bufferIter.hasNext()); - verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt()); - verify(blockResolver, never()).getBlockData( - anyString(), anyString(), anyInt(), anyInt(), anyInt()); - verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0); - verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1); - - // Verify open block request latency metrics - Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler) - .getAllMetrics() - .getMetrics() - .get("openBlockRequestLatencyMillis"); - assertEquals(1, openBlockRequestLatencyMillis.getCount()); - // Verify block transfer metrics - Meter blockTransferRateBytes = (Meter) ((ExternalBlockHandler) handler) - .getAllMetrics() - .getMetrics() - .get("blockTransferRateBytes"); - assertEquals(24, blockTransferRateBytes.getCount()); - } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java deleted file mode 100644 index 91f319ded493..000000000000 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import org.junit.Assert; -import org.junit.Test; - -import static org.junit.Assert.*; - -public class FetchShuffleBlockChunksSuite { - - @Test - public void testFetchShuffleBlockChunksEncodeDecode() { - FetchShuffleBlockChunks shuffleBlockChunks = - new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}}); - Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks()); - int len = shuffleBlockChunks.encodedLength(); - Assert.assertEquals(45, len); - ByteBuf buf = Unpooled.buffer(len); - shuffleBlockChunks.encode(buf); - - FetchShuffleBlockChunks decoded = FetchShuffleBlockChunks.decode(buf); - assertEquals(shuffleBlockChunks, decoded); - } -} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java deleted file mode 100644 index a1681f58e7ea..000000000000 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.network.shuffle.protocol; - -import io.netty.buffer.ByteBuf; -import io.netty.buffer.Unpooled; -import org.junit.Assert; -import org.junit.Test; - -import static org.junit.Assert.*; - -public class FetchShuffleBlocksSuite { - - @Test - public void testFetchShuffleBlockEncodeDecode() { - FetchShuffleBlocks fetchShuffleBlocks = - new FetchShuffleBlocks("app0", "exec1", 0, new long[] {0}, new int[][] {{0, 1}}, false); - Assert.assertEquals(2, fetchShuffleBlocks.getNumBlocks()); - int len = fetchShuffleBlocks.encodedLength(); - Assert.assertEquals(50, len); - ByteBuf buf = Unpooled.buffer(len); - fetchShuffleBlocks.encode(buf); - - FetchShuffleBlocks decoded = FetchShuffleBlocks.decode(buf); - assertEquals(fetchShuffleBlocks, decoded); - } -} diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index eff2de714394..388a86037594 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -42,8 +42,7 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", "blockTransferRate", "blockTransferMessageRate", "blockTransferAvgSize_1min", "blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections", - "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis", - "fetchMergedBlocksMetaLatencyMillis") + "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis") // Use sorted Seq instead of Set for easier comparison when there is a mismatch metrics.getMetrics.keySet().asScala.toSeq.sorted should be (allMetrics.sorted) From cbc62fa19109c85f55c24b3a723d25cefb482e8c Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 9 Jun 2021 14:00:36 -0700 Subject: [PATCH 11/27] Addressed wuyi's comments --- .../spark/storage/PushBasedFetchHelper.scala | 289 +++++++++++++++++ .../storage/ShuffleBlockFetcherIterator.scala | 301 ++---------------- .../ShuffleBlockFetcherIteratorSuite.scala | 33 ++ 3 files changed, 347 insertions(+), 276 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala new file mode 100644 index 000000000000..9cc9fda31a75 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.buffer.ManagedBuffer +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch merged block meta and merged shuffle block chunks. + */ +private class PushBasedFetchHelper( + private val iterator: ShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker) extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[this] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, + blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) + + /** A map for storing merged block shuffle chunk bitmap */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** + * Returns true if the address is for a push-merged block. + */ + def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) + } + + /** + * Returns true if the address is not of executor local or merged local block. false otherwise. + */ + def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = { + (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) || + (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId) + } + + /** + * Returns true if the address if of merged local block. false otherwise. + */ + def isMergedLocal(address: BlockManagerId): Boolean = { + isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { + chunksMetaMap(blockId).getCardinality + } + + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + reduceId: Int, + blockSize: Long, + numChunks: Int, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / numChunks + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- 0 until numChunks) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { + logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId, + sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address)) + } catch { + case exception: Throwable => + logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", exception) + iterator.addToResultsQueue( + MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + + override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { + logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", exception) + iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + req.blocks.foreach { block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId] + shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId, + shuffleBlockId.reduceId, mergedBlocksMetaListener) + } + } + + // Fetch all outstanding merged local blocks + def fetchAllMergedLocalBlocks( + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (mergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) + } + } + + /** + * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local + * blocks. + */ + private def fetchMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( + SHUFFLE_MERGER_IDENTIFIER) + if (cachedMergerDirs.isDefined) { + logDebug(s"Fetching local merged blocks with cached executors dir: " + + s"${cachedMergerDirs.get.mkString(", ")}") + mergedLocalBlocks.foreach(blockId => + fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId)) + } else { + logDebug(s"Asynchronous fetching local merged blocks without cached executors dir") + hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, + localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + mergedLocalBlocks.takeWhile { + blockId => + logDebug(s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + logDebug(s"Got local merged blocks (without cached executors' dir) in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + case Failure(throwable) => + // If we see an exception with getting the local dirs for local merged blocks, + // we fallback to fetch the original unmerged blocks. We do not report block fetch + // failure. + logWarning(s"Error occurred while getting the local dirs for local merged " + + s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable) + mergedLocalBlocks.foreach( + blockId => iterator.addToResultsQueue( + IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) + ) + } + } + } + + /** + * Fetch a single local merged block generated. + * @param blockId ShuffleBlockId to be fetched + * @param localDirs Local directories where the merged shuffle files are stored + * @param blockManagerId BlockManagerId + * @return Boolean represents successful or failed fetch + */ + private[this] def fetchMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Boolean = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs) + .readChunkBitmaps() + // Fetch local merged shuffle block data as multiple chunks + val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs) + // Update total number of blocks to fetch, reflecting the multiple local chunks + iterator.foundMoreBlocksToFetch(bufs.size - 1) + for (chunkId <- bufs.indices) { + val buf = bufs(chunkId) + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId, + shuffleBlockId.reduceId, chunkId) + iterator.addToResultsQueue( + SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf, + isNetworkReqDone = false)) + chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) + } + true + } catch { + case e: Exception => + // If we see an exception with reading a local merged block, we fallback to + // fetch the original unmerged blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning(s"Error occurred while fetching local merged block, " + + s"prepare to fetch the original blocks", e) + iterator.addToResultsQueue( + IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + false + } + } + + /** + * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed + * to fetch. + * It calls out to map output tracker to get the list of original blocks for the + * given merged blocks, split them into remote and local blocks, and process them + * accordingly. + * The fallback happens when: + * 1. There is an exception while creating shuffle block chunk from local merged shuffle block. + * See fetchLocalBlock. + * 2. There is a failure when fetching remote shuffle block chunks. + * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk + * (local or remote). + * + * @return number of blocks processed + */ + def initiateFallbackBlockFetchForMergedBlock( + blockId: BlockId, + address: BlockManagerId): Int = { + logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId") + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + var blocksProcessed = 1 + val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = + if (blockId.isShuffle) { + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, shuffleBlockId.reduceId) + } else { + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull + // When there is a failure to fetch a remote merged shuffle block chunk, then we try to + // fallback not only for that particular remote shuffle block chunk but also for all the + // pending block chunks that belong to the same host. The reason for doing so is that it is + // very likely that the subsequent requests for merged block chunks from this host will fail + // as well. Since, push-based shuffle is best effort and we try not to increase the delay + // of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isNotExecutorOrMergedLocal(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + if (pendingShuffleChunks.nonEmpty) { + pendingShuffleChunks.foreach { pendingBlockId => + logWarning(s"Falling back immediately for merged block $pendingBlockId") + val bitmapOfPendingChunk: RoaringBitmap = + chunksMetaMap.remove(pendingBlockId).orNull + assert(bitmapOfPendingChunk != null) + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + } + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) + } + iterator.fetchFallbackBlocks(fallbackBlocksByAddr) + blocksProcessed + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 313dbb4a70e5..ed2e2d31f315 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -38,8 +38,6 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ import org.apache.spark.network.util.{NettyUtils, TransportConf} import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER -import org.apache.spark.storage.ShuffleBlockFetcherIterator.{FetchBlockInfo, FetchRequest, IgnoreFetchResult, MergedBlocksMetaFailedFetchResult, MergedBlocksMetaFetchResult, SuccessFetchResult} import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} /** @@ -251,10 +249,6 @@ final class ShuffleBlockFetcherIterator( private[this] def sendRequest(req: FetchRequest): Unit = { logDebug("Sending request for %d blocks (%s) from %s".format( req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort)) - if (req.hasMergedBlocks) { - pushBasedFetchHelper.sendFetchMergedStatusRequest(req) - return - } bytesInFlight += req.size reqsInFlight += 1 @@ -367,7 +361,7 @@ final class ShuffleBlockFetcherIterator( */ private[this] def partitionBlocksByFetchMode( blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], - localBlocks: scala.collection.mutable.LinkedHashSet[(BlockId, Int)], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " @@ -389,13 +383,14 @@ final class ShuffleBlockFetcherIterator( // These are push-based merged blocks or chunks of these merged blocks. if (address.host == blockManager.blockManagerId.host) { checkBlockSizes(blockInfos) - val pushMergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( - blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch = false) + val pushMergedBlockInfos = blockInfos.map( + info => FetchBlockInfo(info._1, info._2, info._3)) numBlocksToFetch += pushMergedBlockInfos.size mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId) - mergedLocalBlockBytes += pushMergedBlockInfos.map(_.size).sum + val size = pushMergedBlockInfos.map(_.size).sum logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " + - s"of size $mergedLocalBlockBytes") + s"of size $size") + mergedLocalBlockBytes += size } else { remoteBlockBytes += blockInfos.map(_._2).sum collectFetchRequests(address, blockInfos, collectedRemoteRequests) @@ -438,7 +433,7 @@ final class ShuffleBlockFetcherIterator( s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " + s"the number of merged-local blocks ${mergedLocalBlocks.size} " + s"+ the number of remote blocks ${numRemoteBlocks} ") - logInfo(s"[${context.taskAttemptId()}] Getting $blocksToFetchCurrentIteration " + + logInfo(s"Getting $blocksToFetchCurrentIteration " + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " + @@ -454,10 +449,10 @@ final class ShuffleBlockFetcherIterator( private def createFetchRequest( blocks: Seq[FetchBlockInfo], address: BlockManagerId, - areMergedBlocks: Boolean = false): FetchRequest = { + forMergedMetas: Boolean = false): FetchRequest = { logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + s"with ${blocks.size} blocks") - FetchRequest(address, blocks, areMergedBlocks) + FetchRequest(address, blocks, forMergedMetas) } private def createFetchRequests( @@ -466,16 +461,16 @@ final class ShuffleBlockFetcherIterator( isLast: Boolean, collectedRemoteRequests: ArrayBuffer[FetchRequest], enableBatchFetch: Boolean, - areMergedBlocks: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + forMergedBlocks: Boolean = false): ArrayBuffer[FetchBlockInfo] = { val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) numBlocksToFetch += mergedBlocks.size val retBlocks = new ArrayBuffer[FetchBlockInfo] if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { - collectedRemoteRequests += createFetchRequest(mergedBlocks, address, areMergedBlocks) + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) } else { mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks => if (blocks.length == maxBlocksInFlightPerAddress || isLast) { - collectedRemoteRequests += createFetchRequest(blocks, address, areMergedBlocks) + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) } else { // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back // to `curBlocks`. @@ -503,7 +498,7 @@ final class ShuffleBlockFetcherIterator( blockId match { // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks. // Based on these types, we decide to do batch fetch and create FetchRequests with - // hasMergedBlocks set. + // forMergedMetas set. case ShuffleBlockChunkId(_, _, _) => if (curRequestSize >= targetRemoteRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { @@ -515,7 +510,7 @@ final class ShuffleBlockFetcherIterator( if (curRequestSize >= targetRemoteRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { curBlocks = createFetchRequests(curBlocks, address, isLast = false, - collectedRemoteRequests, enableBatchFetch = false, areMergedBlocks = true) + collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true) curRequestSize = curBlocks.map(_.size).sum } case _ => @@ -538,7 +533,7 @@ final class ShuffleBlockFetcherIterator( } } createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests, - enableBatchFetch = enableBatchFetch, areMergedBlocks = areMergedBlocks) + enableBatchFetch = enableBatchFetch, forMergedBlocks = areMergedBlocks) } } @@ -570,8 +565,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(new SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, - buf.size(), buf, false)) + results.put(SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, + buf.size(), buf, isNetworkReqDone = false)) } catch { // If we see an exception, stop immediately. case e: Exception => @@ -926,7 +921,7 @@ final class ShuffleBlockFetcherIterator( result = null case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, - address, _) => + address, _) => // The original meta request is processed so we decrease numBlocksToFetch by 1. We will // collect new chunks request and the count of this is added to numBlocksToFetch in // collectFetchReqsFromMergedBlocks. @@ -1013,9 +1008,10 @@ final class ShuffleBlockFetcherIterator( } def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { - sendRequest(request) - if (!request.hasMergedBlocks) { - // Not updating any metrics for chunk count requests. + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) numBlocksInFlightPerAddress(remoteAddress) = numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } @@ -1338,13 +1334,13 @@ object ShuffleBlockFetcherIterator { * A request to fetch blocks from a remote BlockManager. * @param address remote BlockManager to fetch from. * @param blocks Sequence of the information for blocks to fetch from the same address. - * @param hasMergedBlocks true if this request contains merged blocks; false if it contains - * regular or shuffle block chunks. + * @param forMergedMetas true if this request is for requesting merged meta information; + * false if it is for regular or shuffle block chunks. */ case class FetchRequest( address: BlockManagerId, blocks: Seq[FetchBlockInfo], - hasMergedBlocks: Boolean = false) { + forMergedMetas: Boolean = false) { val size = blocks.map(_.size).sum } @@ -1441,250 +1437,3 @@ object ShuffleBlockFetcherIterator { address: BlockManagerId, blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult } - -/** - * Helper class that encapsulates all the push-based functionality to fetch merged block meta - * and merged shuffle block chunks. - */ -private class PushBasedFetchHelper( - private val iterator: ShuffleBlockFetcherIterator, - private val shuffleClient: BlockStoreClient, - private val blockManager: BlockManager, - private val mapOutputTracker: MapOutputTracker) extends Logging { - - private[this] val startTimeNs = System.nanoTime() - - private[this] val localShuffleMergerBlockMgrId = BlockManagerId( - SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, - blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) - - /** A map for storing merged block shuffle chunk bitmap */ - private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() - - /** - * Returns true if the address is for a push-merged block. - */ - def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { - SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) - } - - /** - * Returns true if the address is not of executor local or merged local block. false otherwise. - */ - def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = { - (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) || - (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId) - } - - /** - * Returns true if the address if of merged local block. false otherwise. - */ - def isMergedLocal(address: BlockManagerId): Boolean = { - isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host - } - - def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { - chunksMetaMap(blockId).getCardinality - } - - def removeChunk(blockId: ShuffleBlockChunkId): Unit = { - chunksMetaMap.remove(blockId) - } - - def createChunkBlockInfosFromMetaResponse( - shuffleId: Int, - reduceId: Int, - blockSize: Long, - numChunks: Int, - bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { - val approxChunkSize = blockSize / numChunks - val blocksToRequest: ArrayBuffer[(BlockId, Long, Int)] = - new ArrayBuffer[(BlockId, Long, Int)]() - for (i <- 0 until numChunks) { - val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) - chunksMetaMap.put(blockChunkId, bitmaps(i)) - logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") - blocksToRequest += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) - } - blocksToRequest - } - - def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { - val sizeMap = req.blocks.map { - case FetchBlockInfo(blockId, size, _) => - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap - val address = req.address - val mergedBlocksMetaListener = new MergedBlocksMetaListener { - override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { - logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId) " + - s"from ${req.address.host}:${req.address.port}") - try { - iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId, - sizeMap(shuffleId, reduceId), meta.getNumChunks, meta.readChunkBitmaps(), address)) - } catch { - case _: Throwable => - iterator.addToResultsQueue( - MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) - } - } - - override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { - logError(s"Failed to get the meta of merged blocks for ($shuffleId, $reduceId) " + - s"from ${req.address.host}:${req.address.port}", exception) - iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) - } - } - req.blocks.foreach(block => { - val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId] - shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId, - shuffleBlockId.reduceId, mergedBlocksMetaListener) - }) - } - - // Fetch all outstanding merged local blocks - def fetchAllMergedLocalBlocks( - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { - if (mergedLocalBlocks.nonEmpty) { - blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) - } - } - - /** - * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local - * blocks. - */ - private def fetchMergedLocalBlocks( - hostLocalDirManager: HostLocalDirManager, - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { - val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( - SHUFFLE_MERGER_IDENTIFIER) - if (cachedMergerDirs.isDefined) { - logDebug(s"Fetching local merged blocks with cached executors dir: " + - s"${cachedMergerDirs.get.mkString(", ")}") - mergedLocalBlocks.foreach(blockId => - fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId)) - } else { - logDebug(s"Asynchronous fetching local merged blocks without cached executors dir") - hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, - localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) { - case Success(dirs) => - mergedLocalBlocks.takeWhile { - blockId => - logDebug(s"Successfully fetched local dirs: " + - s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") - fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER), - localShuffleMergerBlockMgrId) - } - logDebug(s"Got local merged blocks (without cached executors' dir) in " + - s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") - case Failure(throwable) => - // If we see an exception with getting the local dirs for local merged blocks, - // we fallback to fetch the original unmerged blocks. We do not report block fetch - // failure. - logWarning(s"Error occurred while getting the local dirs for local merged " + - s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", - throwable) - mergedLocalBlocks.foreach( - blockId => iterator.addToResultsQueue( - IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) - ) - } - } - } - - /** - * Fetch a single local merged block generated. - * @param blockId ShuffleBlockId to be fetched - * @param localDirs Local directories where the merged shuffle files are stored - * @param blockManagerId BlockManagerId - * @return Boolean represents successful or failed fetch - */ - private[this] def fetchMergedLocalBlock( - blockId: BlockId, - localDirs: Array[String], - blockManagerId: BlockManagerId): Boolean = { - try { - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs) - .readChunkBitmaps() - // Fetch local merged shuffle block data as multiple chunks - val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs) - // Update total number of blocks to fetch, reflecting the multiple local chunks - iterator.foundMoreBlocksToFetch(bufs.size - 1) - for (chunkId <- bufs.indices) { - val buf = bufs(chunkId) - buf.retain() - val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId, - shuffleBlockId.reduceId, chunkId) - iterator.addToResultsQueue( - SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf, - isNetworkReqDone = false)) - chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) - } - true - } catch { - case e: Exception => - // If we see an exception with reading a local merged block, we fallback to - // fetch the original unmerged blocks. We do not report block fetch failure - // and will continue with the remaining local block read. - logWarning(s"Error occurred while fetching local merged block, " + - s"prepare to fetch the original blocks", e) - iterator.addToResultsQueue( - IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) - false - } - } - - /** - * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed - * to fetch. - * It calls out to map output tracker to get the list of original blocks for the - * given merged blocks, split them into remote and local blocks, and process them - * accordingly. - * The fallback happens when: - * 1. There is an exception while creating shuffle block chunk from local merged shuffle block. - * See fetchLocalBlock. - * 2. There is a failure when fetching remote shuffle block chunks. - * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk - * (local or remote). - * - * @return number of blocks processed - */ - def initiateFallbackBlockFetchForMergedBlock( - blockId: BlockId, - address: BlockManagerId): Int = { - logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId") - // Increase the blocks processed since we will process another block in the next iteration of - // the while loop in ShuffleBlockFetcherIterator.next(). - var blocksProcessed = 1 - val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = - if (blockId.isShuffle) { - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - mapOutputTracker.getMapSizesForMergeResult( - shuffleBlockId.shuffleId, shuffleBlockId.reduceId) - } else { - val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] - val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull - if (isNotExecutorOrMergedLocal(address)) { - // Fallback for all the pending fetch requests - val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) - if (pendingShuffleChunks.nonEmpty) { - pendingShuffleChunks.foreach { pendingBlockId => - logWarning(s"Falling back immediately for merged block $pendingBlockId") - val bitmapOfPendingChunk: RoaringBitmap = - chunksMetaMap.remove(pendingBlockId).orNull - assert(bitmapOfPendingChunk != null) - chunkBitmap.or(bitmapOfPendingChunk) - } - // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed - blocksProcessed += pendingShuffleChunks.size - } - } - mapOutputTracker.getMapSizesForMergeResult( - shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) - } - iterator.fetchFallbackBlocks(fallbackBlocksByAddr) - blocksProcessed - } -} diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index f4777198072e..c5e6d366c2e7 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1149,6 +1149,39 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } + test("iterator has just 1 merged block and fails to fetch the meta") { + val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer() + ) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((remoteBmId, toBlockList( + Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error")) + } + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) + val (id1, _) = iterator.next() + blocksSem.acquire(2) + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 1, 2)) + assert(!iterator.hasNext) + } + private def createMockMergedBlockMeta( numChunks: Int, bitmaps: Array[RoaringBitmap]): MergedBlockMeta = { From c073a8e021454fff3cb71a362a278d0642dde1a5 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 9 Jun 2021 14:11:46 -0700 Subject: [PATCH 12/27] Fixed indentation of PushBasedFetchHelper --- .../org/apache/spark/storage/PushBasedFetchHelper.scala | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 9cc9fda31a75..4d89986e78b9 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -38,10 +38,10 @@ import org.apache.spark.storage.ShuffleBlockFetcherIterator._ * functionality to fetch merged block meta and merged shuffle block chunks. */ private class PushBasedFetchHelper( - private val iterator: ShuffleBlockFetcherIterator, - private val shuffleClient: BlockStoreClient, - private val blockManager: BlockManager, - private val mapOutputTracker: MapOutputTracker) extends Logging { + private val iterator: ShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker) extends Logging { private[this] val startTimeNs = System.nanoTime() From 52b53a6e4ec979546887bf91833417c71f1f7fd7 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 9 Jun 2021 15:41:34 -0700 Subject: [PATCH 13/27] Adddressed Mridu's comments, accounting for meta requests in numBlocksInFlightPerAddress, and added more comments --- .../storage/ShuffleBlockFetcherIterator.scala | 30 +++++++++++++------ 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index ed2e2d31f315..368c32fc452e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -507,11 +507,9 @@ final class ShuffleBlockFetcherIterator( curRequestSize = curBlocks.map(_.size).sum } case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => - if (curRequestSize >= targetRemoteRequestSize || - curBlocks.size >= maxBlocksInFlightPerAddress) { + if (curBlocks.size >= maxBlocksInFlightPerAddress) { curBlocks = createFetchRequests(curBlocks, address, isLast = false, collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true) - curRequestSize = curBlocks.map(_.size).sum } case _ => // For batch fetch, the actual block in flight should count for merged block. @@ -906,6 +904,12 @@ final class ShuffleBlockFetcherIterator( result = null case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote merged shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the local merged data. In this case, the blockId is ShuffleBlockId. + // 3. Failure to get the local merged directories from the ESS. In this case, the blockId + // is ShuffleBlockId. if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 bytesInFlight -= size @@ -922,9 +926,10 @@ final class ShuffleBlockFetcherIterator( case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, address, _) => - // The original meta request is processed so we decrease numBlocksToFetch by 1. We will - // collect new chunks request and the count of this is added to numBlocksToFetch in - // collectFetchReqsFromMergedBlocks. + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new chunks request and the count of + // this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 numBlocksToFetch -= 1 val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( shuffleId, reduceId, blockSize, numChunks, bitmaps) @@ -935,6 +940,10 @@ final class ShuffleBlockFetcherIterator( result = null case MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address, _) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + // However, instead of decreasing numBlocksToFetch by 1, we increment numBlocksProcessed + // which has the same effect. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 // If we fail to fetch the merged status of a merged block, we fall back to fetching the // unmerged blocks. numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock( @@ -1012,9 +1021,9 @@ final class ShuffleBlockFetcherIterator( pushBasedFetchHelper.sendFetchMergedStatusRequest(request) } else { sendRequest(request) - numBlocksInFlightPerAddress(remoteAddress) = - numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { @@ -1391,7 +1400,10 @@ object ShuffleBlockFetcherIterator { case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult /** - * Result of a fetch from a remote merged block unsuccessfully. + * Result of an un-successful fetch of either of these: + * 1) Remote shuffle block chunk. + * 2) Local merged block data. + * * Instead of treating this as a FailureFetchResult, we ignore this failure * and fallback to fetch the original unmerged blocks. * @param blockId block id From 9ae32fb797d2602909f3b5fe5562c1d36e2d76d6 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 9 Jun 2021 22:12:10 -0700 Subject: [PATCH 14/27] addressed wuyi's comments --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 368c32fc452e..40fa4ff18d97 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -379,10 +379,10 @@ final class ShuffleBlockFetcherIterator( val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId for ((address, blockInfos) <- blocksByAddress) { + checkBlockSizes(blockInfos) if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) { // These are push-based merged blocks or chunks of these merged blocks. if (address.host == blockManager.blockManagerId.host) { - checkBlockSizes(blockInfos) val pushMergedBlockInfos = blockInfos.map( info => FetchBlockInfo(info._1, info._2, info._3)) numBlocksToFetch += pushMergedBlockInfos.size @@ -397,7 +397,6 @@ final class ShuffleBlockFetcherIterator( } } else if ( Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) { - checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size @@ -405,7 +404,6 @@ final class ShuffleBlockFetcherIterator( localBlockBytes += mergedBlockInfos.map(_.size).sum } else if (blockManager.hostLocalDirManager.isDefined && address.host == blockManager.blockManagerId.host) { - checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size @@ -492,7 +490,6 @@ final class ShuffleBlockFetcherIterator( while (iterator.hasNext) { val (blockId, size, mapIndex) = iterator.next() - assertPositiveBlockSize(blockId, size) curBlocks += FetchBlockInfo(blockId, size, mapIndex) curRequestSize += size blockId match { From daf8d8e3f1893af4f5b9042af5da5ac5a6dc1792 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Thu, 10 Jun 2021 11:06:56 -0700 Subject: [PATCH 15/27] Addressing Mridul's comments --- .../spark/storage/PushBasedFetchHelper.scala | 39 ++++++++++++++++--- .../storage/ShuffleBlockFetcherIterator.scala | 10 +++-- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 4d89986e78b9..4405dd1f03ba 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -82,6 +82,18 @@ private class PushBasedFetchHelper( chunksMetaMap.remove(blockId) } + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.MergedBlocksMetaFetchResult]]. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param blockSize size of the merged block. + * @param numChunks number of chunks in the merged block. + * @param bitmaps per chunk bitmap, where each bitmap contains all the mapIds that are merged + * to that chunk. + * @return shuffle chunks to fetch. + */ def createChunkBlockInfosFromMetaResponse( shuffleId: Int, reduceId: Int, @@ -99,6 +111,13 @@ private class PushBasedFetchHelper( blocksToFetch } + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch + * metadata of merged blocks. + */ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { val sizeMap = req.blocks.map { case FetchBlockInfo(blockId, size, _) => @@ -134,9 +153,13 @@ private class PushBasedFetchHelper( } } - // Fetch all outstanding merged local blocks + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding merged local blocks. + * @param mergedLocalBlocks set of identified merged local blocks. + */ def fetchAllMergedLocalBlocks( - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { if (mergedLocalBlocks.nonEmpty) { blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) } @@ -229,9 +252,15 @@ private class PushBasedFetchHelper( } /** - * Initiate fetching fallback blocks for a merged block (or a merged block chunk) that's failed - * to fetch. - * It calls out to map output tracker to get the list of original blocks for the + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: + * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] + * 2) [[ShuffleBlockFetcherIterator.IgnoreFetchResult]] + * 3) [[ShuffleBlockFetcherIterator.MergedBlocksMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a merged block (or a merged block chunk) that + * failed to fetch. + * It makes a call to the map output tracker to get the list of original blocks for the * given merged blocks, split them into remote and local blocks, and process them * accordingly. * The fallback happens when: diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 40fa4ff18d97..192c0d1d6bf9 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -928,10 +928,10 @@ final class ShuffleBlockFetcherIterator( // this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 numBlocksToFetch -= 1 - val blocksToRequest = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( shuffleId, reduceId, blockSize, numChunks, bitmaps) val additionalRemoteReqs = new ArrayBuffer[FetchRequest] - collectFetchRequests(address, blocksToRequest.toSeq, additionalRemoteReqs) + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) fetchRequests ++= additionalRemoteReqs // Set result to null to force another iteration. result = null @@ -1069,7 +1069,9 @@ final class ShuffleBlockFetcherIterator( /** * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch - * failure with a shuffle merged block/chunk. + * failure for a shuffle merged block/chunk. + * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates + * fallback. */ private[storage] def fetchFallbackBlocks( fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { @@ -1096,6 +1098,8 @@ final class ShuffleBlockFetcherIterator( /** * Removes all the pending shuffle chunks that are on the same host as the block chunk that had * a fetch failure. + * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates + * fallback. * * @return set of all the removed shuffle chunk Ids. */ From 6e602af9731b6c329ba9cb07b69b16d286ff55f3 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Thu, 10 Jun 2021 16:01:44 -0700 Subject: [PATCH 16/27] Rebasing against master --- .../apache/spark/network/BlockDataManager.scala | 15 --------------- .../spark/shuffle/ShuffleBlockResolver.scala | 5 ----- .../org/apache/spark/storage/BlockManager.scala | 7 ------- 3 files changed, 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 205addb2097c..cafb39ea82ad 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -22,7 +22,6 @@ import scala.reflect.ClassTag import org.apache.spark.TaskContext import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.StreamCallbackWithID -import org.apache.spark.network.shuffle.MergedBlockMeta import org.apache.spark.storage.{BlockId, StorageLevel} private[spark] @@ -72,18 +71,4 @@ trait BlockDataManager { * Release locks acquired by [[putBlockData()]] and [[getLocalBlockData()]]. */ def releaseLock(blockId: BlockId, taskContext: Option[TaskContext]): Unit - - /** - * Interface to get merged shuffle block data. Throws an exception if the block cannot be found - * or cannot be read successfully. - */ - // PART OF SPARK-33350 - def getMergedBlockData(blockId: BlockId, dirs: Array[String]): Seq[ManagedBuffer] - - /** - * Interface to get merged shuffle block meta. Throws an exception if the meta cannot be found - * or cannot be read successfully. - */ - // PART OF SPARK-33350 - def getMergedBlockMeta(blockId: BlockId, dirs: Array[String]): MergedBlockMeta } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 2d839254e1a7..49e59298cc0c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -51,10 +51,5 @@ trait ShuffleBlockResolver { */ def getMergedBlockMeta(blockId: ShuffleBlockId, dirs: Option[Array[String]]): MergedBlockMeta - /** - * Retrieve the meta data for the specified merged shuffle block. - */ - def getMergedBlockMeta(blockId: ShuffleBlockId): MergedBlockMeta - def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a7718184701f..df449fba24e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -750,13 +750,6 @@ private[spark] class BlockManager( shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId, Some(dirs)) } - /** - * Get the local merged shuffle block metada data for the given block ID. - */ - def getMergedBlockMeta(blockId: ShuffleBlockId): MergedBlockMeta = { - shuffleManager.shuffleBlockResolver.getMergedBlockMeta(blockId) - } - /** * Get the BlockStatus for the block identified by the given ID, if it exists. * NOTE: This is mainly for testing. From c698c37cce88b4204d34c83f92ccc3d84a25ecf9 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Thu, 10 Jun 2021 16:49:22 -0700 Subject: [PATCH 17/27] Addressed Mridul's comments, rebased with master, and another UT --- .../spark/storage/PushBasedFetchHelper.scala | 101 ++++++++++-------- .../storage/ShuffleBlockFetcherIterator.scala | 67 +++++------- .../ShuffleBlockFetcherIteratorSuite.scala | 61 ++++++++--- 3 files changed, 135 insertions(+), 94 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 4405dd1f03ba..b2c4aa016d8f 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -56,15 +56,15 @@ private class PushBasedFetchHelper( * Returns true if the address is for a push-merged block. */ def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { - SHUFFLE_MERGER_IDENTIFIER.equals(address.executorId) + SHUFFLE_MERGER_IDENTIFIER == address.executorId } /** - * Returns true if the address is not of executor local or merged local block. false otherwise. + * Returns true if the address is of a remote merged block. */ - def isNotExecutorOrMergedLocal(address: BlockManagerId): Boolean = { - (isMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host) || - (!isMergedShuffleBlockAddress(address) && address != blockManager.blockManagerId) + def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = { + assert(isMergedShuffleBlockAddress(address)) + address.host != blockManager.blockManagerId.host } /** @@ -74,17 +74,29 @@ private class PushBasedFetchHelper( isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host } + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId shuffle block chunk id. + */ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { chunksMetaMap(blockId).getCardinality } + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId shuffle block chunk id. + */ def removeChunk(blockId: ShuffleBlockChunkId): Unit = { chunksMetaMap.remove(blockId) } /** * This is executed by the task thread when the `iterator.next()` is invoked and the iterator - * processes a response of type [[ShuffleBlockFetcherIterator.MergedBlocksMetaFetchResult]]. + * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]]. * * @param shuffleId shuffle id. * @param reduceId reduce id. @@ -122,28 +134,29 @@ private class PushBasedFetchHelper( val sizeMap = req.blocks.map { case FetchBlockInfo(blockId, size, _) => val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)}.toMap + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap val address = req.address val mergedBlocksMetaListener = new MergedBlocksMetaListener { override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId) " + s"from ${req.address.host}:${req.address.port}") try { - iterator.addToResultsQueue(MergedBlocksMetaFetchResult(shuffleId, reduceId, + iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId, sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address)) } catch { case exception: Throwable => logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " + s"from ${req.address.host}:${req.address.port}", exception) iterator.addToResultsQueue( - MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + MergedMetaFailedFetchResult(shuffleId, reduceId, address)) } } override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " + s"from ${req.address.host}:${req.address.port}", exception) - iterator.addToResultsQueue(MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address)) + iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address)) } } req.blocks.foreach { block => @@ -221,12 +234,12 @@ private class PushBasedFetchHelper( blockManagerId: BlockManagerId): Boolean = { try { val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - val chunksMeta = blockManager.getMergedBlockMeta(shuffleBlockId, localDirs) + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) .readChunkBitmaps() // Fetch local merged shuffle block data as multiple chunks - val bufs: Seq[ManagedBuffer] = blockManager.getMergedBlockData(shuffleBlockId, localDirs) + val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) // Update total number of blocks to fetch, reflecting the multiple local chunks - iterator.foundMoreBlocksToFetch(bufs.size - 1) + iterator.incrementNumBlocksToFetch(bufs.size - 1) for (chunkId <- bufs.indices) { val buf = bufs(chunkId) buf.retain() @@ -256,7 +269,7 @@ private class PushBasedFetchHelper( * processes a response of type: * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] * 2) [[ShuffleBlockFetcherIterator.IgnoreFetchResult]] - * 3) [[ShuffleBlockFetcherIterator.MergedBlocksMetaFailedFetchResult]] + * 3) [[ShuffleBlockFetcherIterator.MergedMetaFailedFetchResult]] * * This initiates fetching fallback blocks for a merged block (or a merged block chunk) that * failed to fetch. @@ -275,42 +288,44 @@ private class PushBasedFetchHelper( def initiateFallbackBlockFetchForMergedBlock( blockId: BlockId, address: BlockManagerId): Int = { + assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId") // Increase the blocks processed since we will process another block in the next iteration of // the while loop in ShuffleBlockFetcherIterator.next(). var blocksProcessed = 1 val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = - if (blockId.isShuffle) { - val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] - mapOutputTracker.getMapSizesForMergeResult( - shuffleBlockId.shuffleId, shuffleBlockId.reduceId) - } else { - val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] - val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull - // When there is a failure to fetch a remote merged shuffle block chunk, then we try to - // fallback not only for that particular remote shuffle block chunk but also for all the - // pending block chunks that belong to the same host. The reason for doing so is that it is - // very likely that the subsequent requests for merged block chunks from this host will fail - // as well. Since, push-based shuffle is best effort and we try not to increase the delay - // of the fetches, we immediately fallback for all the pending shuffle chunks in the - // fetchRequests queue. - if (isNotExecutorOrMergedLocal(address)) { - // Fallback for all the pending fetch requests - val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) - if (pendingShuffleChunks.nonEmpty) { - pendingShuffleChunks.foreach { pendingBlockId => - logWarning(s"Falling back immediately for merged block $pendingBlockId") - val bitmapOfPendingChunk: RoaringBitmap = - chunksMetaMap.remove(pendingBlockId).orNull - assert(bitmapOfPendingChunk != null) - chunkBitmap.or(bitmapOfPendingChunk) + blockId match { + case shuffleBlockId: ShuffleBlockId => + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull + assert(chunkBitmap != null) + // When there is a failure to fetch a remote merged shuffle block chunk, then we try to + // fallback not only for that particular remote shuffle block chunk but also for all the + // pending block chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for merged block chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isMergedBlockAddressRemote(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + if (pendingShuffleChunks.nonEmpty) { + pendingShuffleChunks.foreach { pendingBlockId => + logInfo(s"Falling back immediately for merged block $pendingBlockId") + val bitmapOfPendingChunk: RoaringBitmap = + chunksMetaMap.remove(pendingBlockId).orNull + assert(bitmapOfPendingChunk != null) + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size } - // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed - blocksProcessed += pendingShuffleChunks.size } - } - mapOutputTracker.getMapSizesForMergeResult( - shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) } iterator.fetchFallbackBlocks(fallbackBlocksByAddr) blocksProcessed diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 192c0d1d6bf9..d9f9969e98b8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -383,20 +383,15 @@ final class ShuffleBlockFetcherIterator( if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) { // These are push-based merged blocks or chunks of these merged blocks. if (address.host == blockManager.blockManagerId.host) { - val pushMergedBlockInfos = blockInfos.map( - info => FetchBlockInfo(info._1, info._2, info._3)) - numBlocksToFetch += pushMergedBlockInfos.size - mergedLocalBlocks ++= pushMergedBlockInfos.map(info => info.blockId) - val size = pushMergedBlockInfos.map(_.size).sum - logInfo(s"Got ${pushMergedBlockInfos.size} local merged blocks " + - s"of size $size") - mergedLocalBlockBytes += size + numBlocksToFetch += blockInfos.size + mergedLocalBlocks ++= blockInfos.map(_._1) + mergedLocalBlockBytes += blockInfos.map(_._3).sum } else { remoteBlockBytes += blockInfos.map(_._2).sum collectFetchRequests(address, blockInfos, collectedRemoteRequests) } - } else if ( - Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) { + } else if (mutable.HashSet(blockManager.blockManagerId.executorId, fallback) + .contains(address.executorId)) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size @@ -438,16 +433,14 @@ final class ShuffleBlockFetcherIterator( s"host-local and ${mergedLocalBlocks.size} (${Utils.bytesToString(mergedLocalBlockBytes)}) " + s"local merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + s"remote blocks") - if (hostLocalBlocksCurrentIteration.nonEmpty) { - this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration - } + this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration collectedRemoteRequests } private def createFetchRequest( blocks: Seq[FetchBlockInfo], address: BlockManagerId, - forMergedMetas: Boolean = false): FetchRequest = { + forMergedMetas: Boolean): FetchRequest = { logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + s"with ${blocks.size} blocks") FetchRequest(address, blocks, forMergedMetas) @@ -907,7 +900,7 @@ final class ShuffleBlockFetcherIterator( // 2. Failure to read the local merged data. In this case, the blockId is ShuffleBlockId. // 3. Failure to get the local merged directories from the ESS. In this case, the blockId // is ShuffleBlockId. - if (pushBasedFetchHelper.isNotExecutorOrMergedLocal(address)) { + if (pushBasedFetchHelper.isMergedBlockAddressRemote(address)) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 bytesInFlight -= size } @@ -921,7 +914,7 @@ final class ShuffleBlockFetcherIterator( // a SuccessFetchResult or a FailureFetchResult. result = null - case MergedBlocksMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, + case MergedMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, address, _) => // The original meta request is processed so we decrease numBlocksToFetch and // numBlocksInFlightPerAddress by 1. We will collect new chunks request and the count of @@ -936,7 +929,7 @@ final class ShuffleBlockFetcherIterator( // Set result to null to force another iteration. result = null - case MergedBlocksMetaFailedFetchResult(shuffleId, reduceId, address, _) => + case MergedMetaFailedFetchResult(shuffleId, reduceId, address, _) => // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. // However, instead of decreasing numBlocksToFetch by 1, we increment numBlocksProcessed // which has the same effect. @@ -1063,7 +1056,7 @@ final class ShuffleBlockFetcherIterator( results.put(result) } - private[storage] def foundMoreBlocksToFetch(moreBlocksToFetch: Int): Unit = { + private[storage] def incrementNumBlocksToFetch(moreBlocksToFetch: Int): Unit = { numBlocksToFetch += moreBlocksToFetch } @@ -1084,9 +1077,7 @@ final class ShuffleBlockFetcherIterator( // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(fallbackRemoteReqs) logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged") - // If there is any fall back block that's a local block, we get them here. The original - // invocation to fetchLocalBlocks might have already returned by this time, so we need - // to invoke it again here. + // fetch all the fallback blocks that are local. fetchLocalBlocks(fallbackLocalBlocks) // Merged local blocks should be empty during fallback assert(fallbackMergedLocalBlocks.isEmpty, @@ -1126,13 +1117,10 @@ final class ShuffleBlockFetcherIterator( } filterRequests(fetchRequests) - val defRequests = deferredFetchRequests.remove(address).orNull - if (defRequests != null) { + deferredFetchRequests.get(address).foreach(defRequests => { filterRequests(defRequests) - if (defRequests.nonEmpty) { - deferredFetchRequests(address) = defRequests - } - } + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + }) removedChunkIds } } @@ -1243,8 +1231,8 @@ object ShuffleBlockFetcherIterator { } /** - * Dummy shuffle block id to fill into [[MergedBlocksMetaFetchResult]] and - * [[MergedBlocksMetaFailedFetchResult]], to match the [[FetchResult]] trait. + * Dummy shuffle block id to fill into [[MergedMetaFetchResult]] and + * [[MergedMetaFailedFetchResult]], to match the [[FetchResult]] trait. */ private val DUMMY_SHUFFLE_BLOCK_ID = ShuffleBlockId(-1, -1, -1) @@ -1405,8 +1393,9 @@ object ShuffleBlockFetcherIterator { * 1) Remote shuffle block chunk. * 2) Local merged block data. * - * Instead of treating this as a FailureFetchResult, we ignore this failure - * and fallback to fetch the original unmerged blocks. + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original + * unmerged blocks. + * * @param blockId block id * @param address BlockManager that the merged block was attempted to be fetched from * @param size size of the block, used to update bytesInFlight. @@ -1421,14 +1410,14 @@ object ShuffleBlockFetcherIterator { /** * Result of a successful fetch of meta information for a merged block. * - * @param shuffleId shuffle id. - * @param reduceId reduce id. - * @param blockSize size of each merged block. - * @param numChunks number of chunks in the merged block. - * @param bitmaps bitmaps for every chunk. - * @param address BlockManager that the merged status was fetched from. + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param blockSize size of each merged block. + * @param numChunks number of chunks in the merged block. + * @param bitmaps bitmaps for every chunk. + * @param address BlockManager that the merged status was fetched from. */ - private[storage] case class MergedBlocksMetaFetchResult( + private[storage] case class MergedMetaFetchResult( shuffleId: Int, reduceId: Int, blockSize: Long, @@ -1444,7 +1433,7 @@ object ShuffleBlockFetcherIterator { * @param reduceId reduce id. * @param address BlockManager that the merged status was fetched from. */ - private[storage] case class MergedBlocksMetaFailedFetchResult( + private[storage] case class MergedMetaFailedFetchResult( shuffleId: Int, reduceId: Int, address: BlockManagerId, diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c5e6d366c2e7..adaec777205f 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1205,25 +1205,25 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT doReturn(localBmId).when(blockManager).blockManagerId initHostLocalDirManager(blockManager, localDirsMap) - val blockChunks = Map[BlockId, ManagedBuffer]( + val blockBuffers = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer() ) - doReturn(blockChunks(ShuffleBlockId(0, 0, 2))).when(blockManager) + doReturn(blockBuffers(ShuffleBlockId(0, 0, 2))).when(blockManager) .getLocalBlockData(ShuffleBlockId(0, 0, 2)) - doReturn(blockChunks(ShuffleBlockId(0, 1, 2))).when(blockManager) + doReturn(blockBuffers(ShuffleBlockId(0, 1, 2))).when(blockManager) .getLocalBlockData(ShuffleBlockId(0, 1, 2)) - doReturn(blockChunks(ShuffleBlockId(0, 2, 2))).when(blockManager) + doReturn(blockBuffers(ShuffleBlockId(0, 2, 2))).when(blockManager) .getLocalBlockData(ShuffleBlockId(0, 2, 2)) - doReturn(blockChunks(ShuffleBlockId(0, 3, 2))).when(blockManager) + doReturn(blockBuffers(ShuffleBlockId(0, 3, 2))).when(blockManager) .getLocalBlockData(ShuffleBlockId(0, 3, 2)) val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER) doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData) // Get a valid chunk meta for this test val bitmaps = Array(new RoaringBitmap) @@ -1234,7 +1234,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } else { createMockMergedBlockMeta(bitmaps.length, bitmaps) } - when(blockManager.getMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData)).thenReturn(mergedBlockMeta) when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( Seq((localBmId, @@ -1267,12 +1267,48 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) verifyLocalBlocksFromFallback(iterator) } + test("failed to fetch merged block as well as fallback block should throw " + + "a FetchFailedException") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + val localBmId = BlockManagerId("test-client", "test-local-host", 1) + doReturn(localBmId).when(blockManager).blockManagerId + val localDirsMap = Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs) + initHostLocalDirManager(blockManager, localDirsMap) + + doReturn(createMockManagedBuffer()).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 0, 2)) + // Force to fail reading of original block (0, 1, 2) that will throw a FetchFailed exception. + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 1, 2)) + + val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER) + // Since bitmaps is null, this will fail reading the merged block meta causing fallback to + // initiate. + val mergedBlockMeta: MergedBlockMeta = createMockMergedBlockMeta(2, null) + when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + dirsForMergedData)).thenReturn(mergedBlockMeta) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((localBmId, + toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) + + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1), toBlockList( + Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + // 1st instance of iterator.next() returns the original shuffle block (0, 0, 2) + assert(iterator.next()._1 === ShuffleBlockId(0, 0, 2)) + // 2nd instance of iterator.next() throws FetchFailedException + intercept[FetchFailedException] { iterator.next() } + } + test("failed to fetch local merged blocks then fallback to fetch original shuffle " + "blocks which contains host-local blocks") { val blockManager = mock(classOf[BlockManager]) @@ -1283,7 +1319,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal(blockManager, hostLocalDirs) doThrow(new RuntimeException("Forced error")).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) // host local read for a shuffle block doReturn(createMockManagedBuffer()).when(blockManager) .getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir")) @@ -1313,7 +1349,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, hostLocalDirs) ++ hostLocalBlocks doThrow(new RuntimeException("Forced error")).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) // host Local read for this original shuffle block doReturn(createMockManagedBuffer()).when(blockManager) .getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir")) @@ -1347,7 +1383,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // This will throw an IOException when input stream is created from the ManagedBuffer doReturn(Seq({ new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100) - })).when(blockManager).getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + })).when(blockManager).getLocalMergedBlockData( + ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) verifyLocalBlocksFromFallback(iterator) @@ -1360,7 +1397,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) val corruptBuffer = createMockManagedBuffer(2) doReturn(Seq({corruptBuffer})).when(blockManager) - .getMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) val corruptStream = mock(classOf[InputStream]) when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) doReturn(corruptStream).when(corruptBuffer).createInputStream() From 429b75933896ee81aa1faa222466b94a2618927c Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Thu, 10 Jun 2021 23:28:39 -0700 Subject: [PATCH 18/27] Addressed Mridul's comments --- .../spark/storage/PushBasedFetchHelper.scala | 23 +++++++++---------- .../storage/ShuffleBlockFetcherIterator.scala | 6 ++--- 2 files changed, 14 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index b2c4aa016d8f..6268b8251574 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -50,7 +50,7 @@ private class PushBasedFetchHelper( blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) /** A map for storing merged block shuffle chunk bitmap */ - private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]() /** * Returns true if the address is for a push-merged block. @@ -81,7 +81,7 @@ private class PushBasedFetchHelper( * @param blockId shuffle block chunk id. */ def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { - chunksMetaMap(blockId).getCardinality + chunksMetaMap.get(blockId).getCardinality } /** @@ -145,7 +145,7 @@ private class PushBasedFetchHelper( iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId, sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address)) } catch { - case exception: Throwable => + case exception: Exception => logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " + s"from ${req.address.host}:${req.address.port}", exception) iterator.addToResultsQueue( @@ -214,8 +214,8 @@ private class PushBasedFetchHelper( s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", throwable) mergedLocalBlocks.foreach( - blockId => iterator.addToResultsQueue( - IgnoreFetchResult(blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) + blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult( + blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) ) } } @@ -245,10 +245,10 @@ private class PushBasedFetchHelper( buf.retain() val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunkId) + chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) iterator.addToResultsQueue( SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf, isNetworkReqDone = false)) - chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) } true } catch { @@ -259,7 +259,7 @@ private class PushBasedFetchHelper( logWarning(s"Error occurred while fetching local merged block, " + s"prepare to fetch the original blocks", e) iterator.addToResultsQueue( - IgnoreFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + FallbackOnMergedFailureFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) false } } @@ -268,7 +268,7 @@ private class PushBasedFetchHelper( * This is executed by the task thread when the `iterator.next()` is invoked and the iterator * processes a response of type: * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] - * 2) [[ShuffleBlockFetcherIterator.IgnoreFetchResult]] + * 2) [[ShuffleBlockFetcherIterator.FallbackOnMergedFailureFetchResult]] * 3) [[ShuffleBlockFetcherIterator.MergedMetaFailedFetchResult]] * * This initiates fetching fallback blocks for a merged block (or a merged block chunk) that @@ -300,7 +300,7 @@ private class PushBasedFetchHelper( shuffleBlockId.shuffleId, shuffleBlockId.reduceId) case _ => val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] - val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).orNull + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId) assert(chunkBitmap != null) // When there is a failure to fetch a remote merged shuffle block chunk, then we try to // fallback not only for that particular remote shuffle block chunk but also for all the @@ -315,8 +315,7 @@ private class PushBasedFetchHelper( if (pendingShuffleChunks.nonEmpty) { pendingShuffleChunks.foreach { pendingBlockId => logInfo(s"Falling back immediately for merged block $pendingBlockId") - val bitmapOfPendingChunk: RoaringBitmap = - chunksMetaMap.remove(pendingBlockId).orNull + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId) assert(bitmapOfPendingChunk != null) chunkBitmap.or(bitmapOfPendingChunk) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index d9f9969e98b8..de74dcfd21a8 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -334,7 +334,7 @@ final class ShuffleBlockFetcherIterator( if (block.isShuffleChunk) { remainingBlocks -= blockId results.put( - IgnoreFetchResult(block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) + FallbackOnMergedFailureFetchResult(block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) } else { results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) } @@ -893,7 +893,7 @@ final class ShuffleBlockFetcherIterator( defReqQueue.enqueue(request) result = null - case IgnoreFetchResult(blockId, address, size, isNetworkReqDone) => + case FallbackOnMergedFailureFetchResult(blockId, address, size, isNetworkReqDone) => // We get this result in 3 cases: // 1. Failure to fetch the data of a remote merged shuffle chunk. In this case, the // blockId is a ShuffleBlockChunkId. @@ -1402,7 +1402,7 @@ object ShuffleBlockFetcherIterator { * @param isNetworkReqDone Is this the last network request for this host in this fetch * request. Used to update reqsInFlight. */ - private[storage] case class IgnoreFetchResult(blockId: BlockId, + private[storage] case class FallbackOnMergedFailureFetchResult(blockId: BlockId, address: BlockManagerId, size: Long, isNetworkReqDone: Boolean) extends FetchResult From 926f0b9c87e89372f44b13115b95d2d4977fe9d1 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Thu, 10 Jun 2021 23:35:47 -0700 Subject: [PATCH 19/27] Added clarifying comments about chunksMetaMap concurrency --- .../org/apache/spark/storage/PushBasedFetchHelper.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 6268b8251574..9b3051f28670 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -49,7 +49,10 @@ private class PushBasedFetchHelper( SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) - /** A map for storing merged block shuffle chunk bitmap */ + /** + * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it + * can be modified by both the task thread and the netty thread. + */ private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]() /** @@ -222,7 +225,8 @@ private class PushBasedFetchHelper( } /** - * Fetch a single local merged block generated. + * Fetch a single local merged block generated. This can also be executed by the task thread as + * well as the netty thread. * @param blockId ShuffleBlockId to be fetched * @param localDirs Local directories where the merged shuffle files are stored * @param blockManagerId BlockManagerId From 94d7c5e1461f6324834c4e2a00fb9447530cc895 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Fri, 11 Jun 2021 07:10:13 -0700 Subject: [PATCH 20/27] Fixed the line length --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index de74dcfd21a8..3d42bbc8bd08 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -333,8 +333,8 @@ final class ShuffleBlockFetcherIterator( val block = BlockId(blockId) if (block.isShuffleChunk) { remainingBlocks -= blockId - results.put( - FallbackOnMergedFailureFetchResult(block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) + results.put(FallbackOnMergedFailureFetchResult( + block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) } else { results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) } From 67bd821632824c66589d8528f035045739ec3ba9 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 16 Jun 2021 15:06:29 -0700 Subject: [PATCH 21/27] Addressed yi.wu's comments --- .../network/client/BaseResponseCallback.java | 31 ++ .../network/client/RpcResponseCallback.java | 5 +- .../spark/network/client/TransportClient.java | 29 +- .../client/TransportResponseHandler.java | 27 +- .../spark/network/crypto/AuthRpcHandler.java | 5 + .../protocol/MergedBlockMetaSuccess.java | 92 ++++++ .../network/protocol/MessageDecoder.java | 6 + .../server/AbstractAuthRpcHandler.java | 5 + .../server/TransportRequestHandler.java | 26 ++ .../network/TransportRequestHandlerSuite.java | 55 ++++ .../TransportResponseHandlerSuite.java | 39 +++ .../protocol/MergedBlockMetaSuccessSuite.java | 101 +++++++ .../shuffle/ExternalShuffleBlockResolver.java | 4 +- .../protocol/AbstractFetchShuffleBlocks.java | 88 ++++++ .../protocol/BlockTransferMessage.java | 4 +- .../shuffle/protocol/FetchShuffleBlocks.java | 45 ++- .../shuffle/ExternalBlockHandlerSuite.java | 112 +++++++ .../FetchShuffleBlockChunksSuite.java | 42 +++ .../protocol/FetchShuffleBlocksSuite.java | 42 +++ .../spark/storage/PushBasedFetchHelper.scala | 219 +++++++------- .../storage/ShuffleBlockFetcherIterator.scala | 276 ++++++++++-------- .../ShuffleBlockFetcherIteratorSuite.scala | 171 ++++++----- .../yarn/YarnShuffleServiceMetricsSuite.scala | 3 +- 23 files changed, 1075 insertions(+), 352 deletions(-) create mode 100644 common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java create mode 100644 common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java create mode 100644 common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java create mode 100644 common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java create mode 100644 common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java new file mode 100644 index 000000000000..d9b7fb2b3bb8 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/client/BaseResponseCallback.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +/** + * A basic callback. This is extended by {@link RpcResponseCallback} and + * {@link MergedBlockMetaResponseCallback} so that both RpcRequests and MergedBlockMetaRequests + * can be handled in {@link TransportResponseHandler} a similar way. + * + * @since 3.2.0 + */ +public interface BaseResponseCallback { + + /** Exception either propagated from server or raised on client side. */ + void onFailure(Throwable e); +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java index 6afc63f71bb3..a3b8cb1d90a2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java @@ -23,7 +23,7 @@ * Callback for the result of a single RPC. This will be invoked once with either success or * failure. */ -public interface RpcResponseCallback { +public interface RpcResponseCallback extends BaseResponseCallback { /** * Successful serialized result from server. * @@ -31,7 +31,4 @@ public interface RpcResponseCallback { * Please copy the content of `response` if you want to use it after `onSuccess` returns. */ void onSuccess(ByteBuffer response); - - /** Exception either propagated from server or raised on client side. */ - void onFailure(Throwable e); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index eb2882074d7c..a50c04cf802a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -200,6 +200,31 @@ public long sendRpc(ByteBuffer message, RpcResponseCallback callback) { return requestId; } + /** + * Sends a MergedBlockMetaRequest message to the server. The response of this message is + * either a {@link MergedBlockMetaSuccess} or {@link RpcFailure}. + * + * @param appId applicationId. + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param callback callback the handle the reply. + */ + public void sendMergedBlockMetaReq( + String appId, + int shuffleId, + int reduceId, + MergedBlockMetaResponseCallback callback) { + long requestId = requestId(); + if (logger.isTraceEnabled()) { + logger.trace( + "Sending RPC {} to fetch merged block meta to {}", requestId, getRemoteAddress(channel)); + } + handler.addRpcRequest(requestId, callback); + RpcChannelListener listener = new RpcChannelListener(requestId, callback); + channel.writeAndFlush( + new MergedBlockMetaRequest(requestId, appId, shuffleId, reduceId)).addListener(listener); + } + /** * Send data to the remote end as a stream. This differs from stream() in that this is a request * to *send* data to the remote end, not to receive it from the remote. @@ -349,9 +374,9 @@ void handleFailure(String errorMsg, Throwable cause) throws Exception {} private class RpcChannelListener extends StdChannelListener { final long rpcRequestId; - final RpcResponseCallback callback; + final BaseResponseCallback callback; - RpcChannelListener(long rpcRequestId, RpcResponseCallback callback) { + RpcChannelListener(long rpcRequestId, BaseResponseCallback callback) { super("RPC " + rpcRequestId); this.rpcRequestId = rpcRequestId; this.callback = callback; diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 3aac2d2441d2..576c08858d6c 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -33,6 +33,7 @@ import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.ResponseMessage; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; @@ -56,7 +57,7 @@ public class TransportResponseHandler extends MessageHandler { private final Map outstandingFetches; - private final Map outstandingRpcs; + private final Map outstandingRpcs; private final Queue> streamCallbacks; private volatile boolean streamActive; @@ -81,7 +82,7 @@ public void removeFetchRequest(StreamChunkId streamChunkId) { outstandingFetches.remove(streamChunkId); } - public void addRpcRequest(long requestId, RpcResponseCallback callback) { + public void addRpcRequest(long requestId, BaseResponseCallback callback) { updateTimeOfLastRequest(); outstandingRpcs.put(requestId, callback); } @@ -112,7 +113,7 @@ private void failOutstandingRequests(Throwable cause) { logger.warn("ChunkReceivedCallback.onFailure throws exception", e); } } - for (Map.Entry entry : outstandingRpcs.entrySet()) { + for (Map.Entry entry : outstandingRpcs.entrySet()) { try { entry.getValue().onFailure(cause); } catch (Exception e) { @@ -184,7 +185,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof RpcResponse) { RpcResponse resp = (RpcResponse) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + RpcResponseCallback listener = (RpcResponseCallback) outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); @@ -199,7 +200,7 @@ public void handle(ResponseMessage message) throws Exception { } } else if (message instanceof RpcFailure) { RpcFailure resp = (RpcFailure) message; - RpcResponseCallback listener = outstandingRpcs.get(resp.requestId); + BaseResponseCallback listener = outstandingRpcs.get(resp.requestId); if (listener == null) { logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding", resp.requestId, getRemoteAddress(channel), resp.errorString); @@ -207,6 +208,22 @@ public void handle(ResponseMessage message) throws Exception { outstandingRpcs.remove(resp.requestId); listener.onFailure(new RuntimeException(resp.errorString)); } + } else if (message instanceof MergedBlockMetaSuccess) { + MergedBlockMetaSuccess resp = (MergedBlockMetaSuccess) message; + try { + MergedBlockMetaResponseCallback listener = + (MergedBlockMetaResponseCallback) outstandingRpcs.get(resp.requestId); + if (listener == null) { + logger.warn( + "Ignoring response for MergedBlockMetaRequest {} from {} ({} bytes) since it is not" + + " outstanding", resp.requestId, getRemoteAddress(channel), resp.body().size()); + } else { + outstandingRpcs.remove(resp.requestId); + listener.onSuccess(resp.getNumChunks(), resp.body()); + } + } finally { + resp.body().release(); + } } else if (message instanceof StreamResponse) { StreamResponse resp = (StreamResponse) message; Pair entry = streamCallbacks.poll(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java index dd31c955350f..8f0a40c38021 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/crypto/AuthRpcHandler.java @@ -138,4 +138,9 @@ protected boolean doAuthChallenge( LOG.debug("Authorization successful for client {}.", channel.remoteAddress()); return true; } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return saslHandler.getMergedBlockMetaReqHandler(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java new file mode 100644 index 000000000000..d2edaf4532e1 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaSuccess.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.buffer.NettyManagedBuffer; + +/** + * Response to {@link MergedBlockMetaRequest} request. + * Note that the server-side encoding of this messages does NOT include the buffer itself. + * + * @since 3.2.0 + */ +public class MergedBlockMetaSuccess extends AbstractResponseMessage { + public final long requestId; + public final int numChunks; + + public MergedBlockMetaSuccess( + long requestId, + int numChunks, + ManagedBuffer chunkBitmapsBuffer) { + super(chunkBitmapsBuffer, true); + this.requestId = requestId; + this.numChunks = numChunks; + } + + @Override + public Type type() { + return Type.MergedBlockMetaSuccess; + } + + @Override + public int hashCode() { + return Objects.hashCode(requestId, numChunks); + } + + @Override + public String toString() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("requestId", requestId).append("numChunks", numChunks).toString(); + } + + @Override + public int encodedLength() { + return 8 + 4; + } + + /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */ + @Override + public void encode(ByteBuf buf) { + buf.writeLong(requestId); + buf.writeInt(numChunks); + } + + public int getNumChunks() { + return numChunks; + } + + /** Decoding uses the given ByteBuf as our data, and will retain() it. */ + public static MergedBlockMetaSuccess decode(ByteBuf buf) { + long requestId = buf.readLong(); + int numChunks = buf.readInt(); + buf.retain(); + NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate()); + return new MergedBlockMetaSuccess(requestId, numChunks, managedBuf); + } + + @Override + public ResponseMessage createFailureResponse(String error) { + return new RpcFailure(requestId, error); + } +} diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java index bf80aed0afe1..98f7f612a486 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java @@ -83,6 +83,12 @@ private Message decode(Message.Type msgType, ByteBuf in) { case UploadStream: return UploadStream.decode(in); + case MergedBlockMetaRequest: + return MergedBlockMetaRequest.decode(in); + + case MergedBlockMetaSuccess: + return MergedBlockMetaSuccess.decode(in); + default: throw new IllegalArgumentException("Unexpected message type: " + msgType); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java index 92eb88628344..95fde677624f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/AbstractAuthRpcHandler.java @@ -104,4 +104,9 @@ public void exceptionCaught(Throwable cause, TransportClient client) { public boolean isAuthenticated() { return isAuthenticated; } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return delegate.getMergedBlockMetaReqHandler(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 4a30f8de0782..ab2deac20fcd 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -113,6 +113,8 @@ public void handle(RequestMessage request) throws Exception { processStreamRequest((StreamRequest) request); } else if (request instanceof UploadStream) { processStreamUpload((UploadStream) request); + } else if (request instanceof MergedBlockMetaRequest) { + processMergedBlockMetaRequest((MergedBlockMetaRequest) request); } else { throw new IllegalArgumentException("Unknown request type: " + request); } @@ -260,6 +262,30 @@ private void processOneWayMessage(OneWayMessage req) { } } + private void processMergedBlockMetaRequest(final MergedBlockMetaRequest req) { + try { + rpcHandler.getMergedBlockMetaReqHandler().receiveMergeBlockMetaReq(reverseClient, req, + new MergedBlockMetaResponseCallback() { + + @Override + public void onSuccess(int numChunks, ManagedBuffer buffer) { + logger.trace("Sending meta for request {} numChunks {}", req, numChunks); + respond(new MergedBlockMetaSuccess(req.requestId, numChunks, buffer)); + } + + @Override + public void onFailure(Throwable e) { + logger.trace("Failed to send meta for {}", req); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + }); + } catch (Exception e) { + logger.error("Error while invoking receiveMergeBlockMetaReq() for appId {} shuffleId {} " + + "reduceId {}", req.appId, req.shuffleId, req.appId, e); + respond(new RpcFailure(req.requestId, Throwables.getStackTraceAsString(e))); + } + } + /** * Responds to a single message with some Encodable object. If a failure occurs while sending, * it will be logged and the channel closed. diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 0a6447176237..b3befb8baf2d 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.network; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; @@ -24,16 +25,19 @@ import org.junit.Assert; import org.junit.Test; +import static org.junit.Assert.*; import static org.mockito.Mockito.*; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; import org.apache.spark.network.protocol.*; import org.apache.spark.network.server.NoOpRpcHandler; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; +import org.apache.spark.network.server.StreamManager; import org.apache.spark.network.server.TransportRequestHandler; public class TransportRequestHandlerSuite { @@ -109,4 +113,55 @@ public void handleStreamRequest() throws Exception { streamManager.connectionTerminated(channel); Assert.assertEquals(0, streamManager.numStreamStates()); } + + @Test + public void handleMergedBlockMetaRequest() throws Exception { + RpcHandler.MergedBlockMetaReqHandler metaHandler = (client, request, callback) -> { + if (request.shuffleId != -1 && request.reduceId != -1) { + callback.onSuccess(2, mock(ManagedBuffer.class)); + } else { + callback.onFailure(new RuntimeException("empty block")); + } + }; + RpcHandler rpcHandler = new RpcHandler() { + @Override + public void receive( + TransportClient client, + ByteBuffer message, + RpcResponseCallback callback) {} + + @Override + public StreamManager getStreamManager() { + return null; + } + + @Override + public MergedBlockMetaReqHandler getMergedBlockMetaReqHandler() { + return metaHandler; + } + }; + Channel channel = mock(Channel.class); + List> responseAndPromisePairs = new ArrayList<>(); + when(channel.writeAndFlush(any())).thenAnswer(invocationOnMock0 -> { + Object response = invocationOnMock0.getArguments()[0]; + ExtendedChannelPromise channelFuture = new ExtendedChannelPromise(channel); + responseAndPromisePairs.add(ImmutablePair.of(response, channelFuture)); + return channelFuture; + }); + + TransportClient reverseClient = mock(TransportClient.class); + TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, + rpcHandler, 2L, null); + MergedBlockMetaRequest validMetaReq = new MergedBlockMetaRequest(19, "app1", 0, 0); + requestHandler.handle(validMetaReq); + assertEquals(1, responseAndPromisePairs.size()); + assertTrue(responseAndPromisePairs.get(0).getLeft() instanceof MergedBlockMetaSuccess); + assertEquals(2, + ((MergedBlockMetaSuccess) (responseAndPromisePairs.get(0).getLeft())).getNumChunks()); + + MergedBlockMetaRequest invalidMetaReq = new MergedBlockMetaRequest(21, "app1", -1, 1); + requestHandler.handle(invalidMetaReq); + assertEquals(2, responseAndPromisePairs.size()); + assertTrue(responseAndPromisePairs.get(1).getLeft() instanceof RpcFailure); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java index b4032c4c3f03..4de13f951d49 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportResponseHandlerSuite.java @@ -23,17 +23,20 @@ import io.netty.channel.Channel; import io.netty.channel.local.LocalChannel; import org.junit.Test; +import org.mockito.ArgumentCaptor; import static org.junit.Assert.assertEquals; import static org.mockito.Mockito.*; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.StreamCallback; import org.apache.spark.network.client.TransportResponseHandler; import org.apache.spark.network.protocol.ChunkFetchFailure; import org.apache.spark.network.protocol.ChunkFetchSuccess; +import org.apache.spark.network.protocol.MergedBlockMetaSuccess; import org.apache.spark.network.protocol.RpcFailure; import org.apache.spark.network.protocol.RpcResponse; import org.apache.spark.network.protocol.StreamChunkId; @@ -167,4 +170,40 @@ public void failOutstandingStreamCallbackOnException() throws Exception { verify(cb).onFailure(eq("stream-1"), isA(IOException.class)); } + + @Test + public void handleSuccessfulMergedBlockMeta() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.addRpcRequest(13, callback); + assertEquals(1, handler.numOutstandingRequests()); + + // This response should be ignored. + handler.handle(new MergedBlockMetaSuccess(22, 2, + new NioManagedBuffer(ByteBuffer.allocate(7)))); + assertEquals(1, handler.numOutstandingRequests()); + + ByteBuffer resp = ByteBuffer.allocate(10); + handler.handle(new MergedBlockMetaSuccess(13, 2, new NioManagedBuffer(resp))); + ArgumentCaptor bufferCaptor = ArgumentCaptor.forClass(NioManagedBuffer.class); + verify(callback, times(1)).onSuccess(eq(2), bufferCaptor.capture()); + assertEquals(resp, bufferCaptor.getValue().nioByteBuffer()); + assertEquals(0, handler.numOutstandingRequests()); + } + + @Test + public void handleFailedMergedBlockMeta() throws Exception { + TransportResponseHandler handler = new TransportResponseHandler(new LocalChannel()); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.addRpcRequest(51, callback); + assertEquals(1, handler.numOutstandingRequests()); + + // This response should be ignored. + handler.handle(new RpcFailure(6, "failed")); + assertEquals(1, handler.numOutstandingRequests()); + + handler.handle(new RpcFailure(51, "failed")); + verify(callback, times(1)).onFailure(any()); + assertEquals(0, handler.numOutstandingRequests()); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java new file mode 100644 index 000000000000..f4a055188c86 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/MergedBlockMetaSuccessSuite.java @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import java.io.DataOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.nio.file.Files; +import java.util.List; + +import com.google.common.collect.Lists; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import org.junit.Assert; +import org.junit.Test; +import org.roaringbitmap.RoaringBitmap; + +import static org.mockito.Mockito.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.util.ByteArrayWritableChannel; +import org.apache.spark.network.util.TransportConf; + +/** + * Test for {@link MergedBlockMetaSuccess}. + */ +public class MergedBlockMetaSuccessSuite { + + @Test + public void testMergedBlocksMetaEncodeDecode() throws Exception { + File chunkMetaFile = new File("target/mergedBlockMetaTest"); + Files.deleteIfExists(chunkMetaFile.toPath()); + RoaringBitmap chunk1 = new RoaringBitmap(); + chunk1.add(1); + chunk1.add(3); + RoaringBitmap chunk2 = new RoaringBitmap(); + chunk2.add(2); + chunk2.add(4); + RoaringBitmap[] expectedChunks = new RoaringBitmap[]{chunk1, chunk2}; + try (DataOutputStream metaOutput = new DataOutputStream(new FileOutputStream(chunkMetaFile))) { + for (int i = 0; i < expectedChunks.length; i++) { + expectedChunks[i].serialize(metaOutput); + } + } + TransportConf conf = mock(TransportConf.class); + when(conf.lazyFileDescriptor()).thenReturn(false); + long requestId = 1L; + MergedBlockMetaSuccess expectedMeta = new MergedBlockMetaSuccess(requestId, 2, + new FileSegmentManagedBuffer(conf, chunkMetaFile, 0, chunkMetaFile.length())); + + List out = Lists.newArrayList(); + ChannelHandlerContext context = mock(ChannelHandlerContext.class); + when(context.alloc()).thenReturn(ByteBufAllocator.DEFAULT); + + MessageEncoder.INSTANCE.encode(context, expectedMeta, out); + Assert.assertEquals(1, out.size()); + MessageWithHeader msgWithHeader = (MessageWithHeader) out.remove(0); + + ByteArrayWritableChannel writableChannel = + new ByteArrayWritableChannel((int) msgWithHeader.count()); + while (msgWithHeader.transfered() < msgWithHeader.count()) { + msgWithHeader.transferTo(writableChannel, msgWithHeader.transfered()); + } + ByteBuf messageBuf = Unpooled.wrappedBuffer(writableChannel.getData()); + messageBuf.readLong(); // frame length + MessageDecoder.INSTANCE.decode(mock(ChannelHandlerContext.class), messageBuf, out); + Assert.assertEquals(1, out.size()); + MergedBlockMetaSuccess decoded = (MergedBlockMetaSuccess) out.get(0); + Assert.assertEquals("merged block", expectedMeta.requestId, decoded.requestId); + Assert.assertEquals("num chunks", expectedMeta.getNumChunks(), decoded.getNumChunks()); + + ByteBuf responseBuf = Unpooled.wrappedBuffer(decoded.body().nioByteBuffer()); + RoaringBitmap[] responseBitmaps = new RoaringBitmap[expectedMeta.getNumChunks()]; + for (int i = 0; i < expectedMeta.getNumChunks(); i++) { + responseBitmaps[i] = Encoders.Bitmaps.decode(responseBuf); + } + Assert.assertEquals( + "num of roaring bitmaps", expectedMeta.getNumChunks(), responseBitmaps.length); + for (int i = 0; i < expectedMeta.getNumChunks(); i++) { + Assert.assertEquals("chunk bitmap " + i, expectedChunks[i], responseBitmaps[i]); + } + Files.delete(chunkMetaFile.toPath()); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index a095bf272341..493edd2b3462 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -361,8 +361,8 @@ public int removeBlocks(String appId, String execId, String[] blockIds) { return numRemovedBlocks; } - public Map getLocalDirs(String appId, String[] execIds) { - return Arrays.stream(execIds) + public Map getLocalDirs(String appId, Set execIds) { + return execIds.stream() .map(exec -> { ExecutorShuffleInfo info = executors.get(new AppExecId(appId, exec)); if (info == null) { diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java new file mode 100644 index 000000000000..0fca27cf26df --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/AbstractFetchShuffleBlocks.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; + +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; +import org.apache.spark.network.protocol.Encoders; + +/** + * Base class for fetch shuffle blocks and chunks. + * + * @since 3.2.0 + */ +public abstract class AbstractFetchShuffleBlocks extends BlockTransferMessage { + public final String appId; + public final String execId; + public final int shuffleId; + + protected AbstractFetchShuffleBlocks( + String appId, + String execId, + int shuffleId) { + this.appId = appId; + this.execId = execId; + this.shuffleId = shuffleId; + } + + public ToStringBuilder toStringHelper() { + return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) + .append("appId", appId) + .append("execId", execId) + .append("shuffleId", shuffleId); + } + + /** + * Returns number of blocks in the request. + */ + public abstract int getNumBlocks(); + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AbstractFetchShuffleBlocks that = (AbstractFetchShuffleBlocks) o; + return shuffleId == that.shuffleId + && Objects.equal(appId, that.appId) && Objects.equal(execId, that.execId); + } + + @Override + public int hashCode() { + int result = appId.hashCode(); + result = 31 * result + execId.hashCode(); + result = 31 * result + shuffleId; + return result; + } + + @Override + public int encodedLength() { + return Encoders.Strings.encodedLength(appId) + + Encoders.Strings.encodedLength(execId) + + 4; /* encoded length of shuffleId */ + } + + @Override + public void encode(ByteBuf buf) { + Encoders.Strings.encode(buf, appId); + Encoders.Strings.encode(buf, execId); + buf.writeInt(shuffleId); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java index 7f5058124988..a55a6cf7ed93 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/BlockTransferMessage.java @@ -48,7 +48,8 @@ public enum Type { OPEN_BLOCKS(0), UPLOAD_BLOCK(1), REGISTER_EXECUTOR(2), STREAM_HANDLE(3), REGISTER_DRIVER(4), HEARTBEAT(5), UPLOAD_BLOCK_STREAM(6), REMOVE_BLOCKS(7), BLOCKS_REMOVED(8), FETCH_SHUFFLE_BLOCKS(9), GET_LOCAL_DIRS_FOR_EXECUTORS(10), LOCAL_DIRS_FOR_EXECUTORS(11), - PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14); + PUSH_BLOCK_STREAM(12), FINALIZE_SHUFFLE_MERGE(13), MERGE_STATUSES(14), + FETCH_SHUFFLE_BLOCK_CHUNKS(15); private final byte id; @@ -82,6 +83,7 @@ public static BlockTransferMessage fromByteBuffer(ByteBuffer msg) { case 12: return PushBlockStream.decode(buf); case 13: return FinalizeShuffleMerge.decode(buf); case 14: return MergeStatuses.decode(buf); + case 15: return FetchShuffleBlockChunks.decode(buf); default: throw new IllegalArgumentException("Unknown message type: " + type); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java index 98057d58f7ab..68550a2fba86 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocks.java @@ -20,8 +20,6 @@ import java.util.Arrays; import io.netty.buffer.ByteBuf; -import org.apache.commons.lang3.builder.ToStringBuilder; -import org.apache.commons.lang3.builder.ToStringStyle; import org.apache.spark.network.protocol.Encoders; @@ -29,10 +27,7 @@ import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; /** Request to read a set of blocks. Returns {@link StreamHandle}. */ -public class FetchShuffleBlocks extends BlockTransferMessage { - public final String appId; - public final String execId; - public final int shuffleId; +public class FetchShuffleBlocks extends AbstractFetchShuffleBlocks { // The length of mapIds must equal to reduceIds.size(), for the i-th mapId in mapIds, // it corresponds to the i-th int[] in reduceIds, which contains all reduce id for this map id. public final long[] mapIds; @@ -50,9 +45,7 @@ public FetchShuffleBlocks( long[] mapIds, int[][] reduceIds, boolean batchFetchEnabled) { - this.appId = appId; - this.execId = execId; - this.shuffleId = shuffleId; + super(appId, execId, shuffleId); this.mapIds = mapIds; this.reduceIds = reduceIds; assert(mapIds.length == reduceIds.length); @@ -69,10 +62,7 @@ public FetchShuffleBlocks( @Override public String toString() { - return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) - .append("appId", appId) - .append("execId", execId) - .append("shuffleId", shuffleId) + return toStringHelper() .append("mapIds", Arrays.toString(mapIds)) .append("reduceIds", Arrays.deepToString(reduceIds)) .append("batchFetchEnabled", batchFetchEnabled) @@ -85,35 +75,40 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; FetchShuffleBlocks that = (FetchShuffleBlocks) o; - - if (shuffleId != that.shuffleId) return false; + if (!super.equals(that)) return false; if (batchFetchEnabled != that.batchFetchEnabled) return false; - if (!appId.equals(that.appId)) return false; - if (!execId.equals(that.execId)) return false; if (!Arrays.equals(mapIds, that.mapIds)) return false; return Arrays.deepEquals(reduceIds, that.reduceIds); } @Override public int hashCode() { - int result = appId.hashCode(); - result = 31 * result + execId.hashCode(); - result = 31 * result + shuffleId; + int result = super.hashCode(); result = 31 * result + Arrays.hashCode(mapIds); result = 31 * result + Arrays.deepHashCode(reduceIds); result = 31 * result + (batchFetchEnabled ? 1 : 0); return result; } + @Override + public int getNumBlocks() { + if (batchFetchEnabled) { + return mapIds.length; + } + int numBlocks = 0; + for (int[] ids : reduceIds) { + numBlocks += ids.length; + } + return numBlocks; + } + @Override public int encodedLength() { int encodedLengthOfReduceIds = 0; for (int[] ids: reduceIds) { encodedLengthOfReduceIds += Encoders.IntArrays.encodedLength(ids); } - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(execId) - + 4 /* encoded length of shuffleId */ + return super.encodedLength() + Encoders.LongArrays.encodedLength(mapIds) + 4 /* encoded length of reduceIds.size() */ + encodedLengthOfReduceIds @@ -122,9 +117,7 @@ public int encodedLength() { @Override public void encode(ByteBuf buf) { - Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, execId); - buf.writeInt(shuffleId); + super.encode(buf); Encoders.LongArrays.encode(buf, mapIds); buf.writeInt(reduceIds.length); for (int[] ids: reduceIds) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 531657bde481..dc41e957f0fc 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -36,13 +36,16 @@ import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; +import org.apache.spark.network.client.MergedBlockMetaResponseCallback; import org.apache.spark.network.client.RpcResponseCallback; import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.protocol.MergedBlockMetaRequest; import org.apache.spark.network.server.OneForOneStreamManager; import org.apache.spark.network.server.RpcHandler; import org.apache.spark.network.shuffle.protocol.BlockTransferMessage; import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.FetchShuffleBlocks; +import org.apache.spark.network.shuffle.protocol.FetchShuffleBlockChunks; import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.OpenBlocks; @@ -263,4 +266,113 @@ public void testFinalizeShuffleMerge() throws IOException { .get("finalizeShuffleMergeLatencyMillis"); assertEquals(1, finalizeShuffleMergeLatencyMillis.getCount()); } + + @Test + public void testFetchMergedBlocksMeta() { + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 0)).thenReturn( + new MergedBlockMeta(1, mock(ManagedBuffer.class))); + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 1)).thenReturn( + new MergedBlockMeta(3, mock(ManagedBuffer.class))); + when(mergedShuffleManager.getMergedBlockMeta("app0", 0, 2)).thenReturn( + new MergedBlockMeta(5, mock(ManagedBuffer.class))); + + int[] expectedCount = new int[]{1, 3, 5}; + String appId = "app0"; + long requestId = 0L; + for (int reduceId = 0; reduceId < 3; reduceId++) { + MergedBlockMetaRequest req = new MergedBlockMetaRequest(requestId++, appId, 0, reduceId); + MergedBlockMetaResponseCallback callback = mock(MergedBlockMetaResponseCallback.class); + handler.getMergedBlockMetaReqHandler() + .receiveMergeBlockMetaReq(client, req, callback); + verify(mergedShuffleManager, times(1)).getMergedBlockMeta("app0", 0, reduceId); + + ArgumentCaptor numChunksResponse = ArgumentCaptor.forClass(Integer.class); + ArgumentCaptor chunkBitmapResponse = + ArgumentCaptor.forClass(ManagedBuffer.class); + verify(callback, times(1)).onSuccess(numChunksResponse.capture(), + chunkBitmapResponse.capture()); + assertEquals("num chunks in merged block " + reduceId, expectedCount[reduceId], + numChunksResponse.getValue().intValue()); + assertNotNull("chunks bitmap buffer " + reduceId, chunkBitmapResponse.getValue()); + } + } + + @Test + public void testOpenBlocksWithShuffleChunks() { + verifyBlockChunkFetches(true); + } + + @Test + public void testFetchShuffleChunks() { + verifyBlockChunkFetches(false); + } + + private void verifyBlockChunkFetches(boolean useOpenBlocks) { + RpcResponseCallback callback = mock(RpcResponseCallback.class); + ByteBuffer buffer; + if (useOpenBlocks) { + OpenBlocks openBlocks = + new OpenBlocks("app0", "exec1", + new String[] {"shuffleChunk_0_0_0", "shuffleChunk_0_0_1", "shuffleChunk_0_1_0", + "shuffleChunk_0_1_1"}); + buffer = openBlocks.toByteBuffer(); + } else { + FetchShuffleBlockChunks fetchChunks = new FetchShuffleBlockChunks( + "app0", "exec1", 0, new int[] {0, 1}, new int[][] {{0, 1}, {0, 1}}); + buffer = fetchChunks.toByteBuffer(); + } + ManagedBuffer[][] buffers = new ManagedBuffer[][] { + { + new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), + new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) + }, + { + new NioManagedBuffer(ByteBuffer.wrap(new byte[5])), + new NioManagedBuffer(ByteBuffer.wrap(new byte[7])) + } + }; + for (int reduceId = 0; reduceId < 2; reduceId++) { + for (int chunkId = 0; chunkId < 2; chunkId++) { + when(mergedShuffleManager.getMergedBlockData( + "app0", 0, reduceId, chunkId)).thenReturn(buffers[reduceId][chunkId]); + } + } + handler.receive(client, buffer, callback); + ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); + verify(callback, times(1)).onSuccess(response.capture()); + verify(callback, never()).onFailure(any()); + StreamHandle handle = + (StreamHandle) BlockTransferMessage.Decoder.fromByteBuffer(response.getValue()); + assertEquals(4, handle.numChunks); + + @SuppressWarnings("unchecked") + ArgumentCaptor> stream = (ArgumentCaptor>) + (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); + verify(streamManager, times(1)).registerStream(any(), stream.capture(), any()); + Iterator bufferIter = stream.getValue(); + for (int reduceId = 0; reduceId < 2; reduceId++) { + for (int chunkId = 0; chunkId < 2; chunkId++) { + assertEquals(buffers[reduceId][chunkId], bufferIter.next()); + } + } + assertFalse(bufferIter.hasNext()); + verify(mergedShuffleManager, never()).getMergedBlockMeta(anyString(), anyInt(), anyInt()); + verify(blockResolver, never()).getBlockData( + anyString(), anyString(), anyInt(), anyInt(), anyInt()); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 0); + verify(mergedShuffleManager, times(1)).getMergedBlockData("app0", 0, 0, 1); + + // Verify open block request latency metrics + Timer openBlockRequestLatencyMillis = (Timer) ((ExternalBlockHandler) handler) + .getAllMetrics() + .getMetrics() + .get("openBlockRequestLatencyMillis"); + assertEquals(1, openBlockRequestLatencyMillis.getCount()); + // Verify block transfer metrics + Meter blockTransferRateBytes = (Meter) ((ExternalBlockHandler) handler) + .getAllMetrics() + .getMetrics() + .get("blockTransferRateBytes"); + assertEquals(24, blockTransferRateBytes.getCount()); + } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java new file mode 100644 index 000000000000..91f319ded493 --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunksSuite.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class FetchShuffleBlockChunksSuite { + + @Test + public void testFetchShuffleBlockChunksEncodeDecode() { + FetchShuffleBlockChunks shuffleBlockChunks = + new FetchShuffleBlockChunks("app0", "exec1", 0, new int[] {0}, new int[][] {{0, 1}}); + Assert.assertEquals(2, shuffleBlockChunks.getNumBlocks()); + int len = shuffleBlockChunks.encodedLength(); + Assert.assertEquals(45, len); + ByteBuf buf = Unpooled.buffer(len); + shuffleBlockChunks.encode(buf); + + FetchShuffleBlockChunks decoded = FetchShuffleBlockChunks.decode(buf); + assertEquals(shuffleBlockChunks, decoded); + } +} diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java new file mode 100644 index 000000000000..a1681f58e7ea --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlocksSuite.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.*; + +public class FetchShuffleBlocksSuite { + + @Test + public void testFetchShuffleBlockEncodeDecode() { + FetchShuffleBlocks fetchShuffleBlocks = + new FetchShuffleBlocks("app0", "exec1", 0, new long[] {0}, new int[][] {{0, 1}}, false); + Assert.assertEquals(2, fetchShuffleBlocks.getNumBlocks()); + int len = fetchShuffleBlocks.encodedLength(); + Assert.assertEquals(50, len); + ByteBuf buf = Unpooled.buffer(len); + fetchShuffleBlocks.encode(buf); + + FetchShuffleBlocks decoded = FetchShuffleBlocks.decode(buf); + assertEquals(fetchShuffleBlocks, decoded); + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 9b3051f28670..b8aa5c518c0d 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.TimeUnit import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -28,14 +28,16 @@ import org.roaringbitmap.RoaringBitmap import org.apache.spark.MapOutputTracker import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER import org.apache.spark.storage.ShuffleBlockFetcherIterator._ /** * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based - * functionality to fetch merged block meta and merged shuffle block chunks. + * functionality to fetch push-merged block meta and shuffle chunks. + * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple + * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that + * chunk. */ private class PushBasedFetchHelper( private val iterator: ShuffleBlockFetcherIterator, @@ -45,67 +47,65 @@ private class PushBasedFetchHelper( private[this] val startTimeNs = System.nanoTime() - private[this] val localShuffleMergerBlockMgrId = BlockManagerId( + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) /** - * A map for storing merged block shuffle chunk bitmap. This is a concurrent hashmap because it - * can be modified by both the task thread and the netty thread. + * A map for storing shuffle chunk bitmap. */ - private[this] val chunksMetaMap = new ConcurrentHashMap[ShuffleBlockChunkId, RoaringBitmap]() + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() /** * Returns true if the address is for a push-merged block. */ - def isMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { SHUFFLE_MERGER_IDENTIFIER == address.executorId } /** - * Returns true if the address is of a remote merged block. + * Returns true if the address is of a remote push-merged block. false otherwise. */ - def isMergedBlockAddressRemote(address: BlockManagerId): Boolean = { - assert(isMergedShuffleBlockAddress(address)) - address.host != blockManager.blockManagerId.host + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host } /** - * Returns true if the address if of merged local block. false otherwise. + * Returns true if the address is of a local push-merged block. false otherwise. */ - def isMergedLocal(address: BlockManagerId): Boolean = { - isMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host } /** * This is executed by the task thread when the `iterator.next()` is invoked and the iterator * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. * - * @param blockId shuffle block chunk id. + * @param blockId shuffle chunk id. */ - def getNumberOfBlocksInChunk(blockId : ShuffleBlockChunkId): Int = { - chunksMetaMap.get(blockId).getCardinality + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) } /** * This is executed by the task thread when the `iterator.next()` is invoked and the iterator - * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. * - * @param blockId shuffle block chunk id. + * @param blockId shuffle chunk id. */ - def removeChunk(blockId: ShuffleBlockChunkId): Unit = { - chunksMetaMap.remove(blockId) + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta } /** * This is executed by the task thread when the `iterator.next()` is invoked and the iterator - * processes a response of type [[ShuffleBlockFetcherIterator.MergedMetaFetchResult]]. + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. * * @param shuffleId shuffle id. * @param reduceId reduce id. - * @param blockSize size of the merged block. - * @param numChunks number of chunks in the merged block. - * @param bitmaps per chunk bitmap, where each bitmap contains all the mapIds that are merged + * @param blockSize size of the push-merged block. + * @param numChunks number of chunks in the push-merged block. + * @param bitmaps chunk bitmaps, where each bitmap contains all the mapIds that were merged * to that chunk. * @return shuffle chunks to fetch. */ @@ -114,7 +114,7 @@ private class PushBasedFetchHelper( reduceId: Int, blockSize: Long, numChunks: Int, - bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { val approxChunkSize = blockSize / numChunks val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() for (i <- 0 until numChunks) { @@ -131,7 +131,7 @@ private class PushBasedFetchHelper( * push-merged blocks for which it needs to fetch the metadata. * * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch - * metadata of merged blocks. + * metadata of push-merged blocks. */ def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { val sizeMap = req.blocks.map { @@ -142,24 +142,25 @@ private class PushBasedFetchHelper( val address = req.address val mergedBlocksMetaListener = new MergedBlocksMetaListener { override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { - logInfo(s"Received the meta of merged block for ($shuffleId, $reduceId) " + + logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId) " + s"from ${req.address.host}:${req.address.port}") try { - iterator.addToResultsQueue(MergedMetaFetchResult(shuffleId, reduceId, + iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId, sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address)) } catch { case exception: Exception => - logError(s"Failed to parse the meta of merged block for ($shuffleId, $reduceId) " + - s"from ${req.address.host}:${req.address.port}", exception) + logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$reduceId) from ${req.address.host}:${req.address.port}", exception) iterator.addToResultsQueue( - MergedMetaFailedFetchResult(shuffleId, reduceId, address)) + PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address)) } } override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { - logError(s"Failed to get the meta of merged block for ($shuffleId, $reduceId) " + + logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + s"from ${req.address.host}:${req.address.port}", exception) - iterator.addToResultsQueue(MergedMetaFailedFetchResult(shuffleId, reduceId, address)) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address)) } } req.blocks.foreach { block => @@ -171,99 +172,89 @@ private class PushBasedFetchHelper( /** * This is executed by the task thread when the iterator is initialized. It fetches all the - * outstanding merged local blocks. - * @param mergedLocalBlocks set of identified merged local blocks. + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes. */ - def fetchAllMergedLocalBlocks( - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { - if (mergedLocalBlocks.nonEmpty) { - blockManager.hostLocalDirManager.foreach(fetchMergedLocalBlocks(_, mergedLocalBlocks)) + def fetchAllPushMergedLocalBlocks( + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) } } /** - * Fetch the merged blocks dirs if they are not in the cache and eventually fetch merged local - * blocks. + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. */ - private def fetchMergedLocalBlocks( + private def fetchPushMergedLocalBlocks( hostLocalDirManager: HostLocalDirManager, - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( SHUFFLE_MERGER_IDENTIFIER) if (cachedMergerDirs.isDefined) { - logDebug(s"Fetching local merged blocks with cached executors dir: " + + logDebug(s"Fetching local push-merged blocks with cached executors dir: " + s"${cachedMergerDirs.get.mkString(", ")}") - mergedLocalBlocks.foreach(blockId => - fetchMergedLocalBlock(blockId, cachedMergerDirs.get, localShuffleMergerBlockMgrId)) + pushMergedLocalBlocks.foreach { blockId => + fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get, + localShuffleMergerBlockMgrId) + } } else { - logDebug(s"Asynchronous fetching local merged blocks without cached executors dir") + logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir") hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) { case Success(dirs) => - mergedLocalBlocks.takeWhile { + pushMergedLocalBlocks.takeWhile { blockId => logDebug(s"Successfully fetched local dirs: " + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") - fetchMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER), + fetchPushMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER), localShuffleMergerBlockMgrId) } - logDebug(s"Got local merged blocks (without cached executors' dir) in " + + logDebug(s"Got local push-merged blocks (without cached executors' dir) in " + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") case Failure(throwable) => - // If we see an exception with getting the local dirs for local merged blocks, - // we fallback to fetch the original unmerged blocks. We do not report block fetch - // failure. - logWarning(s"Error occurred while getting the local dirs for local merged " + - s"blocks: ${mergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + // If we see an exception with getting the local dirs for local push-merged blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning(s"Error occurred while getting the local dirs for local push-merged " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", throwable) - mergedLocalBlocks.foreach( - blockId => iterator.addToResultsQueue(FallbackOnMergedFailureFetchResult( - blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) - ) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue(FallbackOnPushMergedFailureResult( + blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) + } } } } /** - * Fetch a single local merged block generated. This can also be executed by the task thread as - * well as the netty thread. + * Fetch a single local push-merged block generated. This can also be executed by the task thread + * as well as the netty thread. * @param blockId ShuffleBlockId to be fetched - * @param localDirs Local directories where the merged shuffle files are stored + * @param localDirs Local directories where the push-merged shuffle files are stored * @param blockManagerId BlockManagerId * @return Boolean represents successful or failed fetch */ - private[this] def fetchMergedLocalBlock( + private[this] def fetchPushMergedLocalBlock( blockId: BlockId, localDirs: Array[String], blockManagerId: BlockManagerId): Boolean = { try { val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) - .readChunkBitmaps() - // Fetch local merged shuffle block data as multiple chunks - val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, localDirs) - // Update total number of blocks to fetch, reflecting the multiple local chunks - iterator.incrementNumBlocksToFetch(bufs.size - 1) - for (chunkId <- bufs.indices) { - val buf = bufs(chunkId) - buf.retain() - val shuffleChunkId = ShuffleBlockChunkId(shuffleBlockId.shuffleId, - shuffleBlockId.reduceId, chunkId) - chunksMetaMap.put(shuffleChunkId, chunksMeta(chunkId)) - iterator.addToResultsQueue( - SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, blockManagerId, buf.size(), buf, - isNetworkReqDone = false)) - } + iterator.addToResultsQueue(PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.getNumChunks, + chunksMeta.readChunkBitmaps(), localDirs)) true } catch { case e: Exception => - // If we see an exception with reading a local merged block, we fallback to - // fetch the original unmerged blocks. We do not report block fetch failure + // If we see an exception with reading a local push-merged meta, we fallback to + // fetch the original blocks. We do not report block fetch failure // and will continue with the remaining local block read. - logWarning(s"Error occurred while fetching local merged block, " + + logWarning(s"Error occurred while fetching local push-merged meta, " + s"prepare to fetch the original blocks", e) iterator.addToResultsQueue( - FallbackOnMergedFailureFetchResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) false } } @@ -272,65 +263,63 @@ private class PushBasedFetchHelper( * This is executed by the task thread when the `iterator.next()` is invoked and the iterator * processes a response of type: * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] - * 2) [[ShuffleBlockFetcherIterator.FallbackOnMergedFailureFetchResult]] - * 3) [[ShuffleBlockFetcherIterator.MergedMetaFailedFetchResult]] + * 2) [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] + * 3) [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] * - * This initiates fetching fallback blocks for a merged block (or a merged block chunk) that + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that * failed to fetch. * It makes a call to the map output tracker to get the list of original blocks for the - * given merged blocks, split them into remote and local blocks, and process them - * accordingly. + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. + * It also updates the numberOfBlocksToFetch in the iterator as it processes failed response and + * finds more push-merged requests to remote and again updates it with additional requests for + * original blocks. * The fallback happens when: - * 1. There is an exception while creating shuffle block chunk from local merged shuffle block. + * 1. There is an exception while creating shuffle chunks from local push-merged shuffle block. * See fetchLocalBlock. - * 2. There is a failure when fetching remote shuffle block chunks. + * 2. There is a failure when fetching remote shuffle chunks. * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk * (local or remote). - * - * @return number of blocks processed */ - def initiateFallbackBlockFetchForMergedBlock( + def initiateFallbackFetchForPushMergedBlock( blockId: BlockId, - address: BlockManagerId): Int = { + address: BlockManagerId): Unit = { assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) - logWarning(s"Falling back to fetch the original unmerged blocks for merged block $blockId") + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") // Increase the blocks processed since we will process another block in the next iteration of // the while loop in ShuffleBlockFetcherIterator.next(). - var blocksProcessed = 1 val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = blockId match { case shuffleBlockId: ShuffleBlockId => + iterator.incrementNumBlocksToFetch(-1) mapOutputTracker.getMapSizesForMergeResult( shuffleBlockId.shuffleId, shuffleBlockId.reduceId) case _ => val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] - val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId) - assert(chunkBitmap != null) - // When there is a failure to fetch a remote merged shuffle block chunk, then we try to - // fallback not only for that particular remote shuffle block chunk but also for all the - // pending block chunks that belong to the same host. The reason for doing so is that it - // is very likely that the subsequent requests for merged block chunks from this host will + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will // fail as well. Since, push-based shuffle is best effort and we try not to increase the // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the // fetchRequests queue. - if (isMergedBlockAddressRemote(address)) { + if (isRemotePushMergedBlockAddress(address)) { // Fallback for all the pending fetch requests val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) - if (pendingShuffleChunks.nonEmpty) { - pendingShuffleChunks.foreach { pendingBlockId => - logInfo(s"Falling back immediately for merged block $pendingBlockId") - val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId) - assert(bitmapOfPendingChunk != null) - chunkBitmap.or(bitmapOfPendingChunk) - } - // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed - blocksProcessed += pendingShuffleChunks.size + pendingShuffleChunks.foreach { pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size } + iterator.incrementNumBlocksToFetch(-blocksProcessed) mapOutputTracker.getMapSizesForMergeResult( shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) } - iterator.fetchFallbackBlocks(fallbackBlocksByAddr) - blocksProcessed + iterator.fallbackFetch(fallbackBlocksByAddr) } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 3d42bbc8bd08..801520eac4e4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -59,8 +59,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * block, which indicate the index in the map stage. * Note that zero-sized blocks are already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. - * @param mapOutputTracker [[MapOutputTracker]] for falling back to fetching unmerged blocks if - * we fail to fetch merged block chunks when push based shuffle is enabled. + * @param mapOutputTracker [[MapOutputTracker]] for falling back to fetching the original blocks if + * we fail to fetch shuffle chunks when push based shuffle is enabled. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -333,7 +333,7 @@ final class ShuffleBlockFetcherIterator( val block = BlockId(blockId) if (block.isShuffleChunk) { remainingBlocks -= blockId - results.put(FallbackOnMergedFailureFetchResult( + results.put(FallbackOnPushMergedFailureResult( block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) } else { results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) @@ -363,29 +363,29 @@ final class ShuffleBlockFetcherIterator( blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], localBlocks: mutable.LinkedHashSet[(BlockId, Int)], hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], - mergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") - // Partition to local, host-local, merged-local, remote (includes merged-remote) blocks. - // Remote blocks are further split into FetchRequests of size at most maxBytesInFlight in order - // to limit the amount of data in flight + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight val collectedRemoteRequests = new ArrayBuffer[FetchRequest] val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]() var localBlockBytes = 0L var hostLocalBlockBytes = 0L - var mergedLocalBlockBytes = 0L + var pushMergedLocalBlockBytes = 0L val prevNumBlocksToFetch = numBlocksToFetch val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId for ((address, blockInfos) <- blocksByAddress) { checkBlockSizes(blockInfos) - if (pushBasedFetchHelper.isMergedShuffleBlockAddress(address)) { - // These are push-based merged blocks or chunks of these merged blocks. + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. if (address.host == blockManager.blockManagerId.host) { numBlocksToFetch += blockInfos.size - mergedLocalBlocks ++= blockInfos.map(_._1) - mergedLocalBlockBytes += blockInfos.map(_._3).sum + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._3).sum } else { remoteBlockBytes += blockInfos.map(_._2).sum collectFetchRequests(address, blockInfos, collectedRemoteRequests) @@ -417,21 +417,22 @@ final class ShuffleBlockFetcherIterator( val (remoteBlockBytes, numRemoteBlocks) = collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + - mergedLocalBlockBytes + pushMergedLocalBlockBytes val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch assert(blocksToFetchCurrentIteration == localBlocks.size + - hostLocalBlocksCurrentIteration.size + numRemoteBlocks + mergedLocalBlocks.size, + hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size, s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " + s"the number of local blocks ${localBlocks.size} + " + s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " + - s"the number of merged-local blocks ${mergedLocalBlocks.size} " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + s"+ the number of remote blocks ${numRemoteBlocks} ") logInfo(s"Getting $blocksToFetchCurrentIteration " + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " + - s"host-local and ${mergedLocalBlocks.size} (${Utils.bytesToString(mergedLocalBlockBytes)}) " + - s"local merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + s"remote blocks") this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration collectedRemoteRequests @@ -486,7 +487,7 @@ final class ShuffleBlockFetcherIterator( curBlocks += FetchBlockInfo(blockId, size, mapIndex) curRequestSize += size blockId match { - // Either all blocks are merged blocks, merged block chunks, or original non-merged blocks. + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. // Based on these types, we decide to do batch fetch and create FetchRequests with // forMergedMetas set. case ShuffleBlockChunkId(_, _, _) => @@ -677,11 +678,11 @@ final class ShuffleBlockFetcherIterator( val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() val hostLocalBlocksByExecutor = mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() - val mergedLocalBlocks = mutable.LinkedHashSet[BlockId]() - // Partition blocks by the different fetch modes: local, host-local, merged-local and remote - // blocks. + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. val remoteRequests = partitionBlocksByFetchMode( - blocksByAddress, localBlocks, hostLocalBlocksByExecutor, mergedLocalBlocks) + blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks) // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), @@ -701,7 +702,7 @@ final class ShuffleBlockFetcherIterator( logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") // Get host local blocks if any fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) - pushBasedFetchHelper.fetchAllMergedLocalBlocks(mergedLocalBlocks) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) } private def fetchAllHostLocalBlocks( @@ -745,31 +746,21 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { - if (pushBasedFetchHelper.isMergedLocal(address)) { - // It is a local merged block chunk - assert(blockId.isShuffleChunk) - shuffleMetrics.incLocalBlocksFetched(pushBasedFetchHelper.getNumberOfBlocksInChunk( - blockId.asInstanceOf[ShuffleBlockChunkId])) - shuffleMetrics.incLocalBytesRead(buf.size) - } else if (hostLocalBlocks.contains(blockId -> mapIndex)) { - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - } else { - // Could be a remote merged block chunk or remote block - numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 - shuffleMetrics.incRemoteBytesRead(buf.size) - if (buf.isInstanceOf[FileSegmentManagedBuffer]) { - shuffleMetrics.incRemoteBytesReadToDisk(buf.size) - } - if (blockId.isShuffleChunk) { - shuffleMetrics.incRemoteBlocksFetched( - pushBasedFetchHelper.getNumberOfBlocksInChunk( - blockId.asInstanceOf[ShuffleBlockChunkId])) - } else { - shuffleMetrics.incRemoteBlocksFetched(1) - } - bytesInFlight -= size - } + if (hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) { + // It is a host local block or a local shuffle chunk + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + // Could be a remote shuffle chunk or remote block + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + bytesInFlight -= size + } } if (isNetworkReqDone) { reqsInFlight -= 1 @@ -810,8 +801,7 @@ final class ShuffleBlockFetcherIterator( } buf.release() if (blockId.isShuffleChunk) { - numBlocksProcessed += pushBasedFetchHelper - .initiateFallbackBlockFetchForMergedBlock(blockId, address) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop to get either. result = null null @@ -835,16 +825,14 @@ final class ShuffleBlockFetcherIterator( case e: IOException => buf.release() if (blockId.isShuffleChunk) { - // Retrying a corrupt block may result again in a corrupt block. For merged - // block chunks, we opt to fallback on the original shuffle blocks - // that belong to that corrupt merged block chunk immediately instead of - // retrying to fetch the corrupt chunk. This also makes the code simpler - // because the chunkMeta corresponding to a Shuffle Chunk is always removed - // from chunksMetaMap whenever a Shuffle Block Chunk gets processed. - // If we try to re-fetch a corrupt shuffle chunk, then it has to be added - // back to the chunksMetaMap. - numBlocksProcessed += pushBasedFetchHelper - .initiateFallbackBlockFetchForMergedBlock(blockId, address) + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop. result = null } else { @@ -893,14 +881,15 @@ final class ShuffleBlockFetcherIterator( defReqQueue.enqueue(request) result = null - case FallbackOnMergedFailureFetchResult(blockId, address, size, isNetworkReqDone) => + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => // We get this result in 3 cases: - // 1. Failure to fetch the data of a remote merged shuffle chunk. In this case, the + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the // blockId is a ShuffleBlockChunkId. - // 2. Failure to read the local merged data. In this case, the blockId is ShuffleBlockId. - // 3. Failure to get the local merged directories from the ESS. In this case, the blockId - // is ShuffleBlockId. - if (pushBasedFetchHelper.isMergedBlockAddressRemote(address)) { + // 2. Failure to read the local push-merged meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the local push-merged directories from the ESS. In this case, the + // blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 bytesInFlight -= size } @@ -908,17 +897,49 @@ final class ShuffleBlockFetcherIterator( reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) } - numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock( - blockId, address) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) // Set result to null to trigger another iteration of the while loop to get either // a SuccessFetchResult or a FailureFetchResult. result = null - case MergedMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, + case PushMergedLocalMetaFetchResult(shuffleId, reduceId, _, bitmaps, localDirs, _) => + // Fetch local push-merged shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId) + try { + val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, + localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + for (chunkId <- bufs.indices) { + val buf = bufs(chunkId) + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading local push-merged data, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning(s"Error occurred while fetching local push-merged data, " + + s"prepare to fetch the original blocks", e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, address, _) => // The original meta request is processed so we decrease numBlocksToFetch and - // numBlocksInFlightPerAddress by 1. We will collect new chunks request and the count of - // this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 numBlocksToFetch -= 1 val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( @@ -929,14 +950,12 @@ final class ShuffleBlockFetcherIterator( // Set result to null to force another iteration. result = null - case MergedMetaFailedFetchResult(shuffleId, reduceId, address, _) => + case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address, _) => // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. - // However, instead of decreasing numBlocksToFetch by 1, we increment numBlocksProcessed - // which has the same effect. numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 - // If we fail to fetch the merged status of a merged block, we fall back to fetching the - // unmerged blocks. - numBlocksProcessed += pushBasedFetchHelper.initiateFallbackBlockFetchForMergedBlock( + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId), address) // Set result to null to force another iteration. result = null @@ -1062,33 +1081,33 @@ final class ShuffleBlockFetcherIterator( /** * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch - * failure for a shuffle merged block/chunk. + * failure related to a push-merged block or shuffle chunk. * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates * fallback. */ - private[storage] def fetchFallbackBlocks( - fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { - val fallbackLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() - val fallbackHostLocalBlocksByExecutor = + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() - val fallbackMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() - val fallbackRemoteReqs = partitionBlocksByFetchMode(fallbackBlocksByAddr, - fallbackLocalBlocks, fallbackHostLocalBlocksByExecutor, fallbackMergedLocalBlocks) + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr, + originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks) // Add the remote requests into our queue in a random order - fetchRequests ++= Utils.randomize(fallbackRemoteReqs) - logInfo(s"Started ${fallbackRemoteReqs.size} fallback remote requests for merged") + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Started ${originalRemoteReqs.size} fallback remote requests for push-merged") // fetch all the fallback blocks that are local. - fetchLocalBlocks(fallbackLocalBlocks) + fetchLocalBlocks(originalLocalBlocks) // Merged local blocks should be empty during fallback - assert(fallbackMergedLocalBlocks.isEmpty, - "There should be zero merged blocks during fallback") + assert(originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") // Some of the fallback local blocks could be host local blocks - fetchAllHostLocalBlocks(fallbackHostLocalBlocksByExecutor) + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) } /** - * Removes all the pending shuffle chunks that are on the same host as the block chunk that had - * a fetch failure. + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates * fallback. * @@ -1099,28 +1118,29 @@ final class ShuffleBlockFetcherIterator( address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() - def sameShuffleBlockChunk(block: BlockId): Boolean = { + def sameShuffleReducePartition(block: BlockId): Boolean = { val chunkId = block.asInstanceOf[ShuffleBlockChunkId] chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId } def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() - fetchRequestsToRemove ++= queue.dequeueAll(req => { + fetchRequestsToRemove ++= queue.dequeueAll { req => val firstBlock = req.blocks.head firstBlock.blockId.isShuffleChunk && req.address.equals(address) && - sameShuffleBlockChunk(firstBlock.blockId) - }) - fetchRequestsToRemove.foreach(req => { - removedChunkIds ++= req.blocks.iterator.map(_.blockId.asInstanceOf[ShuffleBlockChunkId]) - }) + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } } filterRequests(fetchRequests) - deferredFetchRequests.get(address).foreach(defRequests => { + deferredFetchRequests.get(address).foreach { defRequests => filterRequests(defRequests) if (defRequests.isEmpty) deferredFetchRequests.remove(address) - }) + } removedChunkIds } } @@ -1231,8 +1251,8 @@ object ShuffleBlockFetcherIterator { } /** - * Dummy shuffle block id to fill into [[MergedMetaFetchResult]] and - * [[MergedMetaFailedFetchResult]], to match the [[FetchResult]] trait. + * Dummy shuffle block id to fill into [[PushMergedRemoteMetaFetchResult]] and + * [[PushMergedRemoteMetaFailedFetchResult]], to match the [[FetchResult]] trait. */ private val DUMMY_SHUFFLE_BLOCK_ID = ShuffleBlockId(-1, -1, -1) @@ -1332,8 +1352,8 @@ object ShuffleBlockFetcherIterator { * A request to fetch blocks from a remote BlockManager. * @param address remote BlockManager to fetch from. * @param blocks Sequence of the information for blocks to fetch from the same address. - * @param forMergedMetas true if this request is for requesting merged meta information; - * false if it is for regular or shuffle block chunks. + * @param forMergedMetas true if this request is for requesting push-merged meta information; + * false if it is for regular or shuffle chunks. */ case class FetchRequest( address: BlockManagerId, @@ -1390,34 +1410,33 @@ object ShuffleBlockFetcherIterator { /** * Result of an un-successful fetch of either of these: - * 1) Remote shuffle block chunk. - * 2) Local merged block data. + * 1) Remote shuffle chunk. + * 2) Local push-merged block. * - * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original - * unmerged blocks. + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. * * @param blockId block id - * @param address BlockManager that the merged block was attempted to be fetched from + * @param address BlockManager that the push-merged block was attempted to be fetched from * @param size size of the block, used to update bytesInFlight. * @param isNetworkReqDone Is this the last network request for this host in this fetch * request. Used to update reqsInFlight. */ - private[storage] case class FallbackOnMergedFailureFetchResult(blockId: BlockId, + private[storage] case class FallbackOnPushMergedFailureResult(blockId: BlockId, address: BlockManagerId, size: Long, isNetworkReqDone: Boolean) extends FetchResult /** - * Result of a successful fetch of meta information for a merged block. + * Result of a successful fetch of meta information for a remote push-merged block. * * @param shuffleId shuffle id. * @param reduceId reduce id. - * @param blockSize size of each merged block. - * @param numChunks number of chunks in the merged block. + * @param blockSize size of each push-merged block. + * @param numChunks number of chunks in the push-merged block. * @param bitmaps bitmaps for every chunk. - * @param address BlockManager that the merged status was fetched from. + * @param address BlockManager that the meta was fetched from. */ - private[storage] case class MergedMetaFetchResult( + private[storage] case class PushMergedRemoteMetaFetchResult( shuffleId: Int, reduceId: Int, blockSize: Long, @@ -1427,15 +1446,32 @@ object ShuffleBlockFetcherIterator { blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult /** - * Result of a failure while fetching the meta information for a merged block. + * Result of a failure while fetching the meta information for a remote push-merged block. * * @param shuffleId shuffle id. * @param reduceId reduce id. - * @param address BlockManager that the merged status was fetched from. + * @param address BlockManager that the meta was fetched from. */ - private[storage] case class MergedMetaFailedFetchResult( + private[storage] case class PushMergedRemoteMetaFailedFetchResult( shuffleId: Int, reduceId: Int, address: BlockManagerId, blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult + + /** + * Result of a successful fetch of meta information for a local push-merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param numChunks number of chunks in the push-merged block. + * @param bitmaps bitmaps for every chunk. + * @param localDirs local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + reduceId: Int, + numChunks: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String], + blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index adaec777205f..ca18f70e7d74 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1056,9 +1056,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT }) } - test("fetch merged blocks meta") { + test("SPARK-32922: fetch remote push-merged block meta") { val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( - (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), (BlockManagerId("remote-client-1", "remote-host-1", 1), toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)) @@ -1073,11 +1073,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransferForPushShuffle(blocksSem, blockChunks) val metaSem = new Semaphore(0) - val mergedBlockMeta = mock(classOf[MergedBlockMeta]) - when(mergedBlockMeta.getNumChunks).thenReturn(2) - when(mergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(2) + when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) val roaringBitmaps = Array(new RoaringBitmap, new RoaringBitmap) - when(mergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] @@ -1088,7 +1088,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT s"port = ${invocation.getArguments()(1)}, " + s"shuffleId = $shuffleId, reduceId = $reduceId") metaSem.acquire() - metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) } }) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) @@ -1108,10 +1108,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("failed to fetch merged blocks meta") { + test("SPARK-32922: failed to fetch remote push-merged block meta so fallback to " + + "original blocks.") { val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( - (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))) @@ -1149,10 +1150,10 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("iterator has just 1 merged block and fails to fetch the meta") { + test("SPARK-32922: iterator has just 1 push-merged block and fails to fetch the meta") { val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( - (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "merged-host", 1), + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) val blockChunks = Map[BlockId, ManagedBuffer]( @@ -1182,26 +1183,27 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - private def createMockMergedBlockMeta( + private def createMockPushMergedBlockMeta( numChunks: Int, bitmaps: Array[RoaringBitmap]): MergedBlockMeta = { - val mergedBlockMeta = mock(classOf[MergedBlockMeta]) - when(mergedBlockMeta.getNumChunks).thenReturn(numChunks) + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(numChunks) if (bitmaps == null) { - when(mergedBlockMeta.readChunkBitmaps()).thenThrow(new IOException("forced error")) + when(pushMergedBlockMeta.readChunkBitmaps()).thenThrow(new IOException("forced error")) } else { - when(mergedBlockMeta.readChunkBitmaps()).thenReturn(bitmaps) + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(bitmaps) } - doReturn(createMockManagedBuffer()).when(mergedBlockMeta).getChunksBitmapBuffer - mergedBlockMeta + doReturn(createMockManagedBuffer()).when(pushMergedBlockMeta).getChunksBitmapBuffer + pushMergedBlockMeta } - private def prepareBlocksForFallbackWhenBlocksAreLocal( + private def prepareForFallbackToLocalBlocks( blockManager: BlockManager, localDirsMap : Map[String, Array[String]], failReadingLocalChunksMeta: Boolean = false): Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = { - val localBmId = BlockManagerId("test-client", "test-local-host", 1) + val localHost = "test-local-host" + val localBmId = BlockManagerId("test-client", localHost, 1) doReturn(localBmId).when(blockManager).blockManagerId initHostLocalDirManager(blockManager, localDirsMap) @@ -1229,23 +1231,23 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val bitmaps = Array(new RoaringBitmap) bitmaps(0).add(1) // chunk 0 has mapId 1 bitmaps(0).add(2) // chunk 0 has mapId 2 - val mergedBlockMeta: MergedBlockMeta = if (failReadingLocalChunksMeta) { - createMockMergedBlockMeta(bitmaps.length, null) + val pushMergedBlockMeta: MergedBlockMeta = if (failReadingLocalChunksMeta) { + createMockPushMergedBlockMeta(bitmaps.length, null) } else { - createMockMergedBlockMeta(bitmaps.length, bitmaps) + createMockPushMergedBlockMeta(bitmaps.length, bitmaps) } when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), - dirsForMergedData)).thenReturn(mergedBlockMeta) + dirsForMergedData)).thenReturn(pushMergedBlockMeta) when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( Seq((localBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) when(mapOutputTracker.getMapSizesForMergeResult(0, 2, bitmaps(0))) .thenReturn(Seq((localBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) - val mergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1) + val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 1) Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), - (mergedBmId, toBlockList( + (pushMergedBmId, toBlockList( Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) } @@ -1261,10 +1263,36 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("failed to fetch local merged blocks then fallback to fetch original shuffle blocks") { + test("SPARK-32922: failure to fetch local push-merged meta should fallback to fetch " + + "original shuffle blocks") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("testPath1", "testPath2") - val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: failure to reading chunkBitmaps of local push-merged meta should " + + "fallback to original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs), + failReadingLocalChunksMeta = true) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: failure to fetch local push-merged data should fallback to fetch " + + "original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + val blocksByAddress = prepareForFallbackToLocalBlocks( blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) doThrow(new RuntimeException("Forced error")).when(blockManager) .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) @@ -1273,7 +1301,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } - test("failed to fetch merged block as well as fallback block should throw " + + test("SPARK-32922: failure to fetch push-merged block as well as fallback block should throw " + "a FetchFailedException") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("testPath1", "testPath2") @@ -1289,11 +1317,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .getLocalBlockData(ShuffleBlockId(0, 1, 2)) val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER) - // Since bitmaps is null, this will fail reading the merged block meta causing fallback to + // Since bitmaps are null, this will fail reading the push-merged block meta causing fallback to // initiate. - val mergedBlockMeta: MergedBlockMeta = createMockMergedBlockMeta(2, null) + val pushMergedBlockMeta: MergedBlockMeta = createMockPushMergedBlockMeta(2, null) when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), - dirsForMergedData)).thenReturn(mergedBlockMeta) + dirsForMergedData)).thenReturn(pushMergedBlockMeta) when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( Seq((localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) @@ -1309,14 +1337,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - test("failed to fetch local merged blocks then fallback to fetch original shuffle " + - "blocks which contains host-local blocks") { + test("SPARK-32922: failure to fetch local push-merged block should fallback to fetch " + + "original shuffle blocks which contain host-local blocks") { val blockManager = mock(classOf[BlockManager]) // BlockManagerId from another executor on the same host val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) val hostLocalDirs = Map("test-client-1" -> Array("local-dir"), SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir")) - val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal(blockManager, hostLocalDirs) + val blocksByAddress = prepareForFallbackToLocalBlocks(blockManager, hostLocalDirs) doThrow(new RuntimeException("Forced error")).when(blockManager) .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) @@ -1333,19 +1361,20 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } - test("initialization and fallback with host locals blocks") { + test("SPARK-32922: fetch host local blocks with push-merged block during initialization " + + "and fallback to host locals blocks") { val blockManager = mock(classOf[BlockManager]) // BlockManagerId of another executor on the same host val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) - val fallbackHostLocalBmId = BlockManagerId("test-client-2", "test-local-host", 1) + val originalHostLocalBmId = BlockManagerId("test-client-2", "test-local-host", 1) val hostLocalDirs = Map(hostLocalBmId.executorId -> Array("local-dir"), SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir"), - fallbackHostLocalBmId.executorId -> Array("local-dir")) + originalHostLocalBmId.executorId -> Array("local-dir")) val hostLocalBlocks = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( (hostLocalBmId, Seq((ShuffleBlockId(0, 5, 2), 1L, 1)))) - val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + val blocksByAddress = prepareForFallbackToLocalBlocks( blockManager, hostLocalDirs) ++ hostLocalBlocks doThrow(new RuntimeException("Forced error")).when(blockManager) @@ -1358,7 +1387,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenAnswer( (_: InvocationOnMock) => { Seq((blockManager.blockManagerId, toBlockList(Seq(ShuffleBlockId(0, 2, 2)), 1L, 1)), - (fallbackHostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator + (originalHostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator }) val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, blockManager = Some(blockManager)) @@ -1375,10 +1404,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("failure while reading shuffle chunks should fallback to original shuffle blocks") { + test("SPARK-32922: failure while reading local shuffle chunks should fallback to original " + + "shuffle blocks") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("local-dir") - val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + val blocksByAddress = prepareForFallbackToLocalBlocks( blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) // This will throw an IOException when input stream is created from the ManagedBuffer doReturn(Seq({ @@ -1390,10 +1420,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } - test("fallback to original shuffle block when a merged block chunk is corrupt") { + test("SPARK-32922: fallback to original shuffle block when a push-merged shuffle chunk " + + "is corrupt") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("local-dir") - val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( + val blocksByAddress = prepareForFallbackToLocalBlocks( blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) val corruptBuffer = createMockManagedBuffer(2) doReturn(Seq({corruptBuffer})).when(blockManager) @@ -1406,19 +1437,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } - test("failure when reading chunkBitmaps of local merged block should fallback to " + - "original shuffle blocks") { - val blockManager = mock(classOf[BlockManager]) - val localDirs = Array("local-dir") - val blocksByAddress = prepareBlocksForFallbackWhenBlocksAreLocal( - blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs), - failReadingLocalChunksMeta = true) - val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, - blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) - verifyLocalBlocksFromFallback(iterator) - } - - test("fallback to original blocks when failed to fetch remote shuffle chunk") { + test("SPARK-32922: fallback to original blocks when failed to fetch remote shuffle chunk") { val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), @@ -1431,14 +1450,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT bitmaps(1).add(3) bitmaps(1).add(4) bitmaps(1).add(5) - val mergedBlockMeta = createMockMergedBlockMeta(2, bitmaps) + val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, bitmaps) when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] val reduceId = invocation.getArguments()(3).asInstanceOf[Int] Future { - metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) } }) val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( @@ -1463,7 +1482,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("fallback to original blocks when failed to parse remote merged block meta") { + test("SPARK-32922: fallback to original blocks when failed to parse remote merged block meta") { val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer() @@ -1473,14 +1492,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) val blocksSem = new Semaphore(0) configureMockTransferForPushShuffle(blocksSem, blockChunks) - val mergedBlockMeta = createMockMergedBlockMeta(2, null) + val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, null) when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] val reduceId = invocation.getArguments()(3).asInstanceOf[Int] Future { - metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) } }) val remoteMergedBlockMgrId = BlockManagerId( @@ -1496,8 +1515,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("failure to fetch a remote merged block chunk initiates the fallback of " + - "deferred shuffle chunks immediately") { + test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " + + "pending shuffle chunks immediately") { val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), // ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in block-chunks. @@ -1512,11 +1531,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT configureMockTransferForPushShuffle(blocksSem, blockChunks) val metaSem = new Semaphore(0) - val mergedBlockMeta = mock(classOf[MergedBlockMeta]) - when(mergedBlockMeta.getNumChunks).thenReturn(4) - when(mergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(4) + when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) val roaringBitmaps = Array.fill[RoaringBitmap](4)(new RoaringBitmap) - when(mergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] @@ -1527,7 +1546,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT s"port = ${invocation.getArguments()(1)}, " + s"shuffleId = $shuffleId, reduceId = $reduceId") metaSem.release() - metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) } }) val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) @@ -1563,8 +1582,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2))) } - test("failure to fetch a remote merged block chunk initiates the fallback of " + - "deferred shuffle chunks immediately which got deferred") { + test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " + + "pending shuffle chunks immediately which got deferred") { val blockChunks = Map[BlockId, ManagedBuffer]( ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(), @@ -1580,11 +1599,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val blocksSem = new Semaphore(0) configureMockTransferForPushShuffle(blocksSem, blockChunks) val metaSem = new Semaphore(0) - val mergedBlockMeta = mock(classOf[MergedBlockMeta]) - when(mergedBlockMeta.getNumChunks).thenReturn(6) - when(mergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(6) + when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) val roaringBitmaps = Array.fill[RoaringBitmap](6)(new RoaringBitmap) - when(mergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) .thenAnswer((invocation: InvocationOnMock) => { val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] @@ -1595,7 +1614,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT s"port = ${invocation.getArguments()(1)}, " + s"shuffleId = $shuffleId, reduceId = $reduceId") metaSem.release() - metaListener.onSuccess(shuffleId, reduceId, mergedBlockMeta) + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) } }) val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala index 388a86037594..eff2de714394 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceMetricsSuite.scala @@ -42,7 +42,8 @@ class YarnShuffleServiceMetricsSuite extends SparkFunSuite with Matchers { "openBlockRequestLatencyMillis", "registerExecutorRequestLatencyMillis", "blockTransferRate", "blockTransferMessageRate", "blockTransferAvgSize_1min", "blockTransferRateBytes", "registeredExecutorsSize", "numActiveConnections", - "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis") + "numCaughtExceptions", "finalizeShuffleMergeLatencyMillis", + "fetchMergedBlocksMetaLatencyMillis") // Use sorted Seq instead of Set for easier comparison when there is a mismatch metrics.getMetrics.keySet().asScala.toSeq.sorted should be (allMetrics.sorted) From fc1b9f10911e52a57ba7ad3aa241ccb03433d42f Mon Sep 17 00:00:00 2001 From: otterc Date: Tue, 22 Jun 2021 19:12:29 -0700 Subject: [PATCH 22/27] Update core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala Co-authored-by: Mridul Muralidharan <1591700+mridulm@users.noreply.github.com> --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 801520eac4e4..2e55511a78f9 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -914,8 +914,7 @@ final class ShuffleBlockFetcherIterator( // Update total number of blocks to fetch, reflecting the multiple local shuffle // chunks. numBlocksToFetch += bufs.size - for (chunkId <- bufs.indices) { - val buf = bufs(chunkId) + bufs.zipWithIndex { case (buf, chunkId) => buf.retain() val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId) pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) From 390300a1a29ee5880561302d8bc8a83de850852a Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Tue, 22 Jun 2021 23:05:59 -0700 Subject: [PATCH 23/27] Addressed Mridul's comments --- .../apache/spark/storage/BlockManager.scala | 5 +++ .../spark/storage/PushBasedFetchHelper.scala | 25 ++++++------ .../storage/ShuffleBlockFetcherIterator.scala | 33 ++++++++-------- .../ShuffleBlockFetcherIteratorSuite.scala | 38 +++++++++++++++++++ 4 files changed, 70 insertions(+), 31 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index df449fba24e9..98d094939cd4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -132,6 +132,11 @@ private[spark] class HostLocalDirManager( executorIdToLocalDirsCache.asMap().asScala.toMap } + private[spark] def getCachedHostLocalDirsFor(executorId: String): Option[Array[String]] = + executorIdToLocalDirsCache.synchronized { + Option(executorIdToLocalDirsCache.getIfPresent(executorId)) + } + private[spark] def getHostLocalDirs( host: String, port: Int, diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index b8aa5c518c0d..8a492eaf4505 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -189,33 +189,33 @@ private class PushBasedFetchHelper( private def fetchPushMergedLocalBlocks( hostLocalDirManager: HostLocalDirManager, pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { - val cachedMergerDirs = hostLocalDirManager.getCachedHostLocalDirs.get( + val cachedPushedMergedDirs = hostLocalDirManager.getCachedHostLocalDirsFor( SHUFFLE_MERGER_IDENTIFIER) - if (cachedMergerDirs.isDefined) { - logDebug(s"Fetching local push-merged blocks with cached executors dir: " + - s"${cachedMergerDirs.get.mkString(", ")}") + if (cachedPushedMergedDirs.isDefined) { + logDebug(s"Fetch the local push-merged blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") pushMergedLocalBlocks.foreach { blockId => - fetchPushMergedLocalBlock(blockId, cachedMergerDirs.get, + fetchPushMergedLocalBlock(blockId, cachedPushedMergedDirs.get, localShuffleMergerBlockMgrId) } } else { - logDebug(s"Asynchronous fetching local push-merged blocks without cached executors dir") + logDebug(s"Asynchronous fetch the local push-merged blocks without cached merged dirs") hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) { case Success(dirs) => - pushMergedLocalBlocks.takeWhile { + logDebug(s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { blockId => logDebug(s"Successfully fetched local dirs: " + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") fetchPushMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER), localShuffleMergerBlockMgrId) } - logDebug(s"Got local push-merged blocks (without cached executors' dir) in " + - s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") case Failure(throwable) => // If we see an exception with getting the local dirs for local push-merged blocks, // we fallback to fetch the original blocks. We do not report block fetch failure. - logWarning(s"Error occurred while getting the local dirs for local push-merged " + + logWarning(s"Error while fetching the merged dirs for local push-merged " + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", throwable) pushMergedLocalBlocks.foreach { @@ -233,19 +233,17 @@ private class PushBasedFetchHelper( * @param blockId ShuffleBlockId to be fetched * @param localDirs Local directories where the push-merged shuffle files are stored * @param blockManagerId BlockManagerId - * @return Boolean represents successful or failed fetch */ private[this] def fetchPushMergedLocalBlock( blockId: BlockId, localDirs: Array[String], - blockManagerId: BlockManagerId): Boolean = { + blockManagerId: BlockManagerId): Unit = { try { val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) iterator.addToResultsQueue(PushMergedLocalMetaFetchResult( shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.getNumChunks, chunksMeta.readChunkBitmaps(), localDirs)) - true } catch { case e: Exception => // If we see an exception with reading a local push-merged meta, we fallback to @@ -255,7 +253,6 @@ private class PushBasedFetchHelper( s"prepare to fetch the original blocks", e) iterator.addToResultsQueue( FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) - false } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2e55511a78f9..8bbbb146ddb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -522,7 +522,7 @@ final class ShuffleBlockFetcherIterator( } } createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests, - enableBatchFetch = enableBatchFetch, forMergedBlocks = areMergedBlocks) + enableBatchFetch = enableBatchFetch, forMergedMetas = areMergedBlocks) } } @@ -746,21 +746,20 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { - if (hostLocalBlocks.contains(blockId -> mapIndex) || - pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) { - // It is a host local block or a local shuffle chunk - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - } else { - // Could be a remote shuffle chunk or remote block - numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 - shuffleMetrics.incRemoteBytesRead(buf.size) - if (buf.isInstanceOf[FileSegmentManagedBuffer]) { - shuffleMetrics.incRemoteBytesReadToDisk(buf.size) - } - shuffleMetrics.incRemoteBlocksFetched(1) - bytesInFlight -= size - } + if (hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) { + // It is a host local block or a local shuffle chunk + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + } else { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + shuffleMetrics.incRemoteBytesRead(buf.size) + if (buf.isInstanceOf[FileSegmentManagedBuffer]) { + shuffleMetrics.incRemoteBytesReadToDisk(buf.size) + } + shuffleMetrics.incRemoteBlocksFetched(1) + bytesInFlight -= size + } } if (isNetworkReqDone) { reqsInFlight -= 1 @@ -914,7 +913,7 @@ final class ShuffleBlockFetcherIterator( // Update total number of blocks to fetch, reflecting the multiple local shuffle // chunks. numBlocksToFetch += bufs.size - bufs.zipWithIndex { case (buf, chunkId) => + bufs.zipWithIndex.foreach { case (buf, chunkId) => buf.retain() val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId) pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index ca18f70e7d74..db1d8bdca91c 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1301,6 +1301,44 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } + test("SPARK-32922: failure to fetch local push-merged meta of a single merged block " + + "should not drop the fetch of other local push-merged blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + val localHost = "test-local-host" + val localBmId = BlockManagerId("test-client", localHost, 1) + val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 1) + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), + (pushMergedBmId, toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3)), 2L, SHUFFLE_PUSH_MAP_ID))) + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + // Create a valid chunk meta for partition 3 + val bitmaps = Array(new RoaringBitmap) + bitmaps(0).add(1) // chunk 0 has mapId 1 + doReturn(createMockPushMergedBlockMeta(bitmaps.length, bitmaps)).when(blockManager) + .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs) + // Return valid buffer for chunk in partition 3 + doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 1, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + val (id5, _) = iterator.next() + assert(id5 === ShuffleBlockChunkId(0, 3, 0)) + assert(!iterator.hasNext) + } + test("SPARK-32922: failure to fetch push-merged block as well as fallback block should throw " + "a FetchFailedException") { val blockManager = mock(classOf[BlockManager]) From d53293e3783981aee3e9c0ffd911fe12cd8097c0 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Wed, 23 Jun 2021 21:20:13 -0700 Subject: [PATCH 24/27] Removed passing numChunks --- .../spark/storage/PushBasedFetchHelper.scala | 12 +++--- .../storage/ShuffleBlockFetcherIterator.scala | 37 ++++++++----------- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index 8a492eaf4505..b16562e17a1d 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -104,7 +104,6 @@ private class PushBasedFetchHelper( * @param shuffleId shuffle id. * @param reduceId reduce id. * @param blockSize size of the push-merged block. - * @param numChunks number of chunks in the push-merged block. * @param bitmaps chunk bitmaps, where each bitmap contains all the mapIds that were merged * to that chunk. * @return shuffle chunks to fetch. @@ -113,11 +112,10 @@ private class PushBasedFetchHelper( shuffleId: Int, reduceId: Int, blockSize: Long, - numChunks: Int, bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { - val approxChunkSize = blockSize / numChunks + val approxChunkSize = blockSize / bitmaps.length val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() - for (i <- 0 until numChunks) { + for (i <- bitmaps.indices) { val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) chunksMetaMap.put(blockChunkId, bitmaps(i)) logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") @@ -146,7 +144,7 @@ private class PushBasedFetchHelper( s"from ${req.address.host}:${req.address.port}") try { iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId, - sizeMap((shuffleId, reduceId)), meta.getNumChunks, meta.readChunkBitmaps(), address)) + sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address)) } catch { case exception: Exception => logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " + @@ -242,8 +240,8 @@ private class PushBasedFetchHelper( val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) iterator.addToResultsQueue(PushMergedLocalMetaFetchResult( - shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.getNumChunks, - chunksMeta.readChunkBitmaps(), localDirs)) + shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(), + localDirs)) } catch { case e: Exception => // If we see an exception with reading a local push-merged meta, we fallback to diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 8bbbb146ddb3..fa556d950b87 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -453,7 +453,7 @@ final class ShuffleBlockFetcherIterator( isLast: Boolean, collectedRemoteRequests: ArrayBuffer[FetchRequest], enableBatchFetch: Boolean, - forMergedBlocks: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) numBlocksToFetch += mergedBlocks.size val retBlocks = new ArrayBuffer[FetchBlockInfo] @@ -473,7 +473,7 @@ final class ShuffleBlockFetcherIterator( } retBlocks } - + private def collectFetchRequests( address: BlockManagerId, blockInfos: Seq[(BlockId, Long, Int)], @@ -493,13 +493,13 @@ final class ShuffleBlockFetcherIterator( case ShuffleBlockChunkId(_, _, _) => if (curRequestSize >= targetRemoteRequestSize || curBlocks.size >= maxBlocksInFlightPerAddress) { - curBlocks = createFetchRequests(curBlocks, address, isLast = false, + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, collectedRemoteRequests, enableBatchFetch = false) curRequestSize = curBlocks.map(_.size).sum } case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => if (curBlocks.size >= maxBlocksInFlightPerAddress) { - curBlocks = createFetchRequests(curBlocks, address, isLast = false, + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true) } case _ => @@ -514,7 +514,7 @@ final class ShuffleBlockFetcherIterator( } // Add in the final request if (curBlocks.nonEmpty) { - val (enableBatchFetch, areMergedBlocks) = { + val (enableBatchFetch, forMergedMetas) = { curBlocks.head.blockId match { case ShuffleBlockChunkId(_, _, _) => (false, false) case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true) @@ -522,7 +522,7 @@ final class ShuffleBlockFetcherIterator( } } createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests, - enableBatchFetch = enableBatchFetch, forMergedMetas = areMergedBlocks) + enableBatchFetch = enableBatchFetch, forMergedMetas = forMergedMetas) } } @@ -554,8 +554,8 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() - results.put(SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, - buf.size(), buf, isNetworkReqDone = false)) + results.put(new SuccessFetchResult(blockId, mapIndex, blockManager.blockManagerId, + buf.size(), buf, false)) } catch { // If we see an exception, stop immediately. case e: Exception => @@ -901,7 +901,7 @@ final class ShuffleBlockFetcherIterator( // a SuccessFetchResult or a FailureFetchResult. result = null - case PushMergedLocalMetaFetchResult(shuffleId, reduceId, _, bitmaps, localDirs, _) => + case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs, _) => // Fetch local push-merged shuffle block data as multiple shuffle chunks val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId) try { @@ -923,7 +923,7 @@ final class ShuffleBlockFetcherIterator( } } catch { case e: Exception => - // If we see an exception with reading local push-merged data, we fallback to + // If we see an exception with reading local push-merged data, we fallback to // fetch the original blocks. We do not report block fetch failure // and will continue with the remaining local block read. logWarning(s"Error occurred while fetching local push-merged data, " + @@ -933,15 +933,14 @@ final class ShuffleBlockFetcherIterator( } result = null - case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, numChunks, bitmaps, - address, _) => + case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, bitmaps, address, _) => // The original meta request is processed so we decrease numBlocksToFetch and // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 numBlocksToFetch -= 1 val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( - shuffleId, reduceId, blockSize, numChunks, bitmaps) + shuffleId, reduceId, blockSize, bitmaps) val additionalRemoteReqs = new ArrayBuffer[FetchRequest] collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) fetchRequests ++= additionalRemoteReqs @@ -1428,9 +1427,8 @@ object ShuffleBlockFetcherIterator { * Result of a successful fetch of meta information for a remote push-merged block. * * @param shuffleId shuffle id. - * @param reduceId reduce id. + * @param reduceId reduce id. * @param blockSize size of each push-merged block. - * @param numChunks number of chunks in the push-merged block. * @param bitmaps bitmaps for every chunk. * @param address BlockManager that the meta was fetched from. */ @@ -1438,7 +1436,6 @@ object ShuffleBlockFetcherIterator { shuffleId: Int, reduceId: Int, blockSize: Long, - numChunks: Int, bitmaps: Array[RoaringBitmap], address: BlockManagerId, blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult @@ -1447,8 +1444,8 @@ object ShuffleBlockFetcherIterator { * Result of a failure while fetching the meta information for a remote push-merged block. * * @param shuffleId shuffle id. - * @param reduceId reduce id. - * @param address BlockManager that the meta was fetched from. + * @param reduceId reduce id. + * @param address BlockManager that the meta was fetched from. */ private[storage] case class PushMergedRemoteMetaFailedFetchResult( shuffleId: Int, @@ -1460,15 +1457,13 @@ object ShuffleBlockFetcherIterator { * Result of a successful fetch of meta information for a local push-merged block. * * @param shuffleId shuffle id. - * @param reduceId reduce id. - * @param numChunks number of chunks in the push-merged block. + * @param reduceId reduce id. * @param bitmaps bitmaps for every chunk. * @param localDirs local directories where the push-merged shuffle files are storedl */ private[storage] case class PushMergedLocalMetaFetchResult( shuffleId: Int, reduceId: Int, - numChunks: Int, bitmaps: Array[RoaringBitmap], localDirs: Array[String], blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult From 69fcc23cdcde24b75696c33b3624e08ac0998641 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Mon, 28 Jun 2021 15:40:27 -0700 Subject: [PATCH 25/27] Rebased against apache master --- .../org/apache/spark/storage/ShuffleBlockFetcherIterator.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index fa556d950b87..a729b1003aef 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -387,7 +387,6 @@ final class ShuffleBlockFetcherIterator( pushMergedLocalBlocks ++= blockInfos.map(_._1) pushMergedLocalBlockBytes += blockInfos.map(_._3).sum } else { - remoteBlockBytes += blockInfos.map(_._2).sum collectFetchRequests(address, blockInfos, collectedRemoteRequests) } } else if (mutable.HashSet(blockManager.blockManagerId.executorId, fallback) From c10b94337ee56757e56713005396309fb9a592fc Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Mon, 28 Jun 2021 21:42:20 -0700 Subject: [PATCH 26/27] Addressed Mridul's comments --- .../apache/spark/storage/ShuffleBlockFetcherIterator.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a729b1003aef..db7480a03ed3 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -378,6 +378,7 @@ final class ShuffleBlockFetcherIterator( val prevNumBlocksToFetch = numBlocksToFetch val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) for ((address, blockInfos) <- blocksByAddress) { checkBlockSizes(blockInfos) if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { @@ -389,8 +390,7 @@ final class ShuffleBlockFetcherIterator( } else { collectFetchRequests(address, blockInfos, collectedRemoteRequests) } - } else if (mutable.HashSet(blockManager.blockManagerId.executorId, fallback) - .contains(address.executorId)) { + } else if (localExecIds.contains(address.executorId)) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size From ad89a0208a5e3f880fca502c297362388a104dd7 Mon Sep 17 00:00:00 2001 From: Chandni Singh Date: Tue, 29 Jun 2021 08:45:12 -0700 Subject: [PATCH 27/27] Addressed review comments --- .../spark/storage/PushBasedFetchHelper.scala | 22 +++---- .../storage/ShuffleBlockFetcherIterator.scala | 62 ++++++++----------- .../ShuffleBlockFetcherIteratorSuite.scala | 12 ++-- 3 files changed, 44 insertions(+), 52 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala index b16562e17a1d..63f42a0024e3 100644 --- a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -71,7 +71,7 @@ private class PushBasedFetchHelper( } /** - * Returns true if the address is of a local push-merged block. false otherwise. + * Returns true if the address is of a push-merged-local block. false otherwise. */ def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host @@ -190,14 +190,14 @@ private class PushBasedFetchHelper( val cachedPushedMergedDirs = hostLocalDirManager.getCachedHostLocalDirsFor( SHUFFLE_MERGER_IDENTIFIER) if (cachedPushedMergedDirs.isDefined) { - logDebug(s"Fetch the local push-merged blocks with cached merged dirs: " + + logDebug(s"Fetch the push-merged-local blocks with cached merged dirs: " + s"${cachedPushedMergedDirs.get.mkString(", ")}") pushMergedLocalBlocks.foreach { blockId => fetchPushMergedLocalBlock(blockId, cachedPushedMergedDirs.get, localShuffleMergerBlockMgrId) } } else { - logDebug(s"Asynchronous fetch the local push-merged blocks without cached merged dirs") + logDebug(s"Asynchronous fetch the push-merged-local blocks without cached merged dirs") hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) { case Success(dirs) => @@ -211,9 +211,9 @@ private class PushBasedFetchHelper( localShuffleMergerBlockMgrId) } case Failure(throwable) => - // If we see an exception with getting the local dirs for local push-merged blocks, + // If we see an exception with getting the local dirs for push-merged-local blocks, // we fallback to fetch the original blocks. We do not report block fetch failure. - logWarning(s"Error while fetching the merged dirs for local push-merged " + + logWarning(s"Error while fetching the merged dirs for push-merged-local " + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", throwable) pushMergedLocalBlocks.foreach { @@ -226,7 +226,7 @@ private class PushBasedFetchHelper( } /** - * Fetch a single local push-merged block generated. This can also be executed by the task thread + * Fetch a single push-merged-local block generated. This can also be executed by the task thread * as well as the netty thread. * @param blockId ShuffleBlockId to be fetched * @param localDirs Local directories where the push-merged shuffle files are stored @@ -244,10 +244,10 @@ private class PushBasedFetchHelper( localDirs)) } catch { case e: Exception => - // If we see an exception with reading a local push-merged meta, we fallback to + // If we see an exception with reading a push-merged-local meta, we fallback to // fetch the original blocks. We do not report block fetch failure // and will continue with the remaining local block read. - logWarning(s"Error occurred while fetching local push-merged meta, " + + logWarning(s"Error occurred while fetching push-merged-local meta, " + s"prepare to fetch the original blocks", e) iterator.addToResultsQueue( FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) @@ -270,7 +270,7 @@ private class PushBasedFetchHelper( * finds more push-merged requests to remote and again updates it with additional requests for * original blocks. * The fallback happens when: - * 1. There is an exception while creating shuffle chunks from local push-merged shuffle block. + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. * See fetchLocalBlock. * 2. There is a failure when fetching remote shuffle chunks. * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk @@ -286,7 +286,7 @@ private class PushBasedFetchHelper( val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = blockId match { case shuffleBlockId: ShuffleBlockId => - iterator.incrementNumBlocksToFetch(-1) + iterator.decreaseNumBlocksToFetch(1) mapOutputTracker.getMapSizesForMergeResult( shuffleBlockId.shuffleId, shuffleBlockId.reduceId) case _ => @@ -311,7 +311,7 @@ private class PushBasedFetchHelper( // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed blocksProcessed += pendingShuffleChunks.size } - iterator.incrementNumBlocksToFetch(-blocksProcessed) + iterator.decreaseNumBlocksToFetch(blocksProcessed) mapOutputTracker.getMapSizesForMergeResult( shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index db7480a03ed3..094c3b5fc7c2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -60,7 +60,7 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * Note that zero-sized blocks are already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. * @param mapOutputTracker [[MapOutputTracker]] for falling back to fetching the original blocks if - * we fail to fetch shuffle chunks when push based shuffle is enabled. + * we fail to fetch shuffle chunks when push based shuffle is enabled. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -371,9 +371,9 @@ final class ShuffleBlockFetcherIterator( // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight // in order to limit the amount of data in flight val collectedRemoteRequests = new ArrayBuffer[FetchRequest] - val hostLocalBlocksCurrentIteration = mutable.LinkedHashSet[(BlockId, Int)]() var localBlockBytes = 0L var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 var pushMergedLocalBlockBytes = 0L val prevNumBlocksToFetch = numBlocksToFetch @@ -404,7 +404,7 @@ final class ShuffleBlockFetcherIterator( val blocksForAddress = mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) hostLocalBlocksByExecutor += address -> blocksForAddress - hostLocalBlocksCurrentIteration ++= blocksForAddress.map(info => (info._1, info._3)) + numHostLocalBlocks += blocksForAddress.size hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum } else { val (_, timeCost) = Utils.timeTakenMs[Unit] { @@ -419,21 +419,22 @@ final class ShuffleBlockFetcherIterator( pushMergedLocalBlockBytes val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch assert(blocksToFetchCurrentIteration == localBlocks.size + - hostLocalBlocksCurrentIteration.size + numRemoteBlocks + pushMergedLocalBlocks.size, - s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to " + - s"the number of local blocks ${localBlocks.size} + " + - s"the number of host-local blocks ${hostLocalBlocksCurrentIteration.size} " + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks ${numHostLocalBlocks} " + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + s"+ the number of remote blocks ${numRemoteBlocks} ") logInfo(s"Getting $blocksToFetchCurrentIteration " + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + - s"${hostLocalBlocksCurrentIteration.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"${numHostLocalBlocks} (${Utils.bytesToString(hostLocalBlockBytes)}) " + s"host-local and ${pushMergedLocalBlocks.size} " + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + - s"local push-merged and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + s"remote blocks") - this.hostLocalBlocks ++= hostLocalBlocksCurrentIteration + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap { infos => infos.map(info => (info._1, info._3)) } collectedRemoteRequests } @@ -883,9 +884,9 @@ final class ShuffleBlockFetcherIterator( // We get this result in 3 cases: // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the // blockId is a ShuffleBlockChunkId. - // 2. Failure to read the local push-merged meta. In this case, the blockId is + // 2. Failure to read the push-merged-local meta. In this case, the blockId is // ShuffleBlockId. - // 3. Failure to get the local push-merged directories from the ESS. In this case, the + // 3. Failure to get the push-merged-local directories from the ESS. In this case, the // blockId is ShuffleBlockId. if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 @@ -900,8 +901,8 @@ final class ShuffleBlockFetcherIterator( // a SuccessFetchResult or a FailureFetchResult. result = null - case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs, _) => - // Fetch local push-merged shuffle block data as multiple shuffle chunks + case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId) try { val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, @@ -922,17 +923,17 @@ final class ShuffleBlockFetcherIterator( } } catch { case e: Exception => - // If we see an exception with reading local push-merged data, we fallback to - // fetch the original blocks. We do not report block fetch failure + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure // and will continue with the remaining local block read. - logWarning(s"Error occurred while fetching local push-merged data, " + + logWarning(s"Error occurred while reading push-merged-local index, " + s"prepare to fetch the original blocks", e) pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( shuffleBlockId, pushBasedFetchHelper.localShuffleMergerBlockMgrId) } result = null - case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, bitmaps, address, _) => + case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, bitmaps, address) => // The original meta request is processed so we decrease numBlocksToFetch and // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. @@ -946,7 +947,7 @@ final class ShuffleBlockFetcherIterator( // Set result to null to force another iteration. result = null - case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address, _) => + case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address) => // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 // If we fail to fetch the meta of a push-merged block, we fall back to fetching the @@ -1071,8 +1072,8 @@ final class ShuffleBlockFetcherIterator( results.put(result) } - private[storage] def incrementNumBlocksToFetch(moreBlocksToFetch: Int): Unit = { - numBlocksToFetch += moreBlocksToFetch + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched } /** @@ -1091,7 +1092,7 @@ final class ShuffleBlockFetcherIterator( originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks) // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(originalRemoteReqs) - logInfo(s"Started ${originalRemoteReqs.size} fallback remote requests for push-merged") + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") // fetch all the fallback blocks that are local. fetchLocalBlocks(originalLocalBlocks) // Merged local blocks should be empty during fallback @@ -1246,12 +1247,6 @@ object ShuffleBlockFetcherIterator { } } - /** - * Dummy shuffle block id to fill into [[PushMergedRemoteMetaFetchResult]] and - * [[PushMergedRemoteMetaFailedFetchResult]], to match the [[FetchResult]] trait. - */ - private val DUMMY_SHUFFLE_BLOCK_ID = ShuffleBlockId(-1, -1, -1) - /** * This function is used to merged blocks when doBatchFetch is true. Blocks which have the * same `mapId` can be merged into one block batch. The block batch is specified by a range @@ -1436,8 +1431,7 @@ object ShuffleBlockFetcherIterator { reduceId: Int, blockSize: Long, bitmaps: Array[RoaringBitmap], - address: BlockManagerId, - blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult + address: BlockManagerId) extends FetchResult /** * Result of a failure while fetching the meta information for a remote push-merged block. @@ -1449,11 +1443,10 @@ object ShuffleBlockFetcherIterator { private[storage] case class PushMergedRemoteMetaFailedFetchResult( shuffleId: Int, reduceId: Int, - address: BlockManagerId, - blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult + address: BlockManagerId) extends FetchResult /** - * Result of a successful fetch of meta information for a local push-merged block. + * Result of a successful fetch of meta information for a push-merged-local block. * * @param shuffleId shuffle id. * @param reduceId reduce id. @@ -1464,6 +1457,5 @@ object ShuffleBlockFetcherIterator { shuffleId: Int, reduceId: Int, bitmaps: Array[RoaringBitmap], - localDirs: Array[String], - blockId: BlockId = DUMMY_SHUFFLE_BLOCK_ID) extends FetchResult + localDirs: Array[String]) extends FetchResult } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index db1d8bdca91c..a5143cd95ead 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -1263,7 +1263,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(!iterator.hasNext) } - test("SPARK-32922: failure to fetch local push-merged meta should fallback to fetch " + + test("SPARK-32922: failure to fetch push-merged-local meta should fallback to fetch " + "original shuffle blocks") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("testPath1", "testPath2") @@ -1276,7 +1276,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } - test("SPARK-32922: failure to reading chunkBitmaps of local push-merged meta should " + + test("SPARK-32922: failure to reading chunkBitmaps of push-merged-local meta should " + "fallback to original shuffle blocks") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("local-dir") @@ -1288,7 +1288,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } - test("SPARK-32922: failure to fetch local push-merged data should fallback to fetch " + + test("SPARK-32922: failure to fetch push-merged-local data should fallback to fetch " + "original shuffle blocks") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("testPath1", "testPath2") @@ -1301,8 +1301,8 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verifyLocalBlocksFromFallback(iterator) } - test("SPARK-32922: failure to fetch local push-merged meta of a single merged block " + - "should not drop the fetch of other local push-merged blocks") { + test("SPARK-32922: failure to fetch push-merged-local meta of a single merged block " + + "should not drop the fetch of other push-merged-local blocks") { val blockManager = mock(classOf[BlockManager]) val localDirs = Array("testPath1", "testPath2") prepareForFallbackToLocalBlocks( @@ -1375,7 +1375,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT intercept[FetchFailedException] { iterator.next() } } - test("SPARK-32922: failure to fetch local push-merged block should fallback to fetch " + + test("SPARK-32922: failure to fetch push-merged-local block should fallback to fetch " + "original shuffle blocks which contain host-local blocks") { val blockManager = mock(classOf[BlockManager]) // BlockManagerId from another executor on the same host