Skip to content

Commit 5016375

Browse files
committed
Address TD's comments
1 parent 7ed72fb commit 5016375

File tree

14 files changed

+181
-84
lines changed

14 files changed

+181
-84
lines changed

core/src/main/scala/org/apache/spark/ContextCleaner.scala

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -169,18 +169,17 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
169169

170170
// Used for testing
171171

172-
private[spark] def cleanupRDD(rdd: RDD[_]) {
172+
def cleanupRDD(rdd: RDD[_]) {
173173
doCleanupRDD(rdd.id)
174174
}
175175

176-
private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
176+
def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) {
177177
doCleanupShuffle(shuffleDependency.shuffleId)
178178
}
179179

180-
private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) {
180+
def cleanupBroadcast[T](broadcast: Broadcast[T]) {
181181
doCleanupBroadcast(broadcast.id)
182182
}
183-
184183
}
185184

186185
private object ContextCleaner {

core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ package org.apache.spark.broadcast
1919

2020
import java.io.Serializable
2121

22+
import org.apache.spark.SparkException
23+
2224
/**
2325
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
2426
* cached on each machine rather than shipping a copy of it with tasks. They can be used, for
@@ -49,25 +51,36 @@ import java.io.Serializable
4951
*/
5052
abstract class Broadcast[T](val id: Long) extends Serializable {
5153

54+
protected var _isValid: Boolean = true
55+
5256
/**
5357
* Whether this Broadcast is actually usable. This should be false once persisted state is
5458
* removed from the driver.
5559
*/
56-
protected var isValid: Boolean = true
60+
def isValid: Boolean = _isValid
5761

5862
def value: T
5963

6064
/**
61-
* Remove all persisted state associated with this broadcast. Overriding implementations
62-
* should set isValid to false if persisted state is also removed from the driver.
63-
*
64-
* @param removeFromDriver Whether to remove state from the driver.
65-
* If true, the resulting broadcast should no longer be valid.
65+
* Remove all persisted state associated with this broadcast on the executors. The next use
66+
* of this broadcast on the executors will trigger a remote fetch.
6667
*/
67-
def unpersist(removeFromDriver: Boolean)
68+
def unpersist()
6869

69-
// We cannot define abstract readObject and writeObject here due to some weird issues
70-
// with these methods having to be 'private' in sub-classes.
70+
/**
71+
* Remove all persisted state associated with this broadcast on both the executors and the
72+
* driver. Overriding implementations should set isValid to false.
73+
*/
74+
private[spark] def destroy()
75+
76+
/**
77+
* If this broadcast is no longer valid, throw an exception.
78+
*/
79+
protected def assertValid() {
80+
if (!_isValid) {
81+
throw new SparkException("Attempted to use %s when is no longer valid!".format(toString))
82+
}
83+
}
7184

7285
override def toString = "Broadcast(" + id + ")"
7386
}

core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ private[spark] class BroadcastManager(
2525
val isDriver: Boolean,
2626
conf: SparkConf,
2727
securityManager: SecurityManager)
28-
extends Logging with Serializable {
28+
extends Logging {
2929

3030
private var initialized = false
3131
private var broadcastFactory: BroadcastFactory = null
@@ -63,5 +63,4 @@ private[spark] class BroadcastManager(
6363
def unbroadcast(id: Long, removeFromDriver: Boolean) {
6464
broadcastFactory.unbroadcast(id, removeFromDriver)
6565
}
66-
6766
}

core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
3131
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
3232
extends Broadcast[T](id) with Logging with Serializable {
3333

34-
override def value = value_
34+
def value: T = {
35+
assertValid()
36+
value_
37+
}
3538

3639
val blockId = BroadcastBlockId(id)
3740

@@ -45,17 +48,24 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
4548
}
4649

4750
/**
48-
* Remove all persisted state associated with this HTTP broadcast.
49-
* @param removeFromDriver Whether to remove state from the driver.
51+
* Remove all persisted state associated with this HTTP broadcast on the executors.
52+
*/
53+
def unpersist() {
54+
HttpBroadcast.unpersist(id, removeFromDriver = false)
55+
}
56+
57+
/**
58+
* Remove all persisted state associated with this HTTP Broadcast on both the executors
59+
* and the driver.
5060
*/
51-
override def unpersist(removeFromDriver: Boolean) {
52-
isValid = !removeFromDriver
53-
HttpBroadcast.unpersist(id, removeFromDriver)
61+
private[spark] def destroy() {
62+
_isValid = false
63+
HttpBroadcast.unpersist(id, removeFromDriver = true)
5464
}
5565

5666
// Used by the JVM when serializing this object
5767
private def writeObject(out: ObjectOutputStream) {
58-
assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!")
68+
assertValid()
5969
out.defaultWriteObject()
6070
}
6171

@@ -231,5 +241,4 @@ private[spark] object HttpBroadcast extends Logging {
231241
logError("Exception while deleting broadcast file: %s".format(file), e)
232242
}
233243
}
234-
235244
}

core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ import org.apache.spark.util.Utils
2929
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
3030
extends Broadcast[T](id) with Logging with Serializable {
3131

32-
override def value = value_
32+
def value = {
33+
assertValid()
34+
value_
35+
}
3336

3437
val broadcastId = BroadcastBlockId(id)
3538

@@ -47,7 +50,23 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
4750
sendBroadcast()
4851
}
4952

50-
def sendBroadcast() {
53+
/**
54+
* Remove all persisted state associated with this Torrent broadcast on the executors.
55+
*/
56+
def unpersist() {
57+
TorrentBroadcast.unpersist(id, removeFromDriver = false)
58+
}
59+
60+
/**
61+
* Remove all persisted state associated with this Torrent broadcast on both the executors
62+
* and the driver.
63+
*/
64+
private[spark] def destroy() {
65+
_isValid = false
66+
TorrentBroadcast.unpersist(id, removeFromDriver = true)
67+
}
68+
69+
private def sendBroadcast() {
5170
val tInfo = TorrentBroadcast.blockifyObject(value_)
5271
totalBlocks = tInfo.totalBlocks
5372
totalBytes = tInfo.totalBytes
@@ -71,18 +90,9 @@ private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boo
7190
}
7291
}
7392

74-
/**
75-
* Remove all persisted state associated with this Torrent broadcast.
76-
* @param removeFromDriver Whether to remove state from the driver.
77-
*/
78-
override def unpersist(removeFromDriver: Boolean) {
79-
isValid = !removeFromDriver
80-
TorrentBroadcast.unpersist(id, removeFromDriver)
81-
}
82-
8393
// Used by the JVM when serializing this object
8494
private def writeObject(out: ObjectOutputStream) {
85-
assert(isValid, "Attempted to serialize a broadcast variable that has been destroyed!")
95+
assertValid()
8696
out.defaultWriteObject()
8797
}
8898

@@ -240,7 +250,6 @@ private[spark] object TorrentBroadcast extends Logging {
240250
def unpersist(id: Long, removeFromDriver: Boolean) = synchronized {
241251
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver)
242252
}
243-
244253
}
245254

246255
private[spark] case class TorrentBlock(

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1128,5 +1128,4 @@ abstract class RDD[T: ClassTag](
11281128
def toJavaRDD() : JavaRDD[T] = {
11291129
new JavaRDD(this)(elementClassTag)
11301130
}
1131-
11321131
}

core/src/main/scala/org/apache/spark/storage/BlockId.scala

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,7 @@ private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: I
5353
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
5454
}
5555

56-
// Leave field as an instance variable to avoid matching on it
57-
private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
58-
var field = ""
56+
private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId {
5957
def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field)
6058
}
6159

