Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 53 additions & 4 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark

import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream, Serializable}
import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.locks.ReentrantReadWriteLock

Expand Down Expand Up @@ -367,6 +367,16 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
startPartition: Int,
endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Get all map output status metadata for the given shuffle id. This could be used by custom
* shuffle manager to get map output information. For example, remote shuffle service shuffle
* manager could use this method to get the information and figure out where the shuffle data
* is located.
* @param shuffleId
* @return An array of all map status metadata objects.
*/
def getAllMapOutputStatusMetadata(shuffleId: Int): Array[Serializable]

/**
* Deletes map output status information for the specified shuffle stage.
*/
Expand Down Expand Up @@ -774,6 +784,18 @@ private[spark] class MapOutputTrackerMaster(
}
}

def getAllMapOutputStatusMetadata(shuffleId: Int): Array[Serializable] = {
logDebug(s"Fetching all output status metadata for shuffle $shuffleId")
shuffleStatuses.get(shuffleId) match {
case Some(shuffleStatus) =>
shuffleStatus.withMapStatuses { statuses =>
MapOutputTracker.checkMapStatuses(statuses, shuffleId)
statuses.flatMap(_.metadata)
}
case None => Array.empty
}
}

override def stop(): Unit = {
mapOutputRequests.offer(PoisonPill)
threadpool.shutdown()
Expand Down Expand Up @@ -827,6 +849,20 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
}
}

override def getAllMapOutputStatusMetadata(shuffleId: Int): Array[Serializable] = {
logDebug(s"Fetching all output status metadata for shuffle $shuffleId")
val statuses = getStatuses(shuffleId, conf)
try {
MapOutputTracker.checkMapStatuses(statuses, shuffleId)
} catch {
case e: MetadataFetchFailedException =>
// We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
mapStatuses.clear()
throw e
}
statuses.flatMap(_.metadata)
}

/**
* Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
* on this array when reading it, because on the driver, we may be changing it in place.
Expand Down Expand Up @@ -995,9 +1031,7 @@ private[spark] object MapOutputTracker extends Logging {
val iter = statuses.iterator.zipWithIndex
for ((status, mapIndex) <- iter.slice(startMapIndex, endMapIndex)) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
throwMetadataFetchFailedException(shuffleId, startPartition)
} else {
for (part <- startPartition until endPartition) {
val size = status.getSizeForBlock(part)
Expand All @@ -1011,4 +1045,19 @@ private[spark] object MapOutputTracker extends Logging {

splitsByAddress.mapValues(_.toSeq).iterator
}

def checkMapStatuses(statuses: Array[MapStatus], shuffleId: Int): Unit = {
assert (statuses != null)
for (status <- statuses) {
if (status == null) {
throwMetadataFetchFailedException(shuffleId, 0)
}
}
}

private def throwMetadataFetchFailedException(shuffleId: Int, startPartition: Int): Unit = {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
throw new MetadataFetchFailedException(shuffleId, startPartition, errorMessage)
}
}
53 changes: 48 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.scheduler

import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.io.{Externalizable, ObjectInput, ObjectOutput, Serializable}

import scala.collection.mutable

Expand Down Expand Up @@ -52,6 +52,13 @@ private[spark] sealed trait MapStatus {
* partitionId of the task or taskContext.taskAttemptId is used.
*/
def mapId: Long

/**
* Extra metadata for map status. This could be used by different ShuffleManager implementation
* to store information they need. For example, a Remote Shuffle Service ShuffleManager could
* store shuffle server information and let reducer task know where to fetch shuffle data.
*/
def metadata: Option[Serializable]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm...what's the relationship between SPARK-33114 and SPARK-25299? According to the JIRA description, SPARK-33114 seems to enhance the support for custom shuffle manager while SPARK-25299 only customize the storage with the default SortShuffleManager.

So if we are only talking about SPARK-33114, adding metadata may be a good choice according to its own scenario. But if we bring in SPARK-25299 together (IIUC, what this PR is doing would also benefit SPARK-25299), I personally think we need a more general design here. For example, I'd prefer to redesign the location of MapStatus to make it be able to support different scenarios (e.g., Spark BlockManager, Spark external shuffle service, custom remote storage, etc. ) mentioned in SPARK-25299. And in this way, different scenarios would be able to reuse the existing features, e.g., decommission(which may update mapstatus location during runtime) and reuse the existing code paths, e.g., we don't need the extra getAllMapOutputStatuses and everything should be the same as what we already did during shuffle reading.

WDYT?

Copy link
Contributor

