Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions python/pyspark/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys
import gc
from tempfile import NamedTemporaryFile
import threading

from pyspark.cloudpickle import print_exec
from pyspark.util import _exception_message
Expand Down Expand Up @@ -139,6 +140,24 @@ def __reduce__(self):
return _from_id, (self._jbroadcast.id(),)


class BroadcastPickleRegistry(threading.local):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm.. actually, I prefer the locking way before.. I guess It wouldn't be big performance differences due to GIL and simple lock was easy to read ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I am okay with the current way too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm ok for both ways. :)

My only concern is in previous locking way is we lock it for dumping the command. I'm not sure if the dumping can take long time for big command so we prevent other threads to preparing their commands.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, this anyway solves the issue and looks apparently safe from your concern. Probably, will make a follow up after testing it (quite) later.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the lock was a little more obvious what is going on, but it's better to not use a lock in case of pickling a large command like @viirya said. Also, this way doesn't need to change any of the pickling code, so I prefer it too.

""" Thread-local registry for broadcast variables that have been pickled
"""

def __init__(self):
self.__dict__.setdefault("_registry", set())

def __iter__(self):
for bcast in self._registry:
yield bcast

def add(self, bcast):
self._registry.add(bcast)

def clear(self):
self._registry.clear()


if __name__ == "__main__":
import doctest
(failure_count, test_count) = doctest.testmod()
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from pyspark import accumulators
from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.broadcast import Broadcast, BroadcastPickleRegistry
from pyspark.conf import SparkConf
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
Expand Down Expand Up @@ -195,7 +195,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# This allows other code to determine which Broadcast instances have
# been pickled, so it can determine which Java broadcast objects to
# send.
self._pickled_broadcast_vars = set()
self._pickled_broadcast_vars = BroadcastPickleRegistry()

SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
Expand Down
44 changes: 44 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,50 @@ def test_multiple_broadcasts(self):
self.assertEqual(N, size)
self.assertEqual(checksum, csum)

def test_multithread_broadcast_pickle(self):
import threading

b1 = self.sc.broadcast(list(range(3)))
b2 = self.sc.broadcast(list(range(3)))

def f1():
return b1.value

def f2():
return b2.value

funcs_num_pickled = {f1: None, f2: None}

def do_pickle(f, sc):
command = (f, None, sc.serializer, sc.serializer)
ser = CloudPickleSerializer()
ser.dumps(command)

def process_vars(sc):
broadcast_vars = list(sc._pickled_broadcast_vars)
num_pickled = len(broadcast_vars)
sc._pickled_broadcast_vars.clear()
return num_pickled

def run(f, sc):
do_pickle(f, sc)
funcs_num_pickled[f] = process_vars(sc)

# pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage
do_pickle(f1, self.sc)

# run all for f2, should only add/count/clear b2 from worker thread local storage
t = threading.Thread(target=run, args=(f2, self.sc))
t.start()
t.join()

# count number of vars pickled in main thread, only b1 should be counted and cleared
funcs_num_pickled[f1] = process_vars(self.sc)

self.assertEqual(funcs_num_pickled[f1], 1)
self.assertEqual(funcs_num_pickled[f2], 1)
self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0)

def test_large_closure(self):
N = 200000
data = [float(i) for i in xrange(N)]
Expand Down