@@ -77,19 +75,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId {
7775
def name = "test_" + id
7876
}
7977

80-
private[spark] object BroadcastBlockId {
81-
def apply(broadcastId: Long, field: String) = {
82-
val blockId = new BroadcastBlockId(broadcastId)
83-
blockId.field = field
84-
blockId
85-
}
86-
}
87-
8878
private[spark] object BlockId {
8979
val RDD = "rdd_([0-9]+)_([0-9]+)".r
9080
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
91-
val BROADCAST = "broadcast_([0-9]+)".r
92-
val BROADCAST_FIELD = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
81+
val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
9382
val TASKRESULT = "taskresult_([0-9]+)".r
9483
val STREAM = "input-([0-9]+)-([0-9]+)".r
9584
val TEST = "test_(.*)".r
@@ -100,10 +89,8 @@ private[spark] object BlockId {
10089
RDDBlockId(rddId.toInt, splitIndex.toInt)
10190
case SHUFFLE(shuffleId, mapId, reduceId) =>
10291
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
103-
case BROADCAST(broadcastId) =>
104-
BroadcastBlockId(broadcastId.toLong)
105-
case BROADCAST_FIELD(broadcastId, field) =>
106-
BroadcastBlockId(broadcastId.toLong, field)
92+
case BROADCAST(broadcastId, field) =>
93+
BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
10794
case TASKRESULT(taskId) =>
10895
TaskResultBlockId(taskId.toLong)
10996
case STREAM(streamId, uniqueId) =>

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ private[spark] class BlockManager(
832832
def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) {
833833
logInfo("Removing broadcast " + broadcastId)
834834
val blocksToRemove = blockInfo.keys.collect {
835-
case bid: BroadcastBlockId if bid.broadcastId == broadcastId => bid
835+
case bid @ BroadcastBlockId(`broadcastId`, _) => bid
836836
}
837837
blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) }
838838
}
@@ -897,7 +897,7 @@ private[spark] class BlockManager(
897897

898898
def shouldCompress(blockId: BlockId): Boolean = blockId match {
899899
case ShuffleBlockId(_, _, _) => compressShuffle
900-
case BroadcastBlockId(_) => compressBroadcast
900+
case BroadcastBlockId(_, _) => compressBroadcast
901901
case RDDBlockId(_, _) => compressRdds
902902
case TempBlockId(_) => compressShuffleSpill
903903
case _ => false

core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
106106
askDriverWithReply(RemoveBlock(blockId))
107107
}
108108

109-
/**
110-
* Remove all blocks belonging to the given RDD.
111-
*/
109+
/** Remove all blocks belonging to the given RDD. */
112110
def removeRdd(rddId: Int, blocking: Boolean) {
113111
val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId))
114112
future onFailure {
@@ -119,16 +117,12 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
119117
}
120118
}
121119