@attilapiros attilapiros Mar 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ngone51 I agree with you that finishing the design laid out in SPARK-25299 would be much better.
This is why I opened #30763 as a copy of Matthew Cheah's original PR for SPARK-31801 (because he is busy with other projects) and kept it up-to-date several times with the master.

But it haven't got enough reviews and I wouldn't want to block @hiboyang further, #30004 (comment).

I am sure with your help we can complete SPARK-31801 and be on the road of SPARK-25299.

So next week I will do the conflict resolution and ping you when the PR is ready for review. Is this okay?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, thanks @attilapiros for keeping working on SPARK-25299 while unblocking this PR. @Ngone51 SPARK-33114 is a small change to support remote shuffle service/storage by adding a metadata object in MapStatus. It could be viewed as a subset of SPARK-25299 's work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So next week I will do the conflict resolution and ping you when the PR is ready for review. Is this okay?

Sure, please. @attilapiros

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hiboyang Thanks for your explanation. I agree that #30763 is too big for review. But I think we can discuss there first to ensure we towards the same direction before we deep into details. And when we're on the same page, we can split the big PR into smaller pieces and start to co-work. Does it sound good to you?

Copy link
Author

@hiboyang hiboyang Mar 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPARK-31801 is very big, and may take very long time to finish (already being there for 10 months). Could we merge this PR first?

If SPARK-31801 find a better way to support it and break getAllMapOutputStatusMetadata, it is actually a good thing :) We could have multiple iterations. This PR is the first iteration with very small change. SPARK-31801 is the iteration after that. The latter does not need to block the former one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should develop like this way...As you mentioned above, SPARK-33114 can be considered as a subtask of SPARK-25299. So how can we consider this PR as a first iteration when SPARK-25299 is still under discussion and development, especially when people haven't reached an agreement on the solution and has a possible alternative solution at the same time? Also, I think the custom shuffle manager isn't officially supported by Spark because the ShuffleManager interface is private. So it doesn't make sense for Spark to add an internal API for un-official use cases if there's no strong reason.

SPARK-31801 is surely big. But as I mentioned early, we can split it. When the solution is finalized, we can start with refactoring MapStatus first. I think it would be a much smaller task and be enough for your case. And then, we'll start the remaining work(e.g. use the new MapStatus where it was referenced) but you don't care.

I understand you have paid a lot of effort into this work, and sorry we can not get it in fast. And, unfortunately, I don't have the permission to merge. You could persuade committers to merge the PR if you insist on it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it is also good idea if we could split SPARK-31801 and start with refactoring MapStatus first. Do you or the community get ideas about how to split SPARK-31801?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're still discussing the solution in #30763. So I can't tell you the concrete split plan. But, I think, we'd be able to start with refactoring MapStatus either way.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I missed the latest discussion in #30763, will check there as well, thanks!

}


Expand All @@ -76,6 +83,18 @@ private[spark] object MapStatus {
}
}

def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
mapTaskId: Long,
metadata: Option[Serializable]): MapStatus = {
if (uncompressedSizes.length > minPartitionsToUseHighlyCompressMapStatus) {
HighlyCompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata)
} else {
new CompressedMapStatus(loc, uncompressedSizes, mapTaskId, metadata)
}
}

private[this] val LOG_BASE = 1.1

