@@ -23,6 +23,7 @@ import java.nio.charset.Charset
2323import java .util .{List => JList , ArrayList => JArrayList , Map => JMap , Collections }
2424
2525import scala .collection .JavaConversions ._
26+ import scala .collection .mutable
2627import scala .language .existentials
2728import scala .reflect .ClassTag
2829import scala .util .{Try , Success , Failure }
@@ -193,11 +194,26 @@ private[spark] class PythonRDD(
193194 PythonRDD .writeUTF(include, dataOut)
194195 }
195196 // Broadcast variables
196- dataOut.writeInt(broadcastVars.length)
197+ val bids = PythonRDD .getWorkerBroadcasts(worker)
198+ val nbids = broadcastVars.map(_.id).toSet
199+ // number of different broadcasts
200+ val cnt = bids.diff(nbids).size + nbids.diff(bids).size
201+ dataOut.writeInt(cnt)
202+ for (bid <- bids) {
203+ if (! nbids.contains(bid)) {
204+ // remove the broadcast from worker
205+ dataOut.writeLong(- bid)
206+ bids.remove(bid)
207+ }
208+ }
197209 for (broadcast <- broadcastVars) {
198- dataOut.writeLong(broadcast.id)
199- dataOut.writeInt(broadcast.value.length)
200- dataOut.write(broadcast.value)
210+ if (! bids.contains(broadcast.id)) {
211+ // send new broadcast
212+ dataOut.writeLong(broadcast.id)
213+ dataOut.writeInt(broadcast.value.length)
214+ dataOut.write(broadcast.value)
215+ bids.add(broadcast.id)
216+ }
201217 }
202218 dataOut.flush()
203219 // Serialized command:
@@ -275,6 +291,12 @@ private object SpecialLengths {
275291private [spark] object PythonRDD extends Logging {
276292 val UTF8 = Charset .forName(" UTF-8" )
277293
294+ // remember the broadcasts sent to each worker
295+ private val workerBroadcasts = new mutable.WeakHashMap [Socket , mutable.Set [Long ]]()
296+ private def getWorkerBroadcasts (worker : Socket ) = {
297+ workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet [Long ]())
298+ }
299+
278300 /**
279301 * Adapter for calling SparkContext#runJob from Python.
280302 *
0 commit comments