122-
/**
123-
* Remove all blocks belonging to the given shuffle.
124-
*/
120+
/** Remove all blocks belonging to the given shuffle. */
125121
def removeShuffle(shuffleId: Int) {
126122
askDriverWithReply(RemoveShuffle(shuffleId))
127123
}
128124

129-
/**
130-
* Remove all blocks belonging to the given broadcast.
131-
*/
125+
/** Remove all blocks belonging to the given broadcast. */
132126
def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) {
133127
askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster))
134128
}
@@ -148,20 +142,21 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log
148142
}
149143

150144
/**
151-
* Return the block's local status on all block managers, if any.
145+
* Return the block's status on all block managers, if any.
152146
*
153147
* If askSlaves is true, this invokes the master to query each block manager for the most
154148
* updated block statuses. This is useful when the master is not informed of the given block
155149
* by all block managers.
156-
*
157-
* To avoid potential deadlocks, the use of Futures is necessary, because the master actor
158-
* should not block on waiting for a block manager, which can in turn be waiting for the
159-
* master actor for a response to a prior message.
160150
*/
161151
def getBlockStatus(
162152
blockId: BlockId,
163153
askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = {
164154
val msg = GetBlockStatus(blockId, askSlaves)
155+
/*
156+
* To avoid potential deadlocks, the use of Futures is necessary, because the master actor
157+
* should not block on waiting for a block manager, which can in turn be waiting for the
158+
* master actor for a response to a prior message.
159+
*/
165160
val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg)
166161
val (blockManagerIds, futures) = response.unzip
167162
val result = Await.result(Future.sequence(futures), timeout)

core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -255,21 +255,22 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
255255
}
256256

257257
/**
258-
* Return the block's local status for all block managers, if any.
258+
* Return the block's status for all block managers, if any.
259259
*
260260
* If askSlaves is true, the master queries each block manager for the most updated block
261261
* statuses. This is useful when the master is not informed of the given block by all block
262262
* managers.
263-
*
264-
* Rather than blocking on the block status query, master actor should simply return a
265-
* Future to avoid potential deadlocks. This can arise if there exists a block manager
266-
* that is also waiting for this master actor's response to a previous message.
267263
*/
268264
private def blockStatus(
269265
blockId: BlockId,
270266
askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = {
271267
import context.dispatcher
272268
val getBlockStatus = GetBlockStatus(blockId)
269+
/*
270+
* Rather than blocking on the block status query, master actor should simply return
271+
* Futures to avoid potential deadlocks. This can arise if there exists a block manager
272+
* that is also waiting for this master actor's response to a previous message.
273+
*/
273274
blockManagerInfo.values.map { info =>
274275
val blockStatusFuture =
275276
if (askSlaves) {

0 commit comments

Comments
 (0)