Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,15 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Get all map output statuses for the given shuffle id. This could be used by custom shuffle
* manager to get map output information. For example, remote shuffle service shuffle manager
* could use this method to get the information and figure out where the shuffle data is located.
* @param shuffleId
* @return An array of all map status objects.
*/
def getAllMapOutputStatuses(shuffleId: Int): Array[MapStatus]

/**
* Deletes map output status information for the specified shuffle stage.
*/
Expand Down Expand Up @@ -774,6 +783,18 @@ private[spark] class MapOutputTrackerMaster(
}
}

def getAllMapOutputStatuses(shuffleId: Int): Array[MapStatus] = {
logDebug(s"Fetching all output statuses for shuffle $shuffleId")
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.withMapStatuses { statuses =>
MapOutputTracker.checkMapStatuses(statuses, shuffleId)
statuses.clone
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your intention by calling this clone here but I do not think this is enough.
As the MapStatus is trait and not a case class in addition its implementations are mutable with a lot of var fields.

The clone on the Array is not a deep copy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, got your point. How about change this method to getAllMapOutputStatusMetadata to only return the metadada?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could not be enough if the metadata can mutate. But as I see we could solve all the problems with immutable metadata easily. So to be on the safe side please document we require the metadata to be immutable and introduce an updateMetadata(meta: Option[Serializable]) method in MapStatus. Then we will be safe and all the use cases are covered.

(And you can use a case class for the Uber RSS's MapTaskRssInfo)

}
case None => Array.empty
}
}

override def stop(): Unit = {
mapOutputRequests.offer(PoisonPill)
threadpool.shutdown()
Expand Down Expand Up @@ -827,6 +848,13 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
}
}

override def getAllMapOutputStatuses(shuffleId: Int): Array[MapStatus] = {
logDebug(s"Fetching all output statuses for shuffle $shuffleId")
val statuses = getStatuses(shuffleId, conf)
Copy link
Contributor

@attilapiros attilapiros Mar 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please clear the mapStatuses in case of MetadataFetchFailedException!

Reasoning:
The getStatuses method before this PR was only used in getMapSizesByExecutorId where the MetadataFetchFailedException (the case when missing output location was detected) handled by clearing of the mapStatuses cache as it is probably outdated.

I am sure that clearing would not be missed if this cleaning would be done at the throwing of that exception.
Could you please check whether it can be moved there?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I see why you cannot move the clearing there!
Still the clearing itself is needed to be done.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we only need to clear mapStatuses in MapOutputTrackerWorker , will add that

MapOutputTracker.checkMapStatuses(statuses, shuffleId)
statuses
}

/**
* Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
* on this array when reading it, because on the driver, we may be changing it in place.
Expand Down Expand Up @@ -1011,4 +1039,15 @@ private[spark] object MapOutputTracker extends Logging {

splitsByAddress.mapValues(_.toSeq).iterator
}

def checkMapStatuses(statuses: Array[MapStatus], shuffleId: Int): Unit = {
assert (statuses != null)
for (status <- statuses) {
if (status == null) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can extract this if into a new method and reuse the method in convertMapStatuses.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, good suggestion!

val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
throw new MetadataFetchFailedException(shuffleId, 0, errorMessage)
}
}
}
}
61 changes: 56 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.scheduler

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.io.{Externalizable, ObjectInput, ObjectOutput, Serializable}

import scala.collection.mutable

Expand Down Expand Up @@ -52,6 +52,13 @@ private[spark] sealed trait MapStatus {
* partitionId of the task or taskContext.taskAttemptId is used.
*/
def mapId: Long

/**
* Extra metadata for map status. This could be used by different ShuffleManager implementation
* to store information they need. For example, a Remote Shuffle Service ShuffleManager could
* store shuffle server information and let reducer task know where to fetch shuffle data.
*/
def metadata: Option[Serializable]
}


Expand All @@ -76,6 +83,18 @@ private[spark] object MapStatus {
}
}

def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
mapTaskId: Long,
metadata: Option[Serializable]): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata)
} else {
new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata)
}
}

private[this] val LOG_BASE = 1.1

