Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,20 @@ private[spark] trait BroadcastFactory {
*/
def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]

/**
* Creates a new broadcast variable with a specified id. The different of the origin interface
* is that there is a new param `isExecutorSide` to tell the BroadCast it is a executor-side
* broadcast and should consider recovery when get block data failed.
*/
def newExecutorBroadcast[T: ClassTag](
value: T,
id: Long,
nBlocks: Int,
cSums: Array[Int]): Broadcast[T]

// Called from executor to put broadcast data to blockmanager.
def uploadBroadcast[T: ClassTag](value_ : T, id: Long): Seq[Int]

def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit

def stop(): Unit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@

package org.apache.spark.broadcast

import java.io.IOException
import java.util.concurrent.atomic.AtomicLong

import scala.reflect.ClassTag

import org.apache.hadoop.fs.Path

import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.internal.Logging
import org.apache.spark.util.ShutdownHookManager

private[spark] class BroadcastManager(
val isDriver: Boolean,
Expand All @@ -32,6 +37,9 @@ private[spark] class BroadcastManager(

private var initialized = false
private var broadcastFactory: BroadcastFactory = null
private val shutdownHook = addShutdownHook()
private[spark] lazy val hdfsBackupDir =
Option(new Path(conf.get("spark.broadcast.backup.dir", s"/tmp/spark/${conf.getAppId}_blocks")))

initialize()

Expand All @@ -47,16 +55,76 @@ private[spark] class BroadcastManager(
}

def stop() {
// Remove the shutdown hook. It causes memory leaks if we leave it around.
try {
ShutdownHookManager.removeShutdownHook(shutdownHook)
} catch {
case e: Exception =>
logError(s"Exception while removing shutdown hook.", e)
}
// only delete the path from driver when the app stop.
if (isDriver) {
hdfsBackupDir.foreach { dirPath =>
try {
val fs = dirPath.getFileSystem(SparkHadoopUtil.get.conf)
if (fs.exists(dirPath)) {
fs.delete(dirPath, true)
}
} catch {
case e: IOException =>
logWarning(s"Failed to delete broadcast temp dir $dirPath.", e)
}
}
}

broadcastFactory.stop()
}

private val nextBroadcastId = new AtomicLong(0)

// Called from driver to create new broadcast id
def newBroadcastId: Long = nextBroadcastId.getAndIncrement()

def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}

// Called from executor to upload broadcast data to blockmanager.
def uploadBroadcast[T: ClassTag](
value_ : T,
id: Long
): Seq[Int] = {
broadcastFactory.uploadBroadcast[T](value_, id)
}

// Called from driver to create broadcast with specified id
def newExecutorBroadcast[T: ClassTag](
value_ : T,
id: Long,
nBlocks: Int,
cSums: Array[Int]): Broadcast[T] = {
broadcastFactory.newExecutorBroadcast[T](value_, id, nBlocks, cSums)
}

def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) {
broadcastFactory.unbroadcast(id, removeFromDriver, blocking)
}

private def addShutdownHook(): AnyRef = {
logDebug("Adding shutdown hook") // force eager creation of logger
ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY) { () =>
logInfo("Shutdown hook called")
BroadcastManager.this.stop()
}
}

}

