diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 2b06c4980515..e605eea80229 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -1453,7 +1453,8 @@ private[spark] object MapOutputTracker extends Logging { // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate this is // a merged shuffle block. splitsByAddress.getOrElseUpdate(mergeStatus.location, ListBuffer()) += - ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, -1)) + ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), mergeStatus.totalSize, + SHUFFLE_PUSH_MAP_ID)) // For the "holes" in this pre-merged shuffle partition, i.e., unmerged mapper // shuffle partition blocks, fetch the original map produced shuffle partition blocks val mapStatusesWithIndex = mapStatuses.zipWithIndex diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 623db9d00ab5..640396a69526 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -110,6 +110,7 @@ private[spark] class SerializerManager( private def shouldCompress(blockId: BlockId): Boolean = { blockId match { case _: ShuffleBlockId => compressShuffle + case _: ShuffleBlockChunkId => compressShuffle case _: BroadcastBlockId => compressBroadcast case _: RDDBlockId => compressRdds case _: TempLocalBlockId => compressShuffleSpill diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 6782c748aff7..818aa2ef75a9 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -35,6 +35,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, shouldBatchFetch: Boolean = false) extends ShuffleReader[K, C] with Logging { @@ -71,6 +72,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( context, blockManager.blockStoreClient, blockManager, + mapOutputTracker, blocksByAddress, serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 47c1b9664103..dc70a9af7e9c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -43,6 +43,7 @@ sealed abstract class BlockId { (isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId] || isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId]) } + def isShuffleChunk: Boolean = isInstanceOf[ShuffleBlockChunkId] def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name @@ -72,6 +73,15 @@ case class ShuffleBlockBatchId( } } +@Since("3.2.0") +@DeveloperApi +case class ShuffleBlockChunkId( + shuffleId: Int, + reduceId: Int, + chunkId: Int) extends BlockId { + override def name: String = "shuffleChunk_" + shuffleId + "_" + reduceId + "_" + chunkId +} + @DeveloperApi case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" @@ -152,7 +162,7 @@ class UnrecognizedBlockId(name: String) @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r - val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r + val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r @@ -160,6 +170,7 @@ object BlockId { val SHUFFLE_MERGED_DATA = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r val SHUFFLE_MERGED_INDEX = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r val SHUFFLE_MERGED_META = "shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r + val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)".r val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r @@ -186,6 +197,8 @@ object BlockId { ShuffleMergedIndexBlockId(appId, shuffleId.toInt, reduceId.toInt) case SHUFFLE_MERGED_META(appId, shuffleId, reduceId) => ShuffleMergedMetaBlockId(appId, shuffleId.toInt, reduceId.toInt) + case SHUFFLE_CHUNK(shuffleId, reduceId, chunkId) => + ShuffleBlockChunkId(shuffleId.toInt, reduceId.toInt, chunkId.toInt) case BROADCAST(broadcastId, field) => BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index df449fba24e9..98d094939cd4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -132,6 +132,11 @@ private[spark] class HostLocalDirManager( executorIdToLocalDirsCache.asMap().asScala.toMap } + private[spark] def getCachedHostLocalDirsFor(executorId: String): Option[Array[String]] = + executorIdToLocalDirsCache.synchronized { + Option(executorIdToLocalDirsCache.getIfPresent(executorId)) + } + private[spark] def getHostLocalDirs( host: String, port: Int, diff --git a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala new file mode 100644 index 000000000000..63f42a0024e3 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala @@ -0,0 +1,320 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import java.util.concurrent.TimeUnit + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.{Failure, Success} + +import org.roaringbitmap.RoaringBitmap + +import org.apache.spark.MapOutputTracker +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID +import org.apache.spark.internal.Logging +import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ + +/** + * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the push-based + * functionality to fetch push-merged block meta and shuffle chunks. + * A push-merged block contains multiple shuffle chunks where each shuffle chunk contains multiple + * shuffle blocks that belong to the common reduce partition and were merged by the ESS to that + * chunk. + */ +private class PushBasedFetchHelper( + private val iterator: ShuffleBlockFetcherIterator, + private val shuffleClient: BlockStoreClient, + private val blockManager: BlockManager, + private val mapOutputTracker: MapOutputTracker) extends Logging { + + private[this] val startTimeNs = System.nanoTime() + + private[storage] val localShuffleMergerBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host, + blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo) + + /** + * A map for storing shuffle chunk bitmap. + */ + private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, RoaringBitmap]() + + /** + * Returns true if the address is for a push-merged block. + */ + def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = { + SHUFFLE_MERGER_IDENTIFIER == address.executorId + } + + /** + * Returns true if the address is of a remote push-merged block. false otherwise. + */ + def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host != blockManager.blockManagerId.host + } + + /** + * Returns true if the address is of a push-merged-local block. false otherwise. + */ + def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = { + isPushMergedShuffleBlockAddress(address) && address.host == blockManager.blockManagerId.host + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.SuccessFetchResult]]. + * + * @param blockId shuffle chunk id. + */ + def removeChunk(blockId: ShuffleBlockChunkId): Unit = { + chunksMetaMap.remove(blockId) + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]]. + * + * @param blockId shuffle chunk id. + */ + def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = { + chunksMetaMap(blockId) = chunkMeta + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]]. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param blockSize size of the push-merged block. + * @param bitmaps chunk bitmaps, where each bitmap contains all the mapIds that were merged + * to that chunk. + * @return shuffle chunks to fetch. + */ + def createChunkBlockInfosFromMetaResponse( + shuffleId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = { + val approxChunkSize = blockSize / bitmaps.length + val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]() + for (i <- bitmaps.indices) { + val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i) + chunksMetaMap.put(blockChunkId, bitmaps(i)) + logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize") + blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID)) + } + blocksToFetch + } + + /** + * This is executed by the task thread when the iterator is initialized and only if it has + * push-merged blocks for which it needs to fetch the metadata. + * + * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only contains requests to fetch + * metadata of push-merged blocks. + */ + def sendFetchMergedStatusRequest(req: FetchRequest): Unit = { + val sizeMap = req.blocks.map { + case FetchBlockInfo(blockId, size, _) => + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size) + }.toMap + val address = req.address + val mergedBlocksMetaListener = new MergedBlocksMetaListener { + override def onSuccess(shuffleId: Int, reduceId: Int, meta: MergedBlockMeta): Unit = { + logInfo(s"Received the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}") + try { + iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId, + sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address)) + } catch { + case exception: Exception => + logError(s"Failed to parse the meta of push-merged block for ($shuffleId, " + + s"$reduceId) from ${req.address.host}:${req.address.port}", exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + + override def onFailure(shuffleId: Int, reduceId: Int, exception: Throwable): Unit = { + logError(s"Failed to get the meta of push-merged block for ($shuffleId, $reduceId) " + + s"from ${req.address.host}:${req.address.port}", exception) + iterator.addToResultsQueue( + PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address)) + } + } + req.blocks.foreach { block => + val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId] + shuffleClient.getMergedBlockMeta(address.host, address.port, shuffleBlockId.shuffleId, + shuffleBlockId.reduceId, mergedBlocksMetaListener) + } + } + + /** + * This is executed by the task thread when the iterator is initialized. It fetches all the + * outstanding push-merged local blocks. + * @param pushMergedLocalBlocks set of identified merged local blocks and their sizes. + */ + def fetchAllPushMergedLocalBlocks( + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + if (pushMergedLocalBlocks.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, pushMergedLocalBlocks)) + } + } + + /** + * Fetch the push-merged blocks dirs if they are not in the cache and eventually fetch push-merged + * local blocks. + */ + private def fetchPushMergedLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = { + val cachedPushedMergedDirs = hostLocalDirManager.getCachedHostLocalDirsFor( + SHUFFLE_MERGER_IDENTIFIER) + if (cachedPushedMergedDirs.isDefined) { + logDebug(s"Fetch the push-merged-local blocks with cached merged dirs: " + + s"${cachedPushedMergedDirs.get.mkString(", ")}") + pushMergedLocalBlocks.foreach { blockId => + fetchPushMergedLocalBlock(blockId, cachedPushedMergedDirs.get, + localShuffleMergerBlockMgrId) + } + } else { + logDebug(s"Asynchronous fetch the push-merged-local blocks without cached merged dirs") + hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host, + localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) { + case Success(dirs) => + logDebug(s"Fetched merged dirs in " + + s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startTimeNs)} ms") + pushMergedLocalBlocks.foreach { + blockId => + logDebug(s"Successfully fetched local dirs: " + + s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}") + fetchPushMergedLocalBlock(blockId, dirs(SHUFFLE_MERGER_IDENTIFIER), + localShuffleMergerBlockMgrId) + } + case Failure(throwable) => + // If we see an exception with getting the local dirs for push-merged-local blocks, + // we fallback to fetch the original blocks. We do not report block fetch failure. + logWarning(s"Error while fetching the merged dirs for push-merged-local " + + s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the original blocks instead", + throwable) + pushMergedLocalBlocks.foreach { + blockId => + iterator.addToResultsQueue(FallbackOnPushMergedFailureResult( + blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = false)) + } + } + } + } + + /** + * Fetch a single push-merged-local block generated. This can also be executed by the task thread + * as well as the netty thread. + * @param blockId ShuffleBlockId to be fetched + * @param localDirs Local directories where the push-merged shuffle files are stored + * @param blockManagerId BlockManagerId + */ + private[this] def fetchPushMergedLocalBlock( + blockId: BlockId, + localDirs: Array[String], + blockManagerId: BlockManagerId): Unit = { + try { + val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId] + val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, localDirs) + iterator.addToResultsQueue(PushMergedLocalMetaFetchResult( + shuffleBlockId.shuffleId, shuffleBlockId.reduceId, chunksMeta.readChunkBitmaps(), + localDirs)) + } catch { + case e: Exception => + // If we see an exception with reading a push-merged-local meta, we fallback to + // fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning(s"Error occurred while fetching push-merged-local meta, " + + s"prepare to fetch the original blocks", e) + iterator.addToResultsQueue( + FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, isNetworkReqDone = false)) + } + } + + /** + * This is executed by the task thread when the `iterator.next()` is invoked and the iterator + * processes a response of type: + * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]] + * 2) [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]] + * 3) [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]] + * + * This initiates fetching fallback blocks for a push-merged block or a shuffle chunk that + * failed to fetch. + * It makes a call to the map output tracker to get the list of original blocks for the + * given push-merged block/shuffle chunk, split them into remote and local blocks, and process + * them accordingly. + * It also updates the numberOfBlocksToFetch in the iterator as it processes failed response and + * finds more push-merged requests to remote and again updates it with additional requests for + * original blocks. + * The fallback happens when: + * 1. There is an exception while creating shuffle chunks from push-merged-local shuffle block. + * See fetchLocalBlock. + * 2. There is a failure when fetching remote shuffle chunks. + * 3. There is a failure when processing SuccessFetchResult which is for a shuffle chunk + * (local or remote). + */ + def initiateFallbackFetchForPushMergedBlock( + blockId: BlockId, + address: BlockManagerId): Unit = { + assert(blockId.isInstanceOf[ShuffleBlockId] || blockId.isInstanceOf[ShuffleBlockChunkId]) + logWarning(s"Falling back to fetch the original blocks for push-merged block $blockId") + // Increase the blocks processed since we will process another block in the next iteration of + // the while loop in ShuffleBlockFetcherIterator.next(). + val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = + blockId match { + case shuffleBlockId: ShuffleBlockId => + iterator.decreaseNumBlocksToFetch(1) + mapOutputTracker.getMapSizesForMergeResult( + shuffleBlockId.shuffleId, shuffleBlockId.reduceId) + case _ => + val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId] + val chunkBitmap: RoaringBitmap = chunksMetaMap.remove(shuffleChunkId).get + var blocksProcessed = 1 + // When there is a failure to fetch a remote shuffle chunk, then we try to + // fallback not only for that particular remote shuffle chunk but also for all the + // pending chunks that belong to the same host. The reason for doing so is that it + // is very likely that the subsequent requests for shuffle chunks from this host will + // fail as well. Since, push-based shuffle is best effort and we try not to increase the + // delay of the fetches, we immediately fallback for all the pending shuffle chunks in the + // fetchRequests queue. + if (isRemotePushMergedBlockAddress(address)) { + // Fallback for all the pending fetch requests + val pendingShuffleChunks = iterator.removePendingChunks(shuffleChunkId, address) + pendingShuffleChunks.foreach { pendingBlockId => + logInfo(s"Falling back immediately for shuffle chunk $pendingBlockId") + val bitmapOfPendingChunk: RoaringBitmap = chunksMetaMap.remove(pendingBlockId).get + chunkBitmap.or(bitmapOfPendingChunk) + } + // These blocks were added to numBlocksToFetch so we increment numBlocksProcessed + blocksProcessed += pendingShuffleChunks.size + } + iterator.decreaseNumBlocksToFetch(blocksProcessed) + mapOutputTracker.getMapSizesForMergeResult( + shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap) + } + iterator.fallbackFetch(fallbackBlocksByAddr) + } +} diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4465d76e3127..094c3b5fc7c2 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -24,13 +24,15 @@ import java.util.concurrent.atomic.AtomicBoolean import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import scala.util.{Failure, Success} import io.netty.util.internal.OutOfDirectMemoryError import org.apache.commons.io.IOUtils +import org.roaringbitmap.RoaringBitmap -import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkException, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle._ @@ -57,6 +59,8 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils} * block, which indicate the index in the map stage. * Note that zero-sized blocks are already excluded, which happened in * [[org.apache.spark.MapOutputTracker.convertMapStatuses]]. + * @param mapOutputTracker [[MapOutputTracker]] for falling back to fetching the original blocks if + * we fail to fetch shuffle chunks when push based shuffle is enabled. * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. @@ -75,6 +79,7 @@ final class ShuffleBlockFetcherIterator( context: TaskContext, shuffleClient: BlockStoreClient, blockManager: BlockManager, + mapOutputTracker: MapOutputTracker, blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, @@ -108,13 +113,6 @@ final class ShuffleBlockFetcherIterator( private[this] val startTimeNs = System.nanoTime() - /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() - - /** Host local blockIds to fetch by executors, excluding zero-sized blocks. */ - private[this] val hostLocalBlocksByExecutor = - LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() - /** Host local blocks to fetch, excluding zero-sized blocks. */ private[this] val hostLocalBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]() @@ -179,6 +177,9 @@ final class ShuffleBlockFetcherIterator( private[this] val onCompleteCallback = new ShuffleFetchCompletionListener(this) + private[this] val pushBasedFetchHelper = new PushBasedFetchHelper( + this, shuffleClient, blockManager, mapOutputTracker) + initialize() // Decrements the buffer reference count. @@ -329,7 +330,14 @@ final class ShuffleBlockFetcherIterator( } case _ => - results.put(FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e)) + val block = BlockId(blockId) + if (block.isShuffleChunk) { + remainingBlocks -= blockId + results.put(FallbackOnPushMergedFailureResult( + block, address, infoMap(blockId)._1, remainingBlocks.isEmpty)) + } else { + results.put(FailureFetchResult(block, infoMap(blockId)._2, address, e)) + } } } } @@ -347,20 +355,42 @@ final class ShuffleBlockFetcherIterator( } } - private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = { + /** + * This is called from initialize and also from the fallback which is triggered from + * [[PushBasedFetchHelper]]. + */ + private[this] def partitionBlocksByFetchMode( + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], + localBlocks: mutable.LinkedHashSet[(BlockId, Int)], + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]], + pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): ArrayBuffer[FetchRequest] = { logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: " + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: $maxBlocksInFlightPerAddress") - // Partition to local, host-local and remote blocks. Remote blocks are further split into - // FetchRequests of size at most maxBytesInFlight in order to limit the amount of data in flight + // Partition to local, host-local, push-merged-local, remote (includes push-merged-remote) + // blocks.Remote blocks are further split into FetchRequests of size at most maxBytesInFlight + // in order to limit the amount of data in flight val collectedRemoteRequests = new ArrayBuffer[FetchRequest] var localBlockBytes = 0L var hostLocalBlockBytes = 0L + var numHostLocalBlocks = 0 + var pushMergedLocalBlockBytes = 0L + val prevNumBlocksToFetch = numBlocksToFetch val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId + val localExecIds = Set(blockManager.blockManagerId.executorId, fallback) for ((address, blockInfos) <- blocksByAddress) { - if (Seq(blockManager.blockManagerId.executorId, fallback).contains(address.executorId)) { - checkBlockSizes(blockInfos) + checkBlockSizes(blockInfos) + if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) { + // These are push-merged blocks or shuffle chunks of these blocks. + if (address.host == blockManager.blockManagerId.host) { + numBlocksToFetch += blockInfos.size + pushMergedLocalBlocks ++= blockInfos.map(_._1) + pushMergedLocalBlockBytes += blockInfos.map(_._3).sum + } else { + collectFetchRequests(address, blockInfos, collectedRemoteRequests) + } + } else if (localExecIds.contains(address.executorId)) { val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size @@ -368,14 +398,13 @@ final class ShuffleBlockFetcherIterator( localBlockBytes += mergedBlockInfos.map(_.size).sum } else if (blockManager.hostLocalDirManager.isDefined && address.host == blockManager.blockManagerId.host) { - checkBlockSizes(blockInfos) val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded( blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), doBatchFetch) numBlocksToFetch += mergedBlockInfos.size val blocksForAddress = mergedBlockInfos.map(info => (info.blockId, info.size, info.mapIndex)) hostLocalBlocksByExecutor += address -> blocksForAddress - hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3)) + numHostLocalBlocks += blocksForAddress.size hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum } else { val (_, timeCost) = Utils.timeTakenMs[Unit] { @@ -386,40 +415,54 @@ final class ShuffleBlockFetcherIterator( } val (remoteBlockBytes, numRemoteBlocks) = collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 + y.blocks.size)) - val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes - assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + numRemoteBlocks, - s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the number of local " + - s"blocks ${localBlocks.size} + the number of host-local blocks ${hostLocalBlocks.size} " + - s"+ the number of remote blocks ${numRemoteBlocks}.") - logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) non-empty blocks " + - s"including ${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + - s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) " + - s"host-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) remote blocks") + val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes + + pushMergedLocalBlockBytes + val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch + assert(blocksToFetchCurrentIteration == localBlocks.size + + numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size, + s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't equal to the sum " + + s"of the number of local blocks ${localBlocks.size} + " + + s"the number of host-local blocks ${numHostLocalBlocks} " + + s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} " + + s"+ the number of remote blocks ${numRemoteBlocks} ") + logInfo(s"Getting $blocksToFetchCurrentIteration " + + s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " + + s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local and " + + s"${numHostLocalBlocks} (${Utils.bytesToString(hostLocalBlockBytes)}) " + + s"host-local and ${pushMergedLocalBlocks.size} " + + s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " + + s"push-merged-local and $numRemoteBlocks (${Utils.bytesToString(remoteBlockBytes)}) " + + s"remote blocks") + this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values + .flatMap { infos => infos.map(info => (info._1, info._3)) } collectedRemoteRequests } private def createFetchRequest( blocks: Seq[FetchBlockInfo], - address: BlockManagerId): FetchRequest = { + address: BlockManagerId, + forMergedMetas: Boolean): FetchRequest = { logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address " + s"with ${blocks.size} blocks") - FetchRequest(address, blocks) + FetchRequest(address, blocks, forMergedMetas) } private def createFetchRequests( curBlocks: Seq[FetchBlockInfo], address: BlockManagerId, isLast: Boolean, - collectedRemoteRequests: ArrayBuffer[FetchRequest]): ArrayBuffer[FetchBlockInfo] = { - val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, doBatchFetch) + collectedRemoteRequests: ArrayBuffer[FetchRequest], + enableBatchFetch: Boolean, + forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = { + val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, enableBatchFetch) numBlocksToFetch += mergedBlocks.size val retBlocks = new ArrayBuffer[FetchBlockInfo] if (mergedBlocks.length <= maxBlocksInFlightPerAddress) { - collectedRemoteRequests += createFetchRequest(mergedBlocks, address) + collectedRemoteRequests += createFetchRequest(mergedBlocks, address, forMergedMetas) } else { mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks => if (blocks.length == maxBlocksInFlightPerAddress || isLast) { - collectedRemoteRequests += createFetchRequest(blocks, address) + collectedRemoteRequests += createFetchRequest(blocks, address, forMergedMetas) } else { // The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back // to `curBlocks`. @@ -441,20 +484,45 @@ final class ShuffleBlockFetcherIterator( while (iterator.hasNext) { val (blockId, size, mapIndex) = iterator.next() - assertPositiveBlockSize(blockId, size) curBlocks += FetchBlockInfo(blockId, size, mapIndex) curRequestSize += size - // For batch fetch, the actual block in flight should count for merged block. - val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress - if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { - curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, - collectedRemoteRequests) - curRequestSize = curBlocks.map(_.size).sum + blockId match { + // Either all blocks are push-merged blocks, shuffle chunks, or original blocks. + // Based on these types, we decide to do batch fetch and create FetchRequests with + // forMergedMetas set. + case ShuffleBlockChunkId(_, _, _) => + if (curRequestSize >= targetRemoteRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, + collectedRemoteRequests, enableBatchFetch = false) + curRequestSize = curBlocks.map(_.size).sum + } + case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => + if (curBlocks.size >= maxBlocksInFlightPerAddress) { + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, + collectedRemoteRequests, enableBatchFetch = false, forMergedMetas = true) + } + case _ => + // For batch fetch, the actual block in flight should count for merged block. + val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress + if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) { + curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false, + collectedRemoteRequests, doBatchFetch) + curRequestSize = curBlocks.map(_.size).sum + } } } // Add in the final request if (curBlocks.nonEmpty) { - createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests) + val (enableBatchFetch, forMergedMetas) = { + curBlocks.head.blockId match { + case ShuffleBlockChunkId(_, _, _) => (false, false) + case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true) + case _ => (doBatchFetch, false) + } + } + createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests, + enableBatchFetch = enableBatchFetch, forMergedMetas = forMergedMetas) } } @@ -475,7 +543,8 @@ final class ShuffleBlockFetcherIterator( * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ - private[this] def fetchLocalBlocks(): Unit = { + private[this] def fetchLocalBlocks( + localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = { logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}") val iter = localBlocks.iterator while (iter.hasNext) { @@ -529,7 +598,10 @@ final class ShuffleBlockFetcherIterator( * `ManagedBuffer`'s memory is allocated lazily when we create the input stream, so all we * track in-memory are the ManagedBuffer references themselves. */ - private[this] def fetchHostLocalBlocks(hostLocalDirManager: HostLocalDirManager): Unit = { + private[this] def fetchHostLocalBlocks( + hostLocalDirManager: HostLocalDirManager, + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): + Unit = { val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = { val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case (hostLocalBmId, _) => @@ -602,9 +674,15 @@ final class ShuffleBlockFetcherIterator( private[this] def initialize(): Unit = { // Add a task completion callback (called in both success case and failure case) to cleanup. context.addTaskCompletionListener(onCompleteCallback) - - // Partition blocks by the different fetch modes: local, host-local and remote blocks. - val remoteRequests = partitionBlocksByFetchMode() + // Local blocks to fetch, excluding zero-sized blocks. + val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val hostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + // Partition blocks by the different fetch modes: local, host-local, push-merged-local and + // remote blocks. + val remoteRequests = partitionBlocksByFetchMode( + blocksByAddress, localBlocks, hostLocalBlocksByExecutor, pushMergedLocalBlocks) // Add the remote requests into our queue in a random order fetchRequests ++= Utils.randomize(remoteRequests) assert ((0 == reqsInFlight) == (0 == bytesInFlight), @@ -620,11 +698,18 @@ final class ShuffleBlockFetcherIterator( (if (numDeferredRequest > 0 ) s", deferred $numDeferredRequest requests" else "")) // Get Local Blocks - fetchLocalBlocks() + fetchLocalBlocks(localBlocks) logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}") + // Get host local blocks if any + fetchAllHostLocalBlocks(hostLocalBlocksByExecutor) + pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks) + } - if (hostLocalBlocks.nonEmpty) { - blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks) + private def fetchAllHostLocalBlocks( + hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]): + Unit = { + if (hostLocalBlocksByExecutor.nonEmpty) { + blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, hostLocalBlocksByExecutor)) } } @@ -661,7 +746,9 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { - if (hostLocalBlocks.contains(blockId -> mapIndex)) { + if (hostLocalBlocks.contains(blockId -> mapIndex) || + pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) { + // It is a host local block or a local shuffle chunk shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) } else { @@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator( case e: IOException => logError("Failed to create input stream from local block", e) } buf.release() - throwFetchFailedException(blockId, mapIndex, address, e) - } - try { - input = streamWrapper(blockId, in) - // If the stream is compressed or wrapped, then we optionally decompress/unwrap the - // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion - // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if - // the corruption is later, we'll still detect the corruption later in the stream. - streamCompressedOrEncrypted = !input.eq(in) - if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { - // TODO: manage the memory used here, and spill it into disk in case of OOM. - input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) - } - } catch { - case e: IOException => - buf.release() - if (buf.isInstanceOf[FileSegmentManagedBuffer] - || corruptedBlocks.contains(blockId)) { - throwFetchFailedException(blockId, mapIndex, address, e) - } else { - logWarning(s"got an corrupted block $blockId from $address, fetch again", e) - corruptedBlocks += blockId - fetchRequests += FetchRequest( - address, Array(FetchBlockInfo(blockId, size, mapIndex))) + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either. result = null + null + } else { + throwFetchFailedException(blockId, mapIndex, address, e) + } + } + if (in != null) { + try { + input = streamWrapper(blockId, in) + // If the stream is compressed or wrapped, then we optionally decompress/unwrap the + // first maxBytesInFlight/3 bytes into memory, to check for corruption in that portion + // of the data. But even if 'detectCorruptUseExtraMemory' configuration is off, or if + // the corruption is later, we'll still detect the corruption later in the stream. + streamCompressedOrEncrypted = !input.eq(in) + if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) { + // TODO: manage the memory used here, and spill it into disk in case of OOM. + input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3) + } + } catch { + case e: IOException => + buf.release() + if (blockId.isShuffleChunk) { + // Retrying a corrupt block may result again in a corrupt block. For shuffle + // chunks, we opt to fallback on the original shuffle blocks that belong to that + // corrupt shuffle chunk immediately instead of retrying to fetch the corrupt + // chunk. This also makes the code simpler because the chunkMeta corresponding to + // a shuffle chunk is always removed from chunksMetaMap whenever a shuffle chunk + // gets processed. If we try to re-fetch a corrupt shuffle chunk, then it has to + // be added back to the chunksMetaMap. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop. + result = null + } else { + if (buf.isInstanceOf[FileSegmentManagedBuffer] + || corruptedBlocks.contains(blockId)) { + throwFetchFailedException(blockId, mapIndex, address, e) + } else { + logWarning(s"got an corrupted block $blockId from $address, fetch again", e) + corruptedBlocks += blockId + fetchRequests += FetchRequest( + address, Array(FetchBlockInfo(blockId, size, mapIndex))) + result = null + } + } + } finally { + if (blockId.isShuffleChunk) { + pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId]) + } + // TODO: release the buf here to free memory earlier + if (input == null) { + // Close the underlying stream if there was an issue in wrapping the stream using + // streamWrapper + in.close() } - } finally { - // TODO: release the buf here to free memory earlier - if (input == null) { - // Close the underlying stream if there was an issue in wrapping the stream using - // streamWrapper - in.close() } } @@ -767,6 +879,83 @@ final class ShuffleBlockFetcherIterator( deferredFetchRequests.getOrElseUpdate(address, new Queue[FetchRequest]()) defReqQueue.enqueue(request) result = null + + case FallbackOnPushMergedFailureResult(blockId, address, size, isNetworkReqDone) => + // We get this result in 3 cases: + // 1. Failure to fetch the data of a remote shuffle chunk. In this case, the + // blockId is a ShuffleBlockChunkId. + // 2. Failure to read the push-merged-local meta. In this case, the blockId is + // ShuffleBlockId. + // 3. Failure to get the push-merged-local directories from the ESS. In this case, the + // blockId is ShuffleBlockId. + if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + bytesInFlight -= size + } + if (isNetworkReqDone) { + reqsInFlight -= 1 + logDebug("Number of requests in flight " + reqsInFlight) + } + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address) + // Set result to null to trigger another iteration of the while loop to get either + // a SuccessFetchResult or a FailureFetchResult. + result = null + + case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, localDirs) => + // Fetch push-merged-local shuffle block data as multiple shuffle chunks + val shuffleBlockId = ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId) + try { + val bufs: Seq[ManagedBuffer] = blockManager.getLocalMergedBlockData(shuffleBlockId, + localDirs) + // Since the request for local block meta completed successfully, numBlocksToFetch + // is decremented. + numBlocksToFetch -= 1 + // Update total number of blocks to fetch, reflecting the multiple local shuffle + // chunks. + numBlocksToFetch += bufs.size + bufs.zipWithIndex.foreach { case (buf, chunkId) => + buf.retain() + val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, chunkId) + pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId)) + results.put(SuccessFetchResult(shuffleChunkId, SHUFFLE_PUSH_MAP_ID, + pushBasedFetchHelper.localShuffleMergerBlockMgrId, buf.size(), buf, + isNetworkReqDone = false)) + } + } catch { + case e: Exception => + // If we see an exception with reading push-merged-local index file, we fallback + // to fetch the original blocks. We do not report block fetch failure + // and will continue with the remaining local block read. + logWarning(s"Error occurred while reading push-merged-local index, " + + s"prepare to fetch the original blocks", e) + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + shuffleBlockId, pushBasedFetchHelper.localShuffleMergerBlockMgrId) + } + result = null + + case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, bitmaps, address) => + // The original meta request is processed so we decrease numBlocksToFetch and + // numBlocksInFlightPerAddress by 1. We will collect new shuffle chunks request and the + // count of this is added to numBlocksToFetch in collectFetchReqsFromMergedBlocks. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + numBlocksToFetch -= 1 + val blocksToFetch = pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse( + shuffleId, reduceId, blockSize, bitmaps) + val additionalRemoteReqs = new ArrayBuffer[FetchRequest] + collectFetchRequests(address, blocksToFetch.toSeq, additionalRemoteReqs) + fetchRequests ++= additionalRemoteReqs + // Set result to null to force another iteration. + result = null + + case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address) => + // The original meta request failed so we decrease numBlocksInFlightPerAddress by 1. + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 + // If we fail to fetch the meta of a push-merged block, we fall back to fetching the + // original blocks. + pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock( + ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId), address) + // Set result to null to force another iteration. + result = null } // Send fetch requests up to maxBytesInFlight @@ -834,7 +1023,11 @@ final class ShuffleBlockFetcherIterator( } def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { - sendRequest(request) + if (request.forMergedMetas) { + pushBasedFetchHelper.sendFetchMergedStatusRequest(request) + } else { + sendRequest(request) + } numBlocksInFlightPerAddress(remoteAddress) = numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size } @@ -871,6 +1064,82 @@ final class ShuffleBlockFetcherIterator( "Failed to get block " + blockId + ", which is not a shuffle block", e) } } + + /** + * All the below methods are used by [[PushBasedFetchHelper]] to communicate with the iterator + */ + private[storage] def addToResultsQueue(result: FetchResult): Unit = { + results.put(result) + } + + private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = { + numBlocksToFetch -= blocksFetched + } + + /** + * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when there is a fetch + * failure related to a push-merged block or shuffle chunk. + * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates + * fallback. + */ + private[storage] def fallbackFetch( + originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { + val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]() + val originalHostLocalBlocksByExecutor = + mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]() + val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]() + val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr, + originalLocalBlocks, originalHostLocalBlocksByExecutor, originalMergedLocalBlocks) + // Add the remote requests into our queue in a random order + fetchRequests ++= Utils.randomize(originalRemoteReqs) + logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for push-merged") + // fetch all the fallback blocks that are local. + fetchLocalBlocks(originalLocalBlocks) + // Merged local blocks should be empty during fallback + assert(originalMergedLocalBlocks.isEmpty, + "There should be zero push-merged blocks during fallback") + // Some of the fallback local blocks could be host local blocks + fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor) + } + + /** + * Removes all the pending shuffle chunks that are on the same host and have the same reduceId as + * the current chunk that had a fetch failure. + * This is executed by the task thread when the `iterator.next()` is invoked and if that initiates + * fallback. + * + * @return set of all the removed shuffle chunk Ids. + */ + private[storage] def removePendingChunks( + failedBlockId: ShuffleBlockChunkId, + address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = { + val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]() + + def sameShuffleReducePartition(block: BlockId): Boolean = { + val chunkId = block.asInstanceOf[ShuffleBlockChunkId] + chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == failedBlockId.reduceId + } + + def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = { + val fetchRequestsToRemove = new mutable.Queue[FetchRequest]() + fetchRequestsToRemove ++= queue.dequeueAll { req => + val firstBlock = req.blocks.head + firstBlock.blockId.isShuffleChunk && req.address.equals(address) && + sameShuffleReducePartition(firstBlock.blockId) + } + fetchRequestsToRemove.foreach { _ => + removedChunkIds ++= + fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId])) + } + } + + filterRequests(fetchRequests) + deferredFetchRequests.get(address).foreach { defRequests => + filterRequests(defRequests) + if (defRequests.isEmpty) deferredFetchRequests.remove(address) + } + removedChunkIds + } } /** @@ -1074,8 +1343,13 @@ object ShuffleBlockFetcherIterator { * A request to fetch blocks from a remote BlockManager. * @param address remote BlockManager to fetch from. * @param blocks Sequence of the information for blocks to fetch from the same address. + * @param forMergedMetas true if this request is for requesting push-merged meta information; + * false if it is for regular or shuffle chunks. */ - case class FetchRequest(address: BlockManagerId, blocks: Seq[FetchBlockInfo]) { + case class FetchRequest( + address: BlockManagerId, + blocks: Seq[FetchBlockInfo], + forMergedMetas: Boolean = false) { val size = blocks.map(_.size).sum } @@ -1124,4 +1398,64 @@ object ShuffleBlockFetcherIterator { */ private[storage] case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends FetchResult + + /** + * Result of an un-successful fetch of either of these: + * 1) Remote shuffle chunk. + * 2) Local push-merged block. + * + * Instead of treating this as a [[FailureFetchResult]], we fallback to fetch the original blocks. + * + * @param blockId block id + * @param address BlockManager that the push-merged block was attempted to be fetched from + * @param size size of the block, used to update bytesInFlight. + * @param isNetworkReqDone Is this the last network request for this host in this fetch + * request. Used to update reqsInFlight. + */ + private[storage] case class FallbackOnPushMergedFailureResult(blockId: BlockId, + address: BlockManagerId, + size: Long, + isNetworkReqDone: Boolean) extends FetchResult + + /** + * Result of a successful fetch of meta information for a remote push-merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param blockSize size of each push-merged block. + * @param bitmaps bitmaps for every chunk. + * @param address BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFetchResult( + shuffleId: Int, + reduceId: Int, + blockSize: Long, + bitmaps: Array[RoaringBitmap], + address: BlockManagerId) extends FetchResult + + /** + * Result of a failure while fetching the meta information for a remote push-merged block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param address BlockManager that the meta was fetched from. + */ + private[storage] case class PushMergedRemoteMetaFailedFetchResult( + shuffleId: Int, + reduceId: Int, + address: BlockManagerId) extends FetchResult + + /** + * Result of a successful fetch of meta information for a push-merged-local block. + * + * @param shuffleId shuffle id. + * @param reduceId reduce id. + * @param bitmaps bitmaps for every chunk. + * @param localDirs local directories where the push-merged shuffle files are storedl + */ + private[storage] case class PushMergedLocalMetaFetchResult( + shuffleId: Int, + reduceId: Int, + bitmaps: Array[RoaringBitmap], + localDirs: Array[String]) extends FetchResult } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index b3138d7f126d..e8c3c2df261c 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -210,4 +210,29 @@ class BlockIdSuite extends SparkFunSuite { assert(!id.isShuffle) assertSame(id, BlockId(id.toString)) } + + test("merged shuffle id") { + val id = ShuffleBlockId(1, -1, 0) + assertSame(id, ShuffleBlockId(1, -1, 0)) + assertDifferent(id, ShuffleBlockId(1, 1, 1)) + assert(id.name === "shuffle_1_-1_0") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.mapId === -1) + assert(id.reduceId === 0) + assertSame(id, BlockId(id.toString)) + } + + test("shuffle chunk") { + val id = ShuffleBlockChunkId(1, 1, 0) + assertSame(id, ShuffleBlockChunkId(1, 1, 0)) + assertDifferent(id, ShuffleBlockChunkId(1, 1, 1)) + assert(id.name === "shuffleChunk_1_1_0") + assert(id.asRDDId === None) + assert(id.shuffleId === 1) + assert(id.reduceId === 1) + assert(id.chunkId === 0) + assertSame(id, BlockId(id.toString)) + } + } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 9b633479cb3a..a5143cd95ead 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -22,32 +22,41 @@ import java.nio.ByteBuffer import java.util.UUID import java.util.concurrent.{CompletableFuture, Semaphore} +import scala.collection.mutable import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.Future import io.netty.util.internal.OutOfDirectMemoryError import org.apache.log4j.Level import org.mockito.ArgumentMatchers.{any, eq => meq} -import org.mockito.Mockito.{mock, times, verify, when} +import org.mockito.Mockito.{doThrow, mock, times, verify, when} +import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer +import org.roaringbitmap.RoaringBitmap import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext} +import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient} +import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, MergedBlocksMetaListener} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} -import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo +import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER +import org.apache.spark.storage.ShuffleBlockFetcherIterator._ import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { private var transfer: BlockTransferService = _ + private var mapOutputTracker: MapOutputTracker = _ override def beforeEach(): Unit = { transfer = mock(classOf[BlockTransferService]) + mapOutputTracker = mock(classOf[MapOutputTracker]) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(Seq.empty.iterator) } private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) @@ -178,6 +187,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT tContext, transfer, blockManager.getOrElse(createMockBlockManager()), + mapOutputTracker, blocksByAddress.toIterator, (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in), maxBytesInFlight, @@ -1017,4 +1027,670 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } assert(e.getMessage.contains("fetch failed after 10 retries due to Netty OOM")) } + + /** + * Prepares the transfer to trigger success for all the blocks present in blockChunks. It will + * trigger failure of block which is not part of blockChunks. + */ + private def configureMockTransferForPushShuffle( + blocksSem: Semaphore, + blockChunks: Map[BlockId, ManagedBuffer]): Unit = { + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val regularBlocks = invocation.getArguments()(3).asInstanceOf[Array[String]] + val blockFetchListener = + invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + Future { + regularBlocks.foreach(blockId => { + val shuffleBlock = BlockId(blockId) + if (!blockChunks.contains(shuffleBlock)) { + // force failure + blockFetchListener.onBlockFetchFailure( + blockId, new RuntimeException("failed to fetch")) + } else { + blockFetchListener.onBlockFetchSuccess(blockId, blockChunks(shuffleBlock)) + } + blocksSem.release() + }) + } + }) + } + + test("SPARK-32922: fetch remote push-merged block meta") { + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), + (BlockManagerId("remote-client-1", "remote-host-1", 1), + toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)) + ) + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + + val metaSem = new Semaphore(0) + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(2) + when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val roaringBitmaps = Array(new RoaringBitmap, new RoaringBitmap) + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + Future { + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + + s"port = ${invocation.getArguments()(1)}, " + + s"shuffleId = $shuffleId, reduceId = $reduceId") + metaSem.acquire() + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + } + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) + blocksSem.acquire(2) + // The first block should be returned without an exception + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + metaSem.release() + val (id3, _) = iterator.next() + blocksSem.acquire() + assert(id3 === ShuffleBlockChunkId(0, 2, 0)) + val (id4, _) = iterator.next() + blocksSem.acquire() + assert(id4 === ShuffleBlockChunkId(0, 2, 1)) + assert(!iterator.hasNext) + } + + test("SPARK-32922: failed to fetch remote push-merged block meta so fallback to " + + "original blocks.") { + val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID)), + (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1))) + + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer() + ) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((remoteBmId, toBlockList( + Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error")) + } + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) + blocksSem.acquire(2) + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + blocksSem.acquire(2) + assert(id3 === ShuffleBlockId(0, 1, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + assert(!iterator.hasNext) + } + + test("SPARK-32922: iterator has just 1 push-merged block and fails to fetch the meta") { + val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1) + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1), + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer() + ) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((remoteBmId, toBlockList( + Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onFailure(shuffleId, reduceId, new RuntimeException("forced error")) + } + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress) + val (id1, _) = iterator.next() + blocksSem.acquire(2) + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 1, 2)) + assert(!iterator.hasNext) + } + + private def createMockPushMergedBlockMeta( + numChunks: Int, + bitmaps: Array[RoaringBitmap]): MergedBlockMeta = { + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(numChunks) + if (bitmaps == null) { + when(pushMergedBlockMeta.readChunkBitmaps()).thenThrow(new IOException("forced error")) + } else { + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(bitmaps) + } + doReturn(createMockManagedBuffer()).when(pushMergedBlockMeta).getChunksBitmapBuffer + pushMergedBlockMeta + } + + private def prepareForFallbackToLocalBlocks( + blockManager: BlockManager, + localDirsMap : Map[String, Array[String]], + failReadingLocalChunksMeta: Boolean = false): + Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = { + val localHost = "test-local-host" + val localBmId = BlockManagerId("test-client", localHost, 1) + doReturn(localBmId).when(blockManager).blockManagerId + initHostLocalDirManager(blockManager, localDirsMap) + + val blockBuffers = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer() + ) + + doReturn(blockBuffers(ShuffleBlockId(0, 0, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 0, 2)) + doReturn(blockBuffers(ShuffleBlockId(0, 1, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 1, 2)) + doReturn(blockBuffers(ShuffleBlockId(0, 2, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 2, 2)) + doReturn(blockBuffers(ShuffleBlockId(0, 3, 2))).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 3, 2)) + + val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER) + doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), dirsForMergedData) + + // Get a valid chunk meta for this test + val bitmaps = Array(new RoaringBitmap) + bitmaps(0).add(1) // chunk 0 has mapId 1 + bitmaps(0).add(2) // chunk 0 has mapId 2 + val pushMergedBlockMeta: MergedBlockMeta = if (failReadingLocalChunksMeta) { + createMockPushMergedBlockMeta(bitmaps.length, null) + } else { + createMockPushMergedBlockMeta(bitmaps.length, bitmaps) + } + when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + dirsForMergedData)).thenReturn(pushMergedBlockMeta) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((localBmId, + toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2, bitmaps(0))) + .thenReturn(Seq((localBmId, + toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator) + val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 1) + Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), + (pushMergedBmId, toBlockList( + Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + } + + private def verifyLocalBlocksFromFallback(iterator: ShuffleBlockFetcherIterator): Unit = { + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 1, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + assert(!iterator.hasNext) + } + + test("SPARK-32922: failure to fetch push-merged-local meta should fallback to fetch " + + "original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: failure to reading chunkBitmaps of push-merged-local meta should " + + "fallback to original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs), + failReadingLocalChunksMeta = true) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: failure to fetch push-merged-local data should fallback to fetch " + + "original shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: failure to fetch push-merged-local meta of a single merged block " + + "should not drop the fetch of other push-merged-local blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + val localHost = "test-local-host" + val localBmId = BlockManagerId("test-client", localHost, 1) + val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 1) + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 1)), + (pushMergedBmId, toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3)), 2L, SHUFFLE_PUSH_MAP_ID))) + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + // Create a valid chunk meta for partition 3 + val bitmaps = Array(new RoaringBitmap) + bitmaps(0).add(1) // chunk 0 has mapId 1 + doReturn(createMockPushMergedBlockMeta(bitmaps.length, bitmaps)).when(blockManager) + .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs) + // Return valid buffer for chunk in partition 3 + doReturn(Seq(createMockManagedBuffer(2))).when(blockManager) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 1, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + val (id5, _) = iterator.next() + assert(id5 === ShuffleBlockChunkId(0, 3, 0)) + assert(!iterator.hasNext) + } + + test("SPARK-32922: failure to fetch push-merged block as well as fallback block should throw " + + "a FetchFailedException") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("testPath1", "testPath2") + val localBmId = BlockManagerId("test-client", "test-local-host", 1) + doReturn(localBmId).when(blockManager).blockManagerId + val localDirsMap = Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs) + initHostLocalDirManager(blockManager, localDirsMap) + + doReturn(createMockManagedBuffer()).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 0, 2)) + // Force to fail reading of original block (0, 1, 2) that will throw a FetchFailed exception. + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalBlockData(ShuffleBlockId(0, 1, 2)) + + val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER) + // Since bitmaps are null, this will fail reading the push-merged block meta causing fallback to + // initiate. + val pushMergedBlockMeta: MergedBlockMeta = createMockPushMergedBlockMeta(2, null) + when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), + dirsForMergedData)).thenReturn(pushMergedBlockMeta) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((localBmId, + toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) + + val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]]( + (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1), toBlockList( + Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + // 1st instance of iterator.next() returns the original shuffle block (0, 0, 2) + assert(iterator.next()._1 === ShuffleBlockId(0, 0, 2)) + // 2nd instance of iterator.next() throws FetchFailedException + intercept[FetchFailedException] { iterator.next() } + } + + test("SPARK-32922: failure to fetch push-merged-local block should fallback to fetch " + + "original shuffle blocks which contain host-local blocks") { + val blockManager = mock(classOf[BlockManager]) + // BlockManagerId from another executor on the same host + val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) + val hostLocalDirs = Map("test-client-1" -> Array("local-dir"), + SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir")) + val blocksByAddress = prepareForFallbackToLocalBlocks(blockManager, hostLocalDirs) + + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) + // host local read for a shuffle block + doReturn(createMockManagedBuffer()).when(blockManager) + .getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir")) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenAnswer( + (_: InvocationOnMock) => { + Seq((blockManager.blockManagerId, toBlockList(Seq(ShuffleBlockId(0, 1, 2)), 1L, 1)), + (hostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 2, 2)), 1L, 1))).iterator + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: fetch host local blocks with push-merged block during initialization " + + "and fallback to host locals blocks") { + val blockManager = mock(classOf[BlockManager]) + // BlockManagerId of another executor on the same host + val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1) + val originalHostLocalBmId = BlockManagerId("test-client-2", "test-local-host", 1) + val hostLocalDirs = Map(hostLocalBmId.executorId -> Array("local-dir"), + SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir"), + originalHostLocalBmId.executorId -> Array("local-dir")) + + val hostLocalBlocks = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (hostLocalBmId, Seq((ShuffleBlockId(0, 5, 2), 1L, 1)))) + + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, hostLocalDirs) ++ hostLocalBlocks + + doThrow(new RuntimeException("Forced error")).when(blockManager) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), Array("local-dir")) + // host Local read for this original shuffle block + doReturn(createMockManagedBuffer()).when(blockManager) + .getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir")) + doReturn(createMockManagedBuffer()).when(blockManager) + .getHostLocalShuffleData(ShuffleBlockId(0, 5, 2), Array("local-dir")) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenAnswer( + (_: InvocationOnMock) => { + Seq((blockManager.blockManagerId, toBlockList(Seq(ShuffleBlockId(0, 2, 2)), 1L, 1)), + (originalHostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator + }) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + val (id1, _) = iterator.next() + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 3, 2)) + val (id3, _) = iterator.next() + assert(id3 === ShuffleBlockId(0, 5, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 2, 2)) + val (id5, _) = iterator.next() + assert(id5 === ShuffleBlockId(0, 1, 2)) + assert(!iterator.hasNext) + } + + test("SPARK-32922: failure while reading local shuffle chunks should fallback to original " + + "shuffle blocks") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + // This will throw an IOException when input stream is created from the ManagedBuffer + doReturn(Seq({ + new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100) + })).when(blockManager).getLocalMergedBlockData( + ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: fallback to original shuffle block when a push-merged shuffle chunk " + + "is corrupt") { + val blockManager = mock(classOf[BlockManager]) + val localDirs = Array("local-dir") + val blocksByAddress = prepareForFallbackToLocalBlocks( + blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)) + val corruptBuffer = createMockManagedBuffer(2) + doReturn(Seq({corruptBuffer})).when(blockManager) + .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs) + val corruptStream = mock(classOf[InputStream]) + when(corruptStream.read(any(), any(), any())).thenThrow(new IOException("corrupt")) + doReturn(corruptStream).when(corruptBuffer).createInputStream() + val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress, + blockManager = Some(blockManager), streamWrapperLimitSize = Some(100)) + verifyLocalBlocksFromFallback(iterator) + } + + test("SPARK-32922: fallback to original blocks when failed to fetch remote shuffle chunk") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + val bitmaps = Array(new RoaringBitmap, new RoaringBitmap) + bitmaps(1).add(3) + bitmaps(1).add(4) + bitmaps(1).add(5) + val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, bitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + } + }) + val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (BlockManagerId("remote-client", "remote-host-2", 1), + toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2)), 4L, 1))) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(fallbackBlocksByAddr.iterator) + val iterator = createShuffleBlockIteratorWithDefaults(Map( + BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) -> + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 12L, SHUFFLE_PUSH_MAP_ID))) + val (id1, _) = iterator.next() + blocksSem.acquire(1) + assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + val (id3, _) = iterator.next() + blocksSem.acquire(3) + assert(id3 === ShuffleBlockId(0, 3, 2)) + val (id4, _) = iterator.next() + assert(id4 === ShuffleBlockId(0, 4, 2)) + val (id5, _) = iterator.next() + assert(id5 === ShuffleBlockId(0, 5, 2)) + assert(!iterator.hasNext) + } + + test("SPARK-32922: fallback to original blocks when failed to parse remote merged block meta") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer() + ) + when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn( + Seq((BlockManagerId("remote-client-1", "remote-host-1", 1), + toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 1))).iterator) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, null) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + } + }) + val remoteMergedBlockMgrId = BlockManagerId( + SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteMergedBlockMgrId -> toBlockList( + Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, SHUFFLE_PUSH_MAP_ID))) + val (id1, _) = iterator.next() + blocksSem.acquire(2) + assert(id1 === ShuffleBlockId(0, 0, 2)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockId(0, 1, 2)) + assert(!iterator.hasNext) + } + + test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " + + "pending shuffle chunks immediately") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + // ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in block-chunks. + ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 3) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 6, 2) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + + val metaSem = new Semaphore(0) + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(4) + when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val roaringBitmaps = Array.fill[RoaringBitmap](4)(new RoaringBitmap) + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + + s"port = ${invocation.getArguments()(1)}, " + + s"shuffleId = $shuffleId, reduceId = $reduceId") + metaSem.release() + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + } + }) + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)), 1L, 1))) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(fallbackBlocksByAddr.iterator) + + val iterator = createShuffleBlockIteratorWithDefaults(Map( + BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 16L, SHUFFLE_PUSH_MAP_ID)), + maxBytesInFlight = 4) + metaSem.acquire(1) + val (id1, _) = iterator.next() + blocksSem.acquire(1) + assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + val regularBlocks = new mutable.HashSet[BlockId]() + val (id2, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id2) + val (id3, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id3) + val (id4, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id4) + val (id5, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id5) + assert(!iterator.hasNext) + assert(regularBlocks === Set(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2))) + } + + test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the fallback of " + + "pending shuffle chunks immediately which got deferred") { + val blockChunks = Map[BlockId, ManagedBuffer]( + ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(), + // ShuffleBlockChunkId(0, 2, 3) will cause failure as it is not in bock chunks + ShuffleBlockChunkId(0, 2, 4) -> createMockManagedBuffer(), + ShuffleBlockChunkId(0, 2, 5) -> createMockManagedBuffer(), + ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(), + ShuffleBlockId(0, 6, 2) -> createMockManagedBuffer() + ) + val blocksSem = new Semaphore(0) + configureMockTransferForPushShuffle(blocksSem, blockChunks) + val metaSem = new Semaphore(0) + val pushMergedBlockMeta = mock(classOf[MergedBlockMeta]) + when(pushMergedBlockMeta.getNumChunks).thenReturn(6) + when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer])) + val roaringBitmaps = Array.fill[RoaringBitmap](6)(new RoaringBitmap) + when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps) + when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any())) + .thenAnswer((invocation: InvocationOnMock) => { + val metaListener = invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener] + val shuffleId = invocation.getArguments()(2).asInstanceOf[Int] + val reduceId = invocation.getArguments()(3).asInstanceOf[Int] + Future { + logInfo(s"acquiring semaphore for host = ${invocation.getArguments()(0)}, " + + s"port = ${invocation.getArguments()(1)}, " + + s"shuffleId = $shuffleId, reduceId = $reduceId") + metaSem.release() + metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta) + } + }) + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( + (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)), 1L, 1))) + when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any())) + .thenReturn(fallbackBlocksByAddr.iterator) + + val iterator = createShuffleBlockIteratorWithDefaults(Map( + BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) -> + toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 24L, SHUFFLE_PUSH_MAP_ID)), + maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1) + metaSem.acquire(1) + val (id1, _) = iterator.next() + blocksSem.acquire(2) + assert(id1 === ShuffleBlockChunkId(0, 2, 0)) + val (id2, _) = iterator.next() + assert(id2 === ShuffleBlockChunkId(0, 2, 1)) + val (id3, _) = iterator.next() + blocksSem.acquire(1) + assert(id3 === ShuffleBlockChunkId(0, 2, 2)) + val regularBlocks = new mutable.HashSet[BlockId]() + val (id4, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id4) + val (id5, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id5) + val (id6, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id6) + val (id7, _) = iterator.next() + blocksSem.acquire(1) + regularBlocks.add(id7) + assert(!iterator.hasNext) + assert(regularBlocks === Set[ShuffleBlockId](ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2), + ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2))) + } + }