/**
Expand Down Expand Up @@ -117,7 +136,8 @@ private[spark] object MapStatus {
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
private[this] var compressedSizes: Array[Byte],
private[this] var _mapTaskId: Long)
private[this] var _mapTaskId: Long,
private[this] var _metadata: Option[Serializable] = None)
extends MapStatus with Externalizable {

// For deserialization only
Expand All @@ -127,6 +147,11 @@ private[spark] class CompressedMapStatus(
this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId)
}

def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long,
metadata: Option[Serializable]) {
this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId, metadata)
}

override def location: BlockManagerId = loc

override def updateLocation(newLoc: BlockManagerId): Unit = {
Expand All @@ -139,11 +164,19 @@ private[spark] class CompressedMapStatus(

override def mapId: Long = _mapTaskId

override def metadata: Option[Serializable] = _metadata

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
out.writeLong(_mapTaskId)
if (_metadata.isEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

out.writeBoolean(_metadata.isDefined)
_metadata.foreach(out.writeObject)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice suggestion!

out.writeBoolean(false)
} else {
out.writeBoolean(true)
out.writeObject(_metadata.get)
}
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
Expand All @@ -152,6 +185,10 @@ private[spark] class CompressedMapStatus(
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
_mapTaskId = in.readLong()
val hasMetadata = in.readBoolean()
if (hasMetadata) {
_metadata = Some(in.readObject().asInstanceOf[Serializable])
}
}
}

Expand All @@ -173,7 +210,8 @@ private[spark] class HighlyCompressedMapStatus private (
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
private[this] var _mapTaskId: Long)
private[this] var _mapTaskId: Long,
private[this] var _metadata: Option[Serializable] = None)
extends MapStatus with Externalizable {

// loc could be null when the default constructor is called during deserialization
Expand Down Expand Up @@ -203,6 +241,8 @@ private[spark] class HighlyCompressedMapStatus private (

override def mapId: Long = _mapTaskId

override def metadata: Option[Serializable] = _metadata

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
emptyBlocks.serialize(out)
Expand All @@ -213,6 +253,12 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeByte(kv._2)
}
out.writeLong(_mapTaskId)
if (_metadata.isEmpty) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

out.writeBoolean(_metadata.isDefined)
_metadata.foreach(out.writeObject)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion!

out.writeBoolean(false)
} else {
out.writeBoolean(true)
out.writeObject(_metadata.get)
}
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
Expand All @@ -230,14 +276,19 @@ private[spark] class HighlyCompressedMapStatus private (
}
hugeBlockSizes = hugeBlockSizesImpl
_mapTaskId = in.readLong()
val hasMetadata = in.readBoolean()
if (hasMetadata) {
_metadata = Some(in.readObject().asInstanceOf[Serializable])
}
}
}

private[spark] object HighlyCompressedMapStatus {
def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
mapTaskId: Long): HighlyCompressedMapStatus = {
mapTaskId: Long,
metadata: Option[Serializable] = None): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
Expand Down Expand Up @@ -278,6 +329,6 @@ private[spark] object HighlyCompressedMapStatus {
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
hugeBlockSizes, mapTaskId)
hugeBlockSizes, mapTaskId, metadata)
}
}
84 changes: 73 additions & 11 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,37 @@ class MapOutputTrackerSuite extends SparkFunSuite {
assert(tracker.containsShuffle(10))
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(1000L, 10000L), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(10000L, 1000L), 6))
val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L, 10000L), 5)
val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L), 6)
tracker.registerMapOutput(10, 0, mapStatus1)
tracker.registerMapOutput(10, 1, mapStatus2)
val statuses = tracker.getMapSizesByExecutorId(10, 0)
assert(statuses.toSet ===
Seq((BlockManagerId("a", "hostA", 1000),
ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))),
(BlockManagerId("b", "hostB", 1000),
ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000, 1)))).toSet)
val allStatuses = tracker.getAllMapOutputStatuses(10)
assert(allStatuses.toSet === Set(mapStatus1, mapStatus2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the toSet you lose the ordering meanwhile ordering can be important (that is the mapIndex) so it should be tested.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good suggestion, will use Array to check the sequence as well

assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.stop()
rpcEnv.shutdown()
}

test("master register shuffle with map status metadata") {
val rpcEnv = createRpcEnv("test")
val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000),
Array(1000L, 10000L), 5, Some("metadata1"))
val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000),
Array(10000L, 1000L), 6, Some(1001))
tracker.registerMapOutput(10, 0, mapStatus1)
tracker.registerMapOutput(10, 1, mapStatus2)
val allStatuses = tracker.getAllMapOutputStatuses(10)
assert(allStatuses.toSet === Set(mapStatus1, mapStatus2))
Copy link
Contributor

@attilapiros attilapiros Mar 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.stop()
rpcEnv.shutdown()
Expand All @@ -92,10 +113,12 @@ class MapOutputTrackerSuite extends SparkFunSuite {
Array(compressedSize10000, compressedSize1000), 6))
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
assert(tracker.getAllMapOutputStatuses(10).nonEmpty)
assert(0 == tracker.getNumCachedSerializedBroadcast)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
assert(tracker.getAllMapOutputStatuses(11).isEmpty)

tracker.stop()
rpcEnv.shutdown()
Expand Down Expand Up @@ -123,6 +146,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) }
intercept[FetchFailedException] { tracker.getAllMapOutputStatuses(10) }

tracker.stop()
rpcEnv.shutdown()
Expand All @@ -147,13 +171,18 @@ class MapOutputTrackerSuite extends SparkFunSuite {
intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) }

val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("a", "hostA", 1000), Array(1000L), 5))
val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5)
masterTracker.registerMapOutput(10, 0, mapStatus)
mapWorkerTracker.updateEpoch(masterTracker.getEpoch)
assert(mapWorkerTracker.getMapSizesByExecutorId(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000),
ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0)))))
assert(0 == masterTracker.getNumCachedSerializedBroadcast)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put the assert back:

    assert(0 == masterTracker.getNumCachedSerializedBroadcast)

val allMapOutputStatuses = mapWorkerTracker.getAllMapOutputStatuses(10)
assert(allMapOutputStatuses.length === 1)
assert(allMapOutputStatuses(0).location === mapStatus.location)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to have an equals method on MapStatus and use the equals here.
Because in that case when MapStatus is extended with a new field this test will validate the serialization / deserialization of this new field too.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In responding of one of previous comments, I am suggesting returning only metadata instead of the whole map statue object. Will revisit here after that discussion.

assert(allMapOutputStatuses(0).getSizeForBlock(0) === mapStatus.getSizeForBlock(0))
assert(allMapOutputStatuses(0).mapId === mapStatus.mapId)
assert(allMapOutputStatuses(0).metadata === mapStatus.metadata)

val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
Expand All @@ -171,6 +200,36 @@ class MapOutputTrackerSuite extends SparkFunSuite {
mapWorkerRpcEnv.shutdown()
}

test("remote get all map output statuses with metadata") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking about extending the remote fetch test (the one before this) with an extra registered map output where metadata is given and then this test could be deleted.

That way you will test the case when one of map status is given with and one is without a metadata.

WDYT?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous test has masterTracker.unregisterMapOutput and some test verification for that, thus want to avoid adding too much for that test. Also this test is specifically testing non-null metadata object, kind of following "separation of concerns" to make it as a separate test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fine for me

val hostname = "localhost"
val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))

val masterTracker = newTrackerMaster()
masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))

val mapWorkerRpcEnv = createRpcEnv("spark-worker", hostname, 0, new SecurityManager(conf))
val mapWorkerTracker = new MapOutputTrackerWorker(conf)
mapWorkerTracker.trackerEndpoint =
mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)

masterTracker.registerShuffle(10, 1)
val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5,
Some("metadata1"))
masterTracker.registerMapOutput(10, 0, mapStatus)
val allMapOutputStatuses = mapWorkerTracker.getAllMapOutputStatuses(10)
assert(allMapOutputStatuses.length === 1)
assert(allMapOutputStatuses(0).location === mapStatus.location)
assert(allMapOutputStatuses(0).getSizeForBlock(0) === mapStatus.getSizeForBlock(0))
assert(allMapOutputStatuses(0).mapId === mapStatus.mapId)
assert(allMapOutputStatuses(0).metadata === mapStatus.metadata)

masterTracker.stop()
mapWorkerTracker.stop()
rpcEnv.shutdown()
mapWorkerRpcEnv.shutdown()
}

test("remote fetch below max RPC message size") {
val newConf = new SparkConf
newConf.set(RPC_MESSAGE_MAX_SIZE, 1)
Expand Down Expand Up @@ -312,10 +371,12 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L))
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(size0, size1000, size0, size10000), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(size10000, size0, size1000, size0), 6))
val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000),
Array(size0, size1000, size0, size10000), 5)
val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000),
Array(size10000, size0, size1000, size0), 6)
tracker.registerMapOutput(10, 0, mapStatus1)
tracker.registerMapOutput(10, 1, mapStatus2)
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0, 2, 0, 4).toSeq ===
Seq(
Expand All @@ -327,6 +388,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
(ShuffleBlockId(10, 6, 2), size1000, 1)))
)
)
assert(tracker.getAllMapOutputStatuses(10).toSet === Set(mapStatus1, mapStatus2))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

order check

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!


tracker.unregisterShuffle(10)
tracker.stop()
Expand Down