diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index cdec1982b4487..3f688c1e776c8 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -17,7 +17,7 @@ package org.apache.spark -import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, Serializable} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.locks.ReentrantReadWriteLock @@ -367,6 +367,16 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Get all map output status metadata for the given shuffle id. This could be used by custom + * shuffle manager to get map output information. For example, remote shuffle service shuffle + * manager could use this method to get the information and figure out where the shuffle data + * is located. + * @param shuffleId + * @return An array of all map status metadata objects. + */ + def getAllMapOutputStatusMetadata(shuffleId: Int): Array[Serializable] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -774,6 +784,18 @@ private[spark] class MapOutputTrackerMaster( } } + def getAllMapOutputStatusMetadata(shuffleId: Int): Array[Serializable] = { + logDebug(s"Fetching all output status metadata for shuffle $shuffleId") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.checkMapStatuses(statuses, shuffleId) + statuses.flatMap(_.metadata) + } + case None => Array.empty + } + } + override def stop(): Unit = { mapOutputRequests.offer(PoisonPill) threadpool.shutdown() @@ -827,6 +849,20 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } + override def getAllMapOutputStatusMetadata(shuffleId: Int): Array[Serializable] = { + logDebug(s"Fetching all output status metadata for shuffle $shuffleId") + val statuses = getStatuses(shuffleId, conf) + try { + MapOutputTracker.checkMapStatuses(statuses, shuffleId) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + statuses.flatMap(_.metadata) + } + /** * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize * on this array when reading it, because on the driver, we may be changing it in place. @@ -995,9 +1031,7 @@ private[spark] object MapOutputTracker extends Logging { val iter = statuses.iterator.zipWithIndex for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) { if (status == null) { - val errorMessage = s"Missing an output location for shuffle $shuffleId" - logError(errorMessage) - throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) + throwMetadataFetchFailedException(shuffleId, startPartition) } else { for (part <- startPartition until endPartition) { val size = status.getSizeForBlock(part) @@ -1011,4 +1045,19 @@ private[spark] object MapOutputTracker extends Logging { splitsByAddress.mapValues(_.toSeq).iterator } + + def checkMapStatuses(statuses: Array[MapStatus], shuffleId: Int): Unit = { + assert (statuses != null) + for (status <- statuses) { + if (status == null) { + throwMetadataFetchFailedException(shuffleId, 0) + } + } + } + + private def throwMetadataFetchFailedException(shuffleId: Int, startPartition: Int): Unit = { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage) + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index 1239c32cee3ab..c42a52a9ff057 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, ObjectInput, ObjectOutput, Serializable} import scala.collection.mutable @@ -52,6 +52,13 @@ private[spark] sealed trait MapStatus { * partitionId of the task or taskContext.taskAttemptId is used. */ def mapId: Long + + /** + * Extra metadata for map status. This could be used by different ShuffleManager implementation + * to store information they need. For example, a Remote Shuffle Service ShuffleManager could + * store shuffle server information and let reducer task know where to fetch shuffle data. + */ + def metadata: Option[Serializable] } @@ -76,6 +83,18 @@ private[spark] object MapStatus { } } + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + mapTaskId: Long, + metadata: Option[Serializable]): MapStatus = { + if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { + HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata) + } else { + new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata) + } + } + private[this] val LOG_BASE = 1.1 /** @@ -117,7 +136,8 @@ private[spark] object MapStatus { private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, private[this] var compressedSizes: Array[Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _metadata: Option[Serializable] = None) extends MapStatus with Externalizable { // For deserialization only @@ -127,6 +147,11 @@ private[spark] class CompressedMapStatus( this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId) } + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long, + metadata: Option[Serializable]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId, metadata) + } + override def location: BlockManagerId = loc override def updateLocation(newLoc: BlockManagerId): Unit = { @@ -139,11 +164,15 @@ private[spark] class CompressedMapStatus( override def mapId: Long = _mapTaskId + override def metadata: Option[Serializable] = _metadata + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) out.writeLong(_mapTaskId) + out.writeBoolean(_metadata.isDefined) + _metadata.foreach(out.writeObject) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -152,6 +181,10 @@ private[spark] class CompressedMapStatus( compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) _mapTaskId = in.readLong() + val hasMetadata = in.readBoolean() + if (hasMetadata) { + _metadata = Some(in.readObject().asInstanceOf[Serializable]) + } } } @@ -173,7 +206,8 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _metadata: Option[Serializable] = None) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization @@ -203,6 +237,8 @@ private[spark] class HighlyCompressedMapStatus private ( override def mapId: Long = _mapTaskId + override def metadata: Option[Serializable] = _metadata + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) emptyBlocks.serialize(out) @@ -213,6 +249,8 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeByte(kv._2) } out.writeLong(_mapTaskId) + out.writeBoolean(_metadata.isDefined) + _metadata.foreach(out.writeObject) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -230,6 +268,10 @@ private[spark] class HighlyCompressedMapStatus private ( } hugeBlockSizes = hugeBlockSizesImpl _mapTaskId = in.readLong() + val hasMetadata = in.readBoolean() + if (hasMetadata) { + _metadata = Some(in.readObject().asInstanceOf[Serializable]) + } } } @@ -237,7 +279,8 @@ private[spark] object HighlyCompressedMapStatus { def apply( loc: BlockManagerId, uncompressedSizes: Array[Long], - mapTaskId: Long): HighlyCompressedMapStatus = { + mapTaskId: Long, + metadata: Option[Serializable] = None): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -278,6 +321,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes, mapTaskId) + hugeBlockSizes, mapTaskId, metadata) } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 20b040f7c810d..43f1d8de59008 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -62,10 +62,10 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L), 5)) - tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L), 6)) + val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L, 10000L), 5) + val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L), 6) + tracker.registerMapOutput(10, 0, mapStatus1) + tracker.registerMapOutput(10, 1, mapStatus2) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), @@ -73,6 +73,27 @@ class MapOutputTrackerSuite extends SparkFunSuite { (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000, 1)))).toSet) assert(0 == tracker.getNumCachedSerializedBroadcast) + val allStatusMetadata = tracker.getAllMapOutputStatusMetadata(10) + assert(0 == allStatusMetadata.size) + tracker.stop() + rpcEnv.shutdown() + } + + test("master register shuffle with map status metadata") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), + Array(1000L, 10000L), 5, Some("metadata1")) + val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), + Array(10000L, 1000L), 6, Some(1001)) + tracker.registerMapOutput(10, 0, mapStatus1) + tracker.registerMapOutput(10, 1, mapStatus2) + assert(0 == tracker.getNumCachedSerializedBroadcast) + val allStatusMetadata = tracker.getAllMapOutputStatusMetadata(10) + assert(allStatusMetadata === Array(mapStatus1.metadata.get, mapStatus2.metadata.get)) tracker.stop() rpcEnv.shutdown() } @@ -92,9 +113,11 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) + assert(0 == tracker.getAllMapOutputStatusMetadata(10).size) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) + assert(0 == tracker.getAllMapOutputStatusMetadata(11).size) tracker.stop() rpcEnv.shutdown() @@ -122,6 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } + intercept[FetchFailedException] { tracker.getAllMapOutputStatusMetadata(10) } tracker.stop() rpcEnv.shutdown() @@ -146,13 +170,15 @@ class MapOutputTrackerSuite extends SparkFunSuite { intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L), 5)) + val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5) + masterTracker.registerMapOutput(10, 0, mapStatus) mapWorkerTracker.updateEpoch(masterTracker.getEpoch) assert(mapWorkerTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))))) assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val allMapOutputStatusMetadata = mapWorkerTracker.getAllMapOutputStatusMetadata(10) + assert(0 == allMapOutputStatusMetadata.size) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) @@ -170,6 +196,33 @@ class MapOutputTrackerSuite extends SparkFunSuite { mapWorkerRpcEnv.shutdown() } + test("remote get all map output statuses with metadata") { + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val mapWorkerRpcEnv = createRpcEnv("spark-worker", hostname, 0, new SecurityManager(conf)) + val mapWorkerTracker = new MapOutputTrackerWorker(conf) + mapWorkerTracker.trackerEndpoint = + mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 1) + val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5, + Some("metadata1")) + masterTracker.registerMapOutput(10, 0, mapStatus) + val allMapOutputStatusMetadata = mapWorkerTracker.getAllMapOutputStatusMetadata(10) + assert(allMapOutputStatusMetadata.size === 1) + assert(allMapOutputStatusMetadata(0) === mapStatus.metadata.get) + + masterTracker.stop() + mapWorkerTracker.stop() + rpcEnv.shutdown() + mapWorkerRpcEnv.shutdown() + } + test("remote fetch below max RPC message size") { val newConf = new SparkConf newConf.set(RPC_MESSAGE_MAX_SIZE, 1) @@ -311,10 +364,12 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000), 5)) - tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0), 6)) + val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000), 5) + val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0), 6) + tracker.registerMapOutput(10, 0, mapStatus1) + tracker.registerMapOutput(10, 1, mapStatus2) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 2, 0, 4).toSeq === Seq( @@ -326,6 +381,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { (ShuffleBlockId(10, 6, 2), size1000, 1))) ) ) + assert(0 == tracker.getAllMapOutputStatusMetadata(10).size) tracker.unregisterShuffle(10) tracker.stop()