/**
Expand Down Expand Up @@ -117,7 +136,8 @@ private[spark] object MapStatus {
private[spark] class CompressedMapStatus(
private[this] var loc: BlockManagerId,
private[this] var compressedSizes: Array[Byte],
private[this] var _mapTaskId: Long)
private[this] var _mapTaskId: Long,
private[this] var _metadata: Option[Serializable] = None)
extends MapStatus with Externalizable {

// For deserialization only
Expand All @@ -127,6 +147,11 @@ private[spark] class CompressedMapStatus(
this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId)
}

def this(loc: BlockManagerId, uncompressedSizes: Array[Long], mapTaskId: Long,
metadata: Option[Serializable]) {
this(loc, uncompressedSizes.map(MapStatus.compressSize), mapTaskId, metadata)
}

override def location: BlockManagerId = loc

override def updateLocation(newLoc: BlockManagerId): Unit = {
Expand All @@ -139,11 +164,15 @@ private[spark] class CompressedMapStatus(

override def mapId: Long = _mapTaskId

override def metadata: Option[Serializable] = _metadata

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
out.writeLong(_mapTaskId)
out.writeBoolean(_metadata.isDefined)
_metadata.foreach(out.writeObject)
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
Expand All @@ -152,6 +181,10 @@ private[spark] class CompressedMapStatus(
compressedSizes = new Array[Byte](len)
in.readFully(compressedSizes)
_mapTaskId = in.readLong()
val hasMetadata = in.readBoolean()
if (hasMetadata) {
_metadata = Some(in.readObject().asInstanceOf[Serializable])
}
}
}

Expand All @@ -173,7 +206,8 @@ private[spark] class HighlyCompressedMapStatus private (
private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long,
private[this] var hugeBlockSizes: scala.collection.Map[Int, Byte],
private[this] var _mapTaskId: Long)
private[this] var _mapTaskId: Long,
private[this] var _metadata: Option[Serializable] = None)
extends MapStatus with Externalizable {

// loc could be null when the default constructor is called during deserialization
Expand Down Expand Up @@ -203,6 +237,8 @@ private[spark] class HighlyCompressedMapStatus private (

override def mapId: Long = _mapTaskId

override def metadata: Option[Serializable] = _metadata

override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
emptyBlocks.serialize(out)
Expand All @@ -213,6 +249,8 @@ private[spark] class HighlyCompressedMapStatus private (
out.writeByte(kv._2)
}
out.writeLong(_mapTaskId)
out.writeBoolean(_metadata.isDefined)
_metadata.foreach(out.writeObject)
}

override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
Expand All @@ -230,14 +268,19 @@ private[spark] class HighlyCompressedMapStatus private (
}
hugeBlockSizes = hugeBlockSizesImpl
_mapTaskId = in.readLong()
val hasMetadata = in.readBoolean()
if (hasMetadata) {
_metadata = Some(in.readObject().asInstanceOf[Serializable])
}
}
}

private[spark] object HighlyCompressedMapStatus {
def apply(
loc: BlockManagerId,
uncompressedSizes: Array[Long],
mapTaskId: Long): HighlyCompressedMapStatus = {
mapTaskId: Long,
metadata: Option[Serializable] = None): HighlyCompressedMapStatus = {
// We must keep track of which blocks are empty so that we don't report a zero-sized
// block as being non-empty (or vice-versa) when using the average block size.
var i = 0
Expand Down Expand Up @@ -278,6 +321,6 @@ private[spark] object HighlyCompressedMapStatus {
emptyBlocks.trim()
emptyBlocks.runOptimize()
new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize,
hugeBlockSizes, mapTaskId)
hugeBlockSizes, mapTaskId, metadata)
}
}
76 changes: 66 additions & 10 deletions core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,38 @@ class MapOutputTrackerSuite extends SparkFunSuite {
assert(tracker.containsShuffle(10))
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(1000L, 10000L), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(10000L, 1000L), 6))
val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L, 10000L), 5)
val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000), Array(10000L, 1000L), 6)
tracker.registerMapOutput(10, 0, mapStatus1)
tracker.registerMapOutput(10, 1, mapStatus2)
val statuses = tracker.getMapSizesByExecutorId(10, 0)
assert(statuses.toSet ===
Seq((BlockManagerId("a", "hostA", 1000),
ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0))),
(BlockManagerId("b", "hostB", 1000),
ArrayBuffer((ShuffleBlockId(10, 6, 0), size10000, 1)))).toSet)
assert(0 == tracker.getNumCachedSerializedBroadcast)
val allStatusMetadata = tracker.getAllMapOutputStatusMetadata(10)
assert(0 == allStatusMetadata.size)
tracker.stop()
rpcEnv.shutdown()
}

test("master register shuffle with map status metadata") {
val rpcEnv = createRpcEnv("test")
val tracker = newTrackerMaster()
tracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, tracker, conf))
tracker.registerShuffle(10, 2)
val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000),
Array(1000L, 10000L), 5, Some("metadata1"))
val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000),
Array(10000L, 1000L), 6, Some(1001))
tracker.registerMapOutput(10, 0, mapStatus1)
tracker.registerMapOutput(10, 1, mapStatus2)
assert(0 == tracker.getNumCachedSerializedBroadcast)
val allStatusMetadata = tracker.getAllMapOutputStatusMetadata(10)
assert(allStatusMetadata === Array(mapStatus1.metadata.get, mapStatus2.metadata.get))
tracker.stop()
rpcEnv.shutdown()
}
Expand All @@ -92,9 +113,11 @@ class MapOutputTrackerSuite extends SparkFunSuite {
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).nonEmpty)
assert(0 == tracker.getNumCachedSerializedBroadcast)
assert(0 == tracker.getAllMapOutputStatusMetadata(10).size)
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0).isEmpty)
assert(0 == tracker.getAllMapOutputStatusMetadata(11).size)

tracker.stop()
rpcEnv.shutdown()
Expand Down Expand Up @@ -122,6 +145,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
intercept[FetchFailedException] { tracker.getMapSizesByExecutorId(10, 1) }
intercept[FetchFailedException] { tracker.getAllMapOutputStatusMetadata(10) }

tracker.stop()
rpcEnv.shutdown()
Expand All @@ -146,13 +170,15 @@ class MapOutputTrackerSuite extends SparkFunSuite {
intercept[FetchFailedException] { mapWorkerTracker.getMapSizesByExecutorId(10, 0) }

val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("a", "hostA", 1000), Array(1000L), 5))
val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5)
masterTracker.registerMapOutput(10, 0, mapStatus)
mapWorkerTracker.updateEpoch(masterTracker.getEpoch)
assert(mapWorkerTracker.getMapSizesByExecutorId(10, 0).toSeq ===
Seq((BlockManagerId("a", "hostA", 1000),
ArrayBuffer((ShuffleBlockId(10, 5, 0), size1000, 0)))))
assert(0 == masterTracker.getNumCachedSerializedBroadcast)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do not remove this assert:

assert(0 == masterTracker.getNumCachedSerializedBroadcast)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems caused by merge issue, will add it back

val allMapOutputStatusMetadata = mapWorkerTracker.getAllMapOutputStatusMetadata(10)
assert(0 == allMapOutputStatusMetadata.size)

val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
Expand All @@ -170,6 +196,33 @@ class MapOutputTrackerSuite extends SparkFunSuite {
mapWorkerRpcEnv.shutdown()
}

test("remote get all map output statuses with metadata") {
val hostname = "localhost"
val rpcEnv = createRpcEnv("spark", hostname, 0, new SecurityManager(conf))

val masterTracker = newTrackerMaster()
masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME,
new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf))

val mapWorkerRpcEnv = createRpcEnv("spark-worker", hostname, 0, new SecurityManager(conf))
val mapWorkerTracker = new MapOutputTrackerWorker(conf)
mapWorkerTracker.trackerEndpoint =
mapWorkerRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)

masterTracker.registerShuffle(10, 1)
val mapStatus = MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L), 5,
Some("metadata1"))
masterTracker.registerMapOutput(10, 0, mapStatus)
val allMapOutputStatusMetadata = mapWorkerTracker.getAllMapOutputStatusMetadata(10)
assert(allMapOutputStatusMetadata.size === 1)
assert(allMapOutputStatusMetadata(0) === mapStatus.metadata.get)

masterTracker.stop()
mapWorkerTracker.stop()
rpcEnv.shutdown()
mapWorkerRpcEnv.shutdown()
}

test("remote fetch below max RPC message size") {
val newConf = new SparkConf
newConf.set(RPC_MESSAGE_MAX_SIZE, 1)
Expand Down Expand Up @@ -311,10 +364,12 @@ class MapOutputTrackerSuite extends SparkFunSuite {
val size0 = MapStatus.decompressSize(MapStatus.compressSize(0L))
val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
val size10000 = MapStatus.decompressSize(MapStatus.compressSize(10000L))
tracker.registerMapOutput(10, 0, MapStatus(BlockManagerId("a", "hostA", 1000),
Array(size0, size1000, size0, size10000), 5))
tracker.registerMapOutput(10, 1, MapStatus(BlockManagerId("b", "hostB", 1000),
Array(size10000, size0, size1000, size0), 6))
val mapStatus1 = MapStatus(BlockManagerId("a", "hostA", 1000),
Array(size0, size1000, size0, size10000), 5)
val mapStatus2 = MapStatus(BlockManagerId("b", "hostB", 1000),
Array(size10000, size0, size1000, size0), 6)
tracker.registerMapOutput(10, 0, mapStatus1)
tracker.registerMapOutput(10, 1, mapStatus2)
assert(tracker.containsShuffle(10))
assert(tracker.getMapSizesByExecutorId(10, 0, 2, 0, 4).toSeq ===
Seq(
Expand All @@ -326,6 +381,7 @@ class MapOutputTrackerSuite extends SparkFunSuite {
(ShuffleBlockId(10, 6, 2), size1000, 1)))
)
)
assert(0 == tracker.getAllMapOutputStatusMetadata(10).size)

tracker.unregisterShuffle(10)
tracker.stop()
Expand Down