/**
* Marker trait to identify the shape in which tuples are broadcasted. This is used for
* executor-side broadcast, typical examples of this are identity (tuples remain unchanged)
* or hashed (tuples are converted into some hash index).
*/
trait TransFunc[T, U] extends Serializable {
def transform(rows: Array[T]): U
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
*
* The mechanism is as follows:
*
* 1 For driver side broadcast(when isExecutorSide is false):
* The driver divides the serialized object into small chunks and
* stores those chunks in the BlockManager of the driver.
*
Expand All @@ -51,10 +52,30 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea
*
* When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
*
* 2 For executor side broadcast(when isExecutorSide is true):
* One executor divides the serialized object into small chunks and
* stores those chunks in the BlockManager of the executor.
*
* On other executors, the executor first attempts to fetch the object from its BlockManager. If
* it does not exist, it then uses remote fetches to fetch the small chunks from
* other executors if available. Once it gets the chunks, it puts the chunks in its own
* BlockManager, ready for other executors to fetch from.
*
* In executor side broadcast driver never holds the broadcast data.
*
* When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
*
* @param obj object to broadcast
* @param id A unique identifier for the broadcast variable.
* @param isExecutorSide A identifier for executor broadcast variable.
* @param nBlocks how many blocks for executor broadcast.
*/
private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
private[spark] class TorrentBroadcast[T: ClassTag](
@transient val obj: T,
id: Long,
isExecutorSide: Boolean = false,
nBlocks: Option[Int] = None,
cSums: Option[Array[Int]] = None)
extends Broadcast[T](id) with Logging with Serializable {

/**
Expand All @@ -70,6 +91,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
/** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
@transient private var blockSize: Int = _

/** Whether to generate checksum for blocks or not. */
private var checksumEnabled: Boolean = false
/** The checksum for all the blocks. */
private var checksums: Array[Int] = cSums.getOrElse(null)

private def setConf(conf: SparkConf) {
compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
Some(CompressionCodec.createCodec(conf))
Expand All @@ -84,13 +110,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)

private val broadcastId = BroadcastBlockId(id)

/** Total number of blocks this broadcast variable contains. */
private val numBlocks: Int = writeBlocks(obj)
def getNumBlocksAndChecksums: Seq[Int] = if (checksumEnabled) {
Seq(numBlocks) ++ checksums
} else {
Seq(numBlocks)
}

/** Whether to generate checksum for blocks or not. */
private var checksumEnabled: Boolean = false
/** The checksum for all the blocks. */
private var checksums: Array[Int] = _
/** Total number of blocks this broadcast variable contains. */
private val numBlocks: Int = nBlocks.getOrElse(writeBlocks(obj)) // this must be after checkSums

override protected def getValue() = {
_value
Expand Down Expand Up @@ -132,6 +159,9 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
checksums(i) = calcChecksum(block)
}
val pieceId = BroadcastBlockId(id, "piece" + i)
if (isExecutorSide) {
blockManager.persistBroadcastPiece(pieceId, block)
}
val bytes = new ChunkedByteBuffer(block.duplicate())
if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager")
Expand All @@ -158,7 +188,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
blocks(pid) = block
releaseLock(pieceId)
case None =>
bm.getRemoteBytes(pieceId) match {
bm.getRemoteBytes(pieceId).orElse(bm.getBroadcastPiece(pieceId)) match {
case Some(b) =>
if (checksumEnabled) {
val sum = calcChecksum(b.chunks(0))
Expand All @@ -169,7 +199,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
}
// We found the block from remote executors/driver's BlockManager, so put the block
// in this executor's BlockManager.
if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
if (!bm.putBytes(
pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) {
throw new SparkException(
s"Failed to store $pieceId of $broadcastId in local BlockManager")
}
Expand All @@ -194,7 +225,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
* and driver.
*/
override protected def doDestroy(blocking: Boolean) {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
if (isExecutorSide) {
TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking)
} else {
TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
}
}

/** Used by the JVM when serializing this object. */
Expand Down Expand Up @@ -301,5 +336,6 @@ private object TorrentBroadcast extends Logging {
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = {
logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
SparkEnv.get.blockManager.cleanBroadcastPieces(id)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory {
new TorrentBroadcast[T](value_, id)
}

override def newExecutorBroadcast[T: ClassTag](
value: T,
id: Long,
nBlocks: Int,
cSums: Array[Int]): Broadcast[T] = {
new TorrentBroadcast[T](value, id, true, Option(nBlocks), Option(cSums))
}

override def uploadBroadcast[T: ClassTag](value_ : T, id: Long): Seq[Int] = {
val executorBroadcast = new TorrentBroadcast[T](value_, id, true)
executorBroadcast.getNumBlocksAndChecksums
}

override def stop() { }

/**
Expand Down
62 changes: 60 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.apache.spark._
import org.apache.spark.Partitioner._
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.broadcast.{Broadcast, TorrentBroadcast, TransFunc}
import org.apache.spark.internal.Logging
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
Expand All @@ -42,8 +43,7 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils}

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -937,6 +937,64 @@ abstract class RDD[T: ClassTag](
}

/**
* Broadcast the rdd to the cluster from executor, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* User should pass in a translate function to compute the broadcast value from the rdd.
*/
@Since("2.1.0")
private[spark] def broadcast[U: ClassTag](transFunc: TransFunc[T, U]): Broadcast[U] = withScope {
val bc = if (partitions.size > 0) {
val id = sc.env.broadcastManager.newBroadcastId

// first: write blocks to block manager from executor.
val numBlocksAndChecksums = coalesce(1).mapPartitions { iter =>
SparkEnv.get.broadcastManager
.uploadBroadcast(transFunc.transform(iter.toArray), id).iterator
}.collect()

// then: create broadcast from driver, this will not write blocks
val res = SparkEnv.get.broadcastManager.newExecutorBroadcast(
transFunc.transform(Array.empty[T]),
id,
numBlocksAndChecksums.head,
numBlocksAndChecksums.tail)

val callSite = sc.getCallSite
logInfo("Created executor side broadcast " + res.id + " from " + callSite.shortForm)
res
} else {
// Rdd may have 0 partitions, for this case use driver broadcast.
val res = SparkEnv.get.broadcastManager.newBroadcast(
transFunc.transform(Array.empty[T]), sc.isLocal)
val callSite = sc.getCallSite
logInfo("Created broadcast " + res.id + " from " + callSite.shortForm)
res
}

sc.cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
}

/**
* Executor broadcast api, it broadcast the rdd to the cluster from executor, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* @param f is a translate function to compute the broadcast value from the rdd.
*/
@Since("2.1.0")
def broadcast[U: ClassTag](f: Iterator[T] => U): Broadcast[U] = withScope {
val transFunc = new TransFunc[T, U] {
override def transform(rows: Array[T]): U = {
f(rows.toIterator)
}
}
broadcast(transFunc)
}

/**
* Return an iterator that contains all of the elements in this RDD.
*
* The iterator will consume as much memory as the largest partition in this RDD.
Expand Down
21 changes: 21 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,27 @@ private[spark] class BlockManager(
syncWrites, writeMetrics, blockId)
}

/**
* Persist a broadcast piece of serialized bytes to the hdfs.
*/
def persistBroadcastPiece(id: BlockId, block: ByteBuffer): Unit = {
diskBlockManager.persistBroadcastPiece(id, block)
}

/**
* Get broadcast block from the hdfs, as serialized bytes.
*/
def getBroadcastPiece(id: BlockId): Option[ChunkedByteBuffer] = {
diskBlockManager.getBroadcastPiece(id)
}

/**
* Clean hdfs files for executor broadcast.
*/
def cleanBroadcastPieces(id: Long): Unit = {
diskBlockManager.cleanBroadcastPieces(id)
}

/**
* Put a new block of serialized bytes to the block manager.
*
Expand Down
Loading