diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index ece4ae6ab031..f4054655fa9c 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -40,6 +40,20 @@ private[spark] trait BroadcastFactory { */ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T] + /** + * Creates a new broadcast variable with a specified id. The different of the origin interface + * is that there is a new param `isExecutorSide` to tell the BroadCast it is a executor-side + * broadcast and should consider recovery when get block data failed. + */ + def newExecutorBroadcast[T: ClassTag]( + value: T, + id: Long, + nBlocks: Int, + cSums: Array[Int]): Broadcast[T] + + // Called from executor to put broadcast data to blockmanager. + def uploadBroadcast[T: ClassTag](value_ : T, id: Long): Seq[Int] + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit def stop(): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index e88988fe03b2..42b6a4a33b9a 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -17,12 +17,17 @@ package org.apache.spark.broadcast +import java.io.IOException import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag +import org.apache.hadoop.fs.Path + import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.util.ShutdownHookManager private[spark] class BroadcastManager( val isDriver: Boolean, @@ -32,6 +37,9 @@ private[spark] class BroadcastManager( private var initialized = false private var broadcastFactory: BroadcastFactory = null + private val shutdownHook = addShutdownHook() + private[spark] lazy val hdfsBackupDir = + Option(new Path(conf.get("spark.broadcast.backup.dir", s"/tmp/spark/${conf.getAppId}_blocks"))) initialize() @@ -47,16 +55,76 @@ private[spark] class BroadcastManager( } def stop() { + // Remove the shutdown hook. It causes memory leaks if we leave it around. + try { + ShutdownHookManager.removeShutdownHook(shutdownHook) + } catch { + case e: Exception => + logError(s"Exception while removing shutdown hook.", e) + } + // only delete the path from driver when the app stop. + if (isDriver) { + hdfsBackupDir.foreach { dirPath => + try { + val fs = dirPath.getFileSystem(SparkHadoopUtil.get.conf) + if (fs.exists(dirPath)) { + fs.delete(dirPath, true) + } + } catch { + case e: IOException => + logWarning(s"Failed to delete broadcast temp dir $dirPath.", e) + } + } + } + broadcastFactory.stop() } private val nextBroadcastId = new AtomicLong(0) + // Called from driver to create new broadcast id + def newBroadcastId: Long = nextBroadcastId.getAndIncrement() + def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean): Broadcast[T] = { broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) } + // Called from executor to upload broadcast data to blockmanager. + def uploadBroadcast[T: ClassTag]( + value_ : T, + id: Long + ): Seq[Int] = { + broadcastFactory.uploadBroadcast[T](value_, id) + } + + // Called from driver to create broadcast with specified id + def newExecutorBroadcast[T: ClassTag]( + value_ : T, + id: Long, + nBlocks: Int, + cSums: Array[Int]): Broadcast[T] = { + broadcastFactory.newExecutorBroadcast[T](value_, id, nBlocks, cSums) + } + def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { broadcastFactory.unbroadcast(id, removeFromDriver, blocking) } + + private def addShutdownHook(): AnyRef = { + logDebug("Adding shutdown hook") // force eager creation of logger + ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY) { () => + logInfo("Shutdown hook called") + BroadcastManager.this.stop() + } + } + +} + +/** + * Marker trait to identify the shape in which tuples are broadcasted. This is used for + * executor-side broadcast, typical examples of this are identity (tuples remain unchanged) + * or hashed (tuples are converted into some hash index). + */ +trait TransFunc[T, U] extends Serializable { + def transform(rows: Array[T]): U } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 22d01c47e645..ce52b8efa533 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -38,6 +38,7 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * * The mechanism is as follows: * + * 1 For driver side broadcast(when isExecutorSide is false): * The driver divides the serialized object into small chunks and * stores those chunks in the BlockManager of the driver. * @@ -51,10 +52,30 @@ import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStrea * * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. * + * 2 For executor side broadcast(when isExecutorSide is true): + * One executor divides the serialized object into small chunks and + * stores those chunks in the BlockManager of the executor. + * + * On other executors, the executor first attempts to fetch the object from its BlockManager. If + * it does not exist, it then uses remote fetches to fetch the small chunks from + * other executors if available. Once it gets the chunks, it puts the chunks in its own + * BlockManager, ready for other executors to fetch from. + * + * In executor side broadcast driver never holds the broadcast data. + * + * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. + * * @param obj object to broadcast * @param id A unique identifier for the broadcast variable. + * @param isExecutorSide A identifier for executor broadcast variable. + * @param nBlocks how many blocks for executor broadcast. */ -private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) +private[spark] class TorrentBroadcast[T: ClassTag]( + @transient val obj: T, + id: Long, + isExecutorSide: Boolean = false, + nBlocks: Option[Int] = None, + cSums: Option[Array[Int]] = None) extends Broadcast[T](id) with Logging with Serializable { /** @@ -70,6 +91,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */ @transient private var blockSize: Int = _ + /** Whether to generate checksum for blocks or not. */ + private var checksumEnabled: Boolean = false + /** The checksum for all the blocks. */ + private var checksums: Array[Int] = cSums.getOrElse(null) + private def setConf(conf: SparkConf) { compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) { Some(CompressionCodec.createCodec(conf)) @@ -84,13 +110,14 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) private val broadcastId = BroadcastBlockId(id) - /** Total number of blocks this broadcast variable contains. */ - private val numBlocks: Int = writeBlocks(obj) + def getNumBlocksAndChecksums: Seq[Int] = if (checksumEnabled) { + Seq(numBlocks) ++ checksums + } else { + Seq(numBlocks) + } - /** Whether to generate checksum for blocks or not. */ - private var checksumEnabled: Boolean = false - /** The checksum for all the blocks. */ - private var checksums: Array[Int] = _ + /** Total number of blocks this broadcast variable contains. */ + private val numBlocks: Int = nBlocks.getOrElse(writeBlocks(obj)) // this must be after checkSums override protected def getValue() = { _value @@ -132,6 +159,9 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) checksums(i) = calcChecksum(block) } val pieceId = BroadcastBlockId(id, "piece" + i) + if (isExecutorSide) { + blockManager.persistBroadcastPiece(pieceId, block) + } val bytes = new ChunkedByteBuffer(block.duplicate()) if (!blockManager.putBytes(pieceId, bytes, MEMORY_AND_DISK_SER, tellMaster = true)) { throw new SparkException(s"Failed to store $pieceId of $broadcastId in local BlockManager") @@ -158,7 +188,7 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) blocks(pid) = block releaseLock(pieceId) case None => - bm.getRemoteBytes(pieceId) match { + bm.getRemoteBytes(pieceId).orElse(bm.getBroadcastPiece(pieceId)) match { case Some(b) => if (checksumEnabled) { val sum = calcChecksum(b.chunks(0)) @@ -169,7 +199,8 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) } // We found the block from remote executors/driver's BlockManager, so put the block // in this executor's BlockManager. - if (!bm.putBytes(pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { + if (!bm.putBytes( + pieceId, b, StorageLevel.MEMORY_AND_DISK_SER, tellMaster = true)) { throw new SparkException( s"Failed to store $pieceId of $broadcastId in local BlockManager") } @@ -194,7 +225,11 @@ private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long) * and driver. */ override protected def doDestroy(blocking: Boolean) { - TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) + if (isExecutorSide) { + TorrentBroadcast.unpersist(id, removeFromDriver = false, blocking) + } else { + TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking) + } } /** Used by the JVM when serializing this object. */ @@ -301,5 +336,6 @@ private object TorrentBroadcast extends Logging { def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = { logDebug(s"Unpersisting TorrentBroadcast $id") SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) + SparkEnv.get.blockManager.cleanBroadcastPieces(id) } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index b11f9ba171b8..1bb2edb668a4 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -34,6 +34,19 @@ private[spark] class TorrentBroadcastFactory extends BroadcastFactory { new TorrentBroadcast[T](value_, id) } + override def newExecutorBroadcast[T: ClassTag]( + value: T, + id: Long, + nBlocks: Int, + cSums: Array[Int]): Broadcast[T] = { + new TorrentBroadcast[T](value, id, true, Option(nBlocks), Option(cSums)) + } + + override def uploadBroadcast[T: ClassTag](value_ : T, id: Long): Seq[Int] = { + val executorBroadcast = new TorrentBroadcast[T](value_, id, true) + executorBroadcast.getNumBlocksAndChecksums + } + override def stop() { } /** diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 374abccf6ad5..a9cd7e4ab4f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -34,6 +34,7 @@ import org.apache.spark._ import org.apache.spark.Partitioner._ import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.{Broadcast, TorrentBroadcast, TransFunc} import org.apache.spark.internal.Logging import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -42,8 +43,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, - SamplingUtils} +import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} /** * A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable, @@ -937,6 +937,64 @@ abstract class RDD[T: ClassTag]( } /** + * Broadcast the rdd to the cluster from executor, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each cluster only once. + * + * User should pass in a translate function to compute the broadcast value from the rdd. + */ + @Since("2.1.0") + private[spark] def broadcast[U: ClassTag](transFunc: TransFunc[T, U]): Broadcast[U] = withScope { + val bc = if (partitions.size > 0) { + val id = sc.env.broadcastManager.newBroadcastId + + // first: write blocks to block manager from executor. + val numBlocksAndChecksums = coalesce(1).mapPartitions { iter => + SparkEnv.get.broadcastManager + .uploadBroadcast(transFunc.transform(iter.toArray), id).iterator + }.collect() + + // then: create broadcast from driver, this will not write blocks + val res = SparkEnv.get.broadcastManager.newExecutorBroadcast( + transFunc.transform(Array.empty[T]), + id, + numBlocksAndChecksums.head, + numBlocksAndChecksums.tail) + + val callSite = sc.getCallSite + logInfo("Created executor side broadcast " + res.id + " from " + callSite.shortForm) + res + } else { + // Rdd may have 0 partitions, for this case use driver broadcast. + val res = SparkEnv.get.broadcastManager.newBroadcast( + transFunc.transform(Array.empty[T]), sc.isLocal) + val callSite = sc.getCallSite + logInfo("Created broadcast " + res.id + " from " + callSite.shortForm) + res + } + + sc.cleaner.foreach(_.registerBroadcastForCleanup(bc)) + bc + } + + /** + * Executor broadcast api, it broadcast the rdd to the cluster from executor, returning a + * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. + * The variable will be sent to each cluster only once. + * + * @param f is a translate function to compute the broadcast value from the rdd. + */ + @Since("2.1.0") + def broadcast[U: ClassTag](f: Iterator[T] => U): Broadcast[U] = withScope { + val transFunc = new TransFunc[T, U] { + override def transform(rows: Array[T]): U = { + f(rows.toIterator) + } + } + broadcast(transFunc) + } + + /** * Return an iterator that contains all of the elements in this RDD. * * The iterator will consume as much memory as the largest partition in this RDD. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 04521c9159ea..27acc5818475 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -750,6 +750,27 @@ private[spark] class BlockManager( syncWrites, writeMetrics, blockId) } + /** + * Persist a broadcast piece of serialized bytes to the hdfs. + */ + def persistBroadcastPiece(id: BlockId, block: ByteBuffer): Unit = { + diskBlockManager.persistBroadcastPiece(id, block) + } + + /** + * Get broadcast block from the hdfs, as serialized bytes. + */ + def getBroadcastPiece(id: BlockId): Option[ChunkedByteBuffer] = { + diskBlockManager.getBroadcastPiece(id) + } + + /** + * Clean hdfs files for executor broadcast. + */ + def cleanBroadcastPieces(id: Long): Unit = { + diskBlockManager.cleanBroadcastPieces(id) + } + /** * Put a new block of serialized bytes to the block manager. * diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3d43e3c367aa..a07211ee599c 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -18,12 +18,17 @@ package org.apache.spark.storage import java.io.{File, IOException} +import java.nio.ByteBuffer import java.util.UUID -import org.apache.spark.SparkConf +import org.apache.hadoop.fs._ + +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.internal.Logging import org.apache.spark.util.{ShutdownHookManager, Utils} +import org.apache.spark.util.io.ChunkedByteBuffer /** * Creates and maintains the logical mapping between logical blocks and physical on-disk @@ -44,6 +49,9 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea logError("Failed to create any local dir.") System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR) } + + private[spark] lazy val broadcastBackupDir = SparkEnv.get.broadcastManager.hdfsBackupDir + // The content of subDirs is immutable but the content of subDirs(i) is mutable. And the content // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) @@ -140,6 +148,114 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea } } + + def persistBroadcastPiece(id: BlockId, block: ByteBuffer): Unit = { + broadcastBackupDir.foreach { dirPath => + val blockFile = id.toString + val filePath = new Path(dirPath, blockFile) + var fs: FileSystem = null + try { + fs = filePath.getFileSystem(SparkHadoopUtil.get.conf) + if (fs.exists(filePath)) { + logWarning(s"File(${filePath.getName})already exists.") + } + } catch { + case e: IOException => + logWarning("Error when list files in doing persistBroadcast", e) + return + } + var shouldDeleteFile: Boolean = false + var outStream: FSDataOutputStream = null + try { + outStream = fs.create(filePath) + outStream.write(block.array()) + outStream.hflush() + logInfo(s"Store block: $blockFile into underlying fs.") + } catch { + case e: IOException => + logWarning("Error when backing broadcast to hdfs and try to clean the file", e) + shouldDeleteFile = true + } finally { + if (null != outStream) { + try { + outStream.close() + } catch { + case e: Throwable => + logWarning("Can't close the output stream.", e) + } + } + if (shouldDeleteFile) { + try { + fs.delete(filePath, true) + } catch { + case e: Exception => + logWarning(s"Failed to clean the broadcast file{$blockFile}.", e) + } + } + } + } + } + + def cleanBroadcastPieces(id: Long): Unit = { + broadcastBackupDir.foreach { dirPath => + val fileFilter = new PathFilter { + override def accept(pathname: Path): Boolean = { + pathname.getName.startsWith(s"broadcast_${id}_piece") + } + } + try { + val fs = dirPath.getFileSystem(SparkHadoopUtil.get.conf) + if (!fs.exists(dirPath)) { + return + } + val files = fs.listStatus(dirPath, fileFilter).map(_.getPath) + for (file <- files) { + if (fs.exists(file)) { + fs.delete(file, true) + logInfo(s"Underlying fs file: ${file.getName} has been clean.") + } + } + } catch { + case e: IOException => + logWarning(s"Failed to clean the broadcast file{${dirPath.toString}} in cleaning.", e) + } + } + } + + def getBroadcastPiece(id: BlockId): Option[ChunkedByteBuffer] = { + var inputStream: FSDataInputStream = null + try { + broadcastBackupDir.map { dirPath => + val blockFile = id.toString + val filePath = new Path(dirPath, blockFile) + val fs = filePath.getFileSystem(SparkHadoopUtil.get.conf) + (filePath, fs) + }.filter { case(path, fs) => + fs.exists(path) + }.map { case (filePath, fs) => + inputStream = fs.open(filePath) + val status = fs.getFileStatus(filePath) + val buffer = new Array[Byte](status.getLen.toInt) + inputStream.readFully(0, buffer) + logInfo(s"Got bytes from underling fs file: ${filePath.getName}.") + new ChunkedByteBuffer(ByteBuffer.wrap(buffer)) + } + } catch { + case e: Exception => + logError("Error in read the broadCast value from underlying fs ", e) + None + } finally { + if (null != inputStream) { + try { + inputStream.close() + } catch { + case e: Throwable => + logWarning("Can't close the input stream.", e) + } + } + } + } + private def addShutdownHook(): AnyRef = { logDebug("Adding shutdown hook") // force eager creation of logger ShutdownHookManager.addShutdownHook(ShutdownHookManager.TEMP_DIR_SHUTDOWN_PRIORITY + 1) { () => diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index ad56715656c8..0f05a83e6c19 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -29,6 +29,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat} import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} +import org.apache.spark.broadcast.TransFunc import org.apache.spark.rdd.RDDSuiteUtils._ import org.apache.spark.util.Utils @@ -48,6 +49,50 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { } } + test("executor broadcast") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val transFun = new TransFunc[Int, Int] { + override def transform(rows: Array[Int]): Int = { + if (rows.size > 0) rows.reduce(_ + _) else 0 + } + } + val b1 = nums.broadcast(transFun) + val b2 = nums.broadcast { iter => + if (iter.hasNext) iter.reduce(_ + _) else 0 + } + assert(b1.value == 10) + assert(b2.value == 10) + } + + test("executor broadcast --- empty rdd") { + val empty = sc.makeRDD(Array.empty[Int], 2) + val transFun = new TransFunc[Int, Int] { + override def transform(rows: Array[Int]): Int = if (rows.size > 0) rows.reduce(_ + _) else 0 + } + val b1 = empty.broadcast(transFun) + assert(b1.value == 0) + val b2 = empty.broadcast { iter => + if (iter.hasNext) iter.reduce(_ + _) else 0 + } + assert(b2.value == 0) + } + + test("executor broadcast --- broadcast data lost from executor") { + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val transFun = new TransFunc[Int, Int] { + override def transform(rows: Array[Int]): Int = if (rows.size > 0) rows.reduce(_ + _) else 0 + } + val b1 = nums.broadcast(transFun) + sc.env.blockManager.removeBroadcast(b1.id, false) + assert(b1.value == 10) + + val b2 = nums.broadcast{ iter => + if (iter.hasNext) iter.reduce(_ + _) else 0 + } + sc.env.blockManager.removeBroadcast(b2.id, false) + assert(b2.value == 10) + } + test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.getNumPartitions === 2) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 5e8a854e46a0..298c1d4b52a5 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -103,6 +103,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou import DAGSchedulerSuite._ val conf = new SparkConf + conf.set("spark.app.id", "DAGSchedulerSuite") + /** Set of TaskSets the DAGScheduler has requested executed. */ val taskSets = scala.collection.mutable.Buffer[TaskSet]() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9dfdf4da78ff..feaabf046c0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.catalyst.plans.physical +import org.apache.spark.broadcast.TransFunc import org.apache.spark.sql.catalyst.InternalRow /** * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index). */ -trait BroadcastMode { +trait BroadcastMode extends TransFunc[InternalRow, Any]{ def transform(rows: Array[InternalRow]): Any /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 7be5d31d4a76..a061dcfdbdfc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -38,13 +38,18 @@ import org.apache.spark.util.ThreadUtils */ case class BroadcastExchangeExec( mode: BroadcastMode, - child: SparkPlan) extends Exchange { + child: SparkPlan, + conf: SQLConf) extends Exchange { + + private def executorBroadcast: Boolean = conf.executorBroadcastEnabled override lazy val metrics = Map( "dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"), "collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"), "buildTime" -> SQLMetrics.createMetric(sparkContext, "time to build (ms)"), - "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)")) + "broadcastTime" -> SQLMetrics.createMetric(sparkContext, "time to broadcast (ms)"), + "collect_build_broadcastTime" -> SQLMetrics.createMetric(sparkContext, + "time to collect, build and broadcast (ms)")) override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) @@ -73,30 +78,41 @@ case class BroadcastExchangeExec( // with the correct execution. SQLExecution.withExecutionId(sparkContext, executionId) { try { - val beforeCollect = System.nanoTime() - // Note that we use .executeCollect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = child.executeCollect() - if (input.length >= 512000000) { - throw new SparkException( - s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") - } - val beforeBuild = System.nanoTime() - longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum - longMetric("dataSize") += dataSize - if (dataSize >= (8L << 30)) { - throw new SparkException( - s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + val broadcasted = if (executorBroadcast) { + val before = System.nanoTime() + val res = + child.execute().mapPartitions { iter => + iter.map(_.copy()) + }.broadcast(mode) + longMetric("collect_build_broadcastTime") += (System.nanoTime() - before) / 1000000 + res + } else { + val beforeCollect = System.nanoTime() + // Note that we use .executeCollect() because we don't want to + // convert data to Scala types + val input: Array[InternalRow] = child.executeCollect() + if (input.length >= 512000000) { + throw new SparkException( + s"Cannot broadcast the table with more than" + + s"512 millions rows: ${input.length} rows") + } + val beforeBuild = System.nanoTime() + longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 + val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + longMetric("dataSize") += dataSize + if (dataSize >= (8L << 30)) { + throw new SparkException( + s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") + } + // Construct and broadcast the relation. + val relation = mode.transform(input) + val beforeBroadcast = System.nanoTime() + longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 + val res = sparkContext.broadcast(relation) + longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 + res } - // Construct and broadcast the relation. - val relation = mode.transform(input) - val beforeBroadcast = System.nanoTime() - longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 - - val broadcasted = sparkContext.broadcast(relation) - longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 - // There are some cases we don't care about the metrics and call `SparkPlan.doExecute` // directly without setting an execution id. We should be tolerant to it. if (executionId != null) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f17049949aa4..5f223af5bb79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -160,7 +160,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, distribution) if child.outputPartitioning.satisfies(distribution) => child case (child, BroadcastDistribution(mode)) => - BroadcastExchangeExec(mode, child) + BroadcastExchangeExec(mode, child, conf) case (child, distribution) => ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 304dcb691b32..45a96c3f2f44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -154,6 +154,11 @@ object SQLConf { .booleanConf .createWithDefault(false) + val EXECUTOR_BROADCAST_N_ENABLED = SQLConfigBuilder("spark.sql.executorBroadcast.enabled") + .doc("When true, broadcast join use executor side broadcast.") + .booleanConf + .createWithDefault(true) + val SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS = SQLConfigBuilder("spark.sql.adaptive.minNumPostShufflePartitions") .internal() @@ -727,6 +732,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def adaptiveExecutionEnabled: Boolean = getConf(ADAPTIVE_EXECUTION_ENABLED) + def executorBroadcastEnabled: Boolean = getConf(EXECUTOR_BROADCAST_N_ENABLED) + def minNumPostShufflePartitions: Int = getConf(SHUFFLE_MIN_NUM_POSTSHUFFLE_PARTITIONS) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index c27b815dfa08..c09d43c4c6ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1071,7 +1071,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { agg.queryExecution.executedPlan.collectFirst { case ShuffleExchange(_, _: RDDScanExec, _) => - case BroadcastExchangeExec(_, _: RDDScanExec) => + case BroadcastExchangeExec(_, _: RDDScanExec, _) => }.foreach { _ => fail( "No Exchange should be inserted above RDDScanExec since the checkpointed Dataset " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 36cde3233dce..fb1e6fc50ce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -55,12 +55,12 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val output = plan.output assert(plan sameResult plan) - val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan) + val exchange1 = BroadcastExchangeExec(IdentityBroadcastMode, plan, spark.sessionState.conf) val hashMode = HashedRelationBroadcastMode(output) - val exchange2 = BroadcastExchangeExec(hashMode, plan) + val exchange2 = BroadcastExchangeExec(hashMode, plan, spark.sessionState.conf) val hashMode2 = HashedRelationBroadcastMode(Alias(output.head, "id2")() :: Nil) - val exchange3 = BroadcastExchangeExec(hashMode2, plan) + val exchange3 = BroadcastExchangeExec(hashMode2, plan, spark.sessionState.conf) val exchange4 = ReusedExchangeExec(output, exchange3) assert(exchange1 sameResult exchange1)