diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index c3152d9225107..2074531bccf0a 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -367,6 +367,15 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging startPartition: Int, endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Get all map output statuses for the given shuffle id. This could be used by custom shuffle + * manager to get map output information. For example, remote shuffle service shuffle manager + * could use this method to get the information and figure out where the shuffle data is located. + * @param shuffleId + * @return An array of all map status objects. + */ + def getAllMapOutputStatuses(shuffleId: Int): Array[MapStatus] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -774,6 +783,18 @@ private[spark] class MapOutputTrackerMaster( } } + def getAllMapOutputStatuses(shuffleId: Int): Array[MapStatus] = { + logDebug(s"Fetching all output statuses for shuffle $shuffleId") + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.checkMapStatuses(statuses, shuffleId) + statuses.clone + } + case None => Array.empty + } + } + override def stop(): Unit = { mapOutputRequests.offer(PoisonPill) threadpool.shutdown() @@ -827,6 +848,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } + override def getAllMapOutputStatuses(shuffleId: Int): Array[MapStatus] = { + logDebug(s"Fetching all output statuses for shuffle $shuffleId") + val statuses = getStatuses(shuffleId, conf) + MapOutputTracker.checkMapStatuses(statuses, shuffleId) + statuses + } + /** * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize * on this array when reading it, because on the driver, we may be changing it in place. @@ -1011,4 +1039,15 @@ private[spark] object MapOutputTracker extends Logging { splitsByAddress.mapValues(_.toSeq).iterator } + + def checkMapStatuses(statuses: Array[MapStatus], shuffleId: Int): Unit = { + assert (statuses != null) + for (status <- statuses) { + if (status == null) { + val errorMessage = s"Missing an output location for shuffle $shuffleId" + logError(errorMessage) + throw new MetadataFetchFailedException(shuffleId, 0, errorMessage) + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala index cfc2e141290c4..29558c660f8ca 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.io.{Externalizable, ObjectInput, ObjectOutput, Serializable} import scala.collection.mutable @@ -52,6 +52,13 @@ private[spark] sealed trait MapStatus { * partitionId of the task or taskContext.taskAttemptId is used. */ def mapId: Long + + /** + * Extra metadata for map status. This could be used by different ShuffleManager implementation + * to store information they need. For example, a Remote Shuffle Service ShuffleManager could + * store shuffle server information and let reducer task know where to fetch shuffle data. + */ + def metadata: Option[Serializable] } @@ -76,6 +83,18 @@ private[spark] object MapStatus { } } + def apply( + loc: BlockManagerId, + uncompressedSizes: Array[Long], + mapTaskId: Long, + metadata: Option[Serializable]): MapStatus = { + if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) { + HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata) + } else { + new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata) + } + } + private[this] val LOG_BASE = 1.1 /** @@ -117,7 +136,8 @@ private[spark] object MapStatus { private[spark] class CompressedMapStatus( private[this] var loc: BlockManagerId, private[this] var compressedSizes: Array[Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _metadata: Option[Serializable] = None) extends MapStatus with Externalizable { // For deserialization only @@ -127,6 +147,11 @@ private[spark] class CompressedMapStatus( this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId) } + def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long, + metadata: Option[Serializable]) { + this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId, metadata) + } + override def location: BlockManagerId = loc override def updateLocation(newLoc: BlockManagerId): Unit = { @@ -139,11 +164,19 @@ private[spark] class CompressedMapStatus( override def mapId: Long = _mapTaskId + override def metadata: Option[Serializable] = _metadata + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) out.writeInt(compressedSizes.length) out.write(compressedSizes) out.writeLong(_mapTaskId) + if (_metadata.isEmpty) { + out.writeBoolean(false) + } else { + out.writeBoolean(true) + out.writeObject(_metadata.get) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -152,6 +185,10 @@ private[spark] class CompressedMapStatus( compressedSizes = new Array[Byte](len) in.readFully(compressedSizes) _mapTaskId = in.readLong() + val hasMetadata = in.readBoolean() + if (hasMetadata) { + _metadata = Some(in.readObject().asInstanceOf[Serializable]) + } } } @@ -173,7 +210,8 @@ private[spark] class HighlyCompressedMapStatus private ( private[this] var emptyBlocks: RoaringBitmap, private[this] var avgSize: Long, private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte], - private[this] var _mapTaskId: Long) + private[this] var _mapTaskId: Long, + private[this] var _metadata: Option[Serializable] = None) extends MapStatus with Externalizable { // loc could be null when the default constructor is called during deserialization @@ -203,6 +241,8 @@ private[spark] class HighlyCompressedMapStatus private ( override def mapId: Long = _mapTaskId + override def metadata: Option[Serializable] = _metadata + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { loc.writeExternal(out) emptyBlocks.serialize(out) @@ -213,6 +253,12 @@ private[spark] class HighlyCompressedMapStatus private ( out.writeByte(kv._2) } out.writeLong(_mapTaskId) + if (_metadata.isEmpty) { + out.writeBoolean(false) + } else { + out.writeBoolean(true) + out.writeObject(_metadata.get) + } } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { @@ -230,6 +276,10 @@ private[spark] class HighlyCompressedMapStatus private ( } hugeBlockSizes = hugeBlockSizesImpl _mapTaskId = in.readLong() + val hasMetadata = in.readBoolean() + if (hasMetadata) { + _metadata = Some(in.readObject().asInstanceOf[Serializable]) + } } } @@ -237,7 +287,8 @@ private[spark] object HighlyCompressedMapStatus { def apply( loc: BlockManagerId, uncompressedSizes: Array[Long], - mapTaskId: Long): HighlyCompressedMapStatus = { + mapTaskId: Long, + metadata: Option[Serializable] = None): HighlyCompressedMapStatus = { // We must keep track of which blocks are empty so that we don't report a zero-sized // block as being non-empty (or vice-versa) when using the average block size. var i = 0 @@ -278,6 +329,6 @@ private[spark] object HighlyCompressedMapStatus { emptyBlocks.trim() emptyBlocks.runOptimize() new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize, - hugeBlockSizes, mapTaskId) + hugeBlockSizes, mapTaskId, metadata) } } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index b5b68f639ffc9..8a8b6f25e54c2 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -63,16 +63,37 @@ class MapOutputTrackerSuite extends SparkFunSuite { assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(1000L, 10000L), 5)) - tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(10000L, 1000L), 6)) + val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L, 10000L), 5) + val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L), 6) + tracker.registerMapOutput(10, 0, mapStatus1) + tracker.registerMapOutput(10, 1, mapStatus2) val statuses = tracker.getMapSizesByExecutorId(10, 0) assert(statuses.toSet === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))), (BlockManagerId("b", "hostB", 1000), ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000, 1)))).toSet) + val allStatuses = tracker.getAllMapOutputStatuses(10) + assert(allStatuses.toSet === Set(mapStatus1, mapStatus2)) + assert(0 == tracker.getNumCachedSerializedBroadcast) + tracker.stop() + rpcEnv.shutdown() + } + + test("master register shuffle with map status metadata") { + val rpcEnv = createRpcEnv("test") + val tracker = newTrackerMaster() + tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf)) + tracker.registerShuffle(10, 2) + val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), + Array(1000L, 10000L), 5, Some("metadata1")) + val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), + Array(10000L, 1000L), 6, Some(1001)) + tracker.registerMapOutput(10, 0, mapStatus1) + tracker.registerMapOutput(10, 1, mapStatus2) + val allStatuses = tracker.getAllMapOutputStatuses(10) + assert(allStatuses.toSet === Set(mapStatus1, mapStatus2)) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.stop() rpcEnv.shutdown() @@ -92,10 +113,12 @@ class MapOutputTrackerSuite extends SparkFunSuite { Array(compressedSize10000, compressedSize1000), 6)) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty) + assert(tracker.getAllMapOutputStatuses(10).nonEmpty) assert(0 == tracker.getNumCachedSerializedBroadcast) tracker.unregisterShuffle(10) assert(!tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty) + assert(tracker.getAllMapOutputStatuses(11).isEmpty) tracker.stop() rpcEnv.shutdown() @@ -123,6 +146,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) } + intercept[FetchFailedException] { tracker.getAllMapOutputStatuses(10) } tracker.stop() rpcEnv.shutdown() @@ -147,13 +171,18 @@ class MapOutputTrackerSuite extends SparkFunSuite { intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) } val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L), 5)) + val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5) + masterTracker.registerMapOutput(10, 0, mapStatus) mapWorkerTracker.updateEpoch(masterTracker.getEpoch) assert(mapWorkerTracker.getMapSizesByExecutorId(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))))) - assert(0 == masterTracker.getNumCachedSerializedBroadcast) + val allMapOutputStatuses = mapWorkerTracker.getAllMapOutputStatuses(10) + assert(allMapOutputStatuses.length === 1) + assert(allMapOutputStatuses(0).location === mapStatus.location) + assert(allMapOutputStatuses(0).getSizeForBlock(0) === mapStatus.getSizeForBlock(0)) + assert(allMapOutputStatuses(0).mapId === mapStatus.mapId) + assert(allMapOutputStatuses(0).metadata === mapStatus.metadata) val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000)) @@ -171,6 +200,36 @@ class MapOutputTrackerSuite extends SparkFunSuite { mapWorkerRpcEnv.shutdown() } + test("remote get all map output statuses with metadata") { + val hostname = "localhost" + val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf)) + + val masterTracker = newTrackerMaster() + masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, + new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) + + val mapWorkerRpcEnv = createRpcEnv("spark-worker", hostname, 0, new SecurityManager(conf)) + val mapWorkerTracker = new MapOutputTrackerWorker(conf) + mapWorkerTracker.trackerEndpoint = + mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + + masterTracker.registerShuffle(10, 1) + val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5, + Some("metadata1")) + masterTracker.registerMapOutput(10, 0, mapStatus) + val allMapOutputStatuses = mapWorkerTracker.getAllMapOutputStatuses(10) + assert(allMapOutputStatuses.length === 1) + assert(allMapOutputStatuses(0).location === mapStatus.location) + assert(allMapOutputStatuses(0).getSizeForBlock(0) === mapStatus.getSizeForBlock(0)) + assert(allMapOutputStatuses(0).mapId === mapStatus.mapId) + assert(allMapOutputStatuses(0).metadata === mapStatus.metadata) + + masterTracker.stop() + mapWorkerTracker.stop() + rpcEnv.shutdown() + mapWorkerRpcEnv.shutdown() + } + test("remote fetch below max RPC message size") { val newConf = new SparkConf newConf.set(RPC_MESSAGE_MAX_SIZE, 1) @@ -312,10 +371,12 @@ class MapOutputTrackerSuite extends SparkFunSuite { val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L)) - tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000), - Array(size0, size1000, size0, size10000), 5)) - tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000), - Array(size10000, size0, size1000, size0), 6)) + val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), + Array(size0, size1000, size0, size10000), 5) + val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), + Array(size10000, size0, size1000, size0), 6) + tracker.registerMapOutput(10, 0, mapStatus1) + tracker.registerMapOutput(10, 1, mapStatus2) assert(tracker.containsShuffle(10)) assert(tracker.getMapSizesByExecutorId(10, 0, 2, 0, 4).toSeq === Seq( @@ -327,6 +388,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { (ShuffleBlockId(10, 6, 2), size1000, 1))) ) ) + assert(tracker.getAllMapOutputStatuses(10).toSet === Set(mapStatus1, mapStatus2)) tracker.unregisterShuffle(10) tracker.stop()