@@ -33,19 +33,55 @@ extends Broadcast[T](id) with Logging with Serializable {
3333 def value = value_
3434
3535 def unpersist (removeSource : Boolean ) {
36- SparkEnv .get.blockManager.master.removeBlock(broadcastId)
37- SparkEnv .get.blockManager.removeBlock(broadcastId)
36+ TorrentBroadcast .synchronized {
37+ SparkEnv .get.blockManager.master.removeBlock(broadcastId)
38+ SparkEnv .get.blockManager.removeBlock(broadcastId)
39+ }
40+
41+ if (! removeSource) {
42+ // We can't tell BlockManager master to remove blocks from all nodes except driver,
43+ // so we need to save them here in order to store them on disk later.
44+ // This may be inefficient if blocks were already dropped to disk,
45+ // but since unpersist is supposed to be called right after working with
46+ // a broadcast this should not happen (and getting them from memory is cheap).
47+ arrayOfBlocks = new Array [TorrentBlock ](totalBlocks)
48+
49+ for (pid <- 0 until totalBlocks) {
50+ val pieceId = pieceBlockId(pid)
51+ TorrentBroadcast .synchronized {
52+ SparkEnv .get.blockManager.getSingle(pieceId) match {
53+ case Some (x) =>
54+ arrayOfBlocks(pid) = x.asInstanceOf [TorrentBlock ]
55+ case None =>
56+ throw new SparkException (" Failed to get " + pieceId + " of " + broadcastId)
57+ }
58+ }
59+ }
60+ }
61+
62+ for (pid <- 0 until totalBlocks) {
63+ TorrentBroadcast .synchronized {
64+ SparkEnv .get.blockManager.master.removeBlock(pieceBlockId(pid))
65+ }
66+ }
3867
3968 if (removeSource) {
40- for (pid <- pieceIds) {
41- SparkEnv .get.blockManager.removeBlock(pieceBlockId(pid) )
69+ TorrentBroadcast . synchronized {
70+ SparkEnv .get.blockManager.removeBlock(metaId )
4271 }
43- SparkEnv .get.blockManager.removeBlock(metaId)
4472 } else {
45- for (pid <- pieceIds) {
46- SparkEnv .get.blockManager.dropFromMemory(pieceBlockId(pid) )
73+ TorrentBroadcast . synchronized {
74+ SparkEnv .get.blockManager.dropFromMemory(metaId )
4775 }
48- SparkEnv .get.blockManager.dropFromMemory(metaId)
76+
77+ for (i <- 0 until totalBlocks) {
78+ val pieceId = pieceBlockId(i)
79+ TorrentBroadcast .synchronized {
80+ SparkEnv .get.blockManager.putSingle(
81+ pieceId, arrayOfBlocks(i), StorageLevel .DISK_ONLY , true )
82+ }
83+ }
84+ arrayOfBlocks = null
4985 }
5086 }
5187
@@ -128,11 +164,6 @@ extends Broadcast[T](id) with Logging with Serializable {
128164 }
129165
130166 private def resetWorkerVariables () {
131- if (arrayOfBlocks != null ) {
132- for (pid <- pieceIds) {
133- SparkEnv .get.blockManager.removeBlock(pieceBlockId(pid))
134- }
135- }
136167 arrayOfBlocks = null
137168 totalBytes = - 1
138169 totalBlocks = - 1
0 commit comments