Skip to content

Commit 6123d0f

Browse files
committed
track broadcasts for each worker
1 parent 8d2f08c commit 6123d0f

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
2323
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
2424

2525
import scala.collection.JavaConversions._
26+
import scala.collection.mutable
2627
import scala.language.existentials
2728
import scala.reflect.ClassTag
2829
import scala.util.{Try, Success, Failure}
@@ -193,11 +194,26 @@ private[spark] class PythonRDD(
193194
PythonRDD.writeUTF(include, dataOut)
194195
}
195196
// Broadcast variables
196-
dataOut.writeInt(broadcastVars.length)
197+
val bids = PythonRDD.getWorkerBroadcasts(worker)
198+
val nbids = broadcastVars.map(_.id).toSet
199+
// number of different broadcasts
200+
val cnt = bids.diff(nbids).size + nbids.diff(bids).size
201+
dataOut.writeInt(cnt)
202+
for (bid <- bids) {
203+
if (!nbids.contains(bid)) {
204+
// remove the broadcast from worker
205+
dataOut.writeLong(-bid)
206+
bids.remove(bid)
207+
}
208+
}
197209
for (broadcast <- broadcastVars) {
198-
dataOut.writeLong(broadcast.id)
199-
dataOut.writeInt(broadcast.value.length)
200-
dataOut.write(broadcast.value)
210+
if (!bids.contains(broadcast.id)) {
211+
// send new broadcast
212+
dataOut.writeLong(broadcast.id)
213+
dataOut.writeInt(broadcast.value.length)
214+
dataOut.write(broadcast.value)
215+
bids.add(broadcast.id)
216+
}
201217
}
202218
dataOut.flush()
203219
// Serialized command:
@@ -275,6 +291,12 @@ private object SpecialLengths {
275291
private[spark] object PythonRDD extends Logging {
276292
val UTF8 = Charset.forName("UTF-8")
277293

294+
// remember the broadcasts sent to each worker
295+
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
296+
private def getWorkerBroadcasts(worker: Socket) = {
297+
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
298+
}
299+
278300
/**
279301
* Adapter for calling SparkContext#runJob from Python.
280302
*

python/pyspark/worker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,11 @@ def main(infile, outfile):
6969
ser = CompressedSerializer(pickleSer)
7070
for _ in range(num_broadcast_variables):
7171
bid = read_long(infile)
72-
value = ser._read_with_length(infile)
73-
_broadcastRegistry[bid] = Broadcast(bid, value)
72+
if bid > 0:
73+
value = ser._read_with_length(infile)
74+
_broadcastRegistry[bid] = Broadcast(bid, value)
75+
else:
76+
_broadcastRegistry.pop(-bid, None)
7477

7578
command = pickleSer._read_with_length(infile)
7679
(func, deserializer, serializer) = command

0 commit comments

Comments
 (0)