From b703f8359cee3c65018fe43f45e26c50771866c7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 20 Jul 2017 14:36:09 -0700 Subject: [PATCH 1/6] Added thread-safe broadcast pickle registry --- python/pyspark/broadcast.py | 23 +++++++++++++++++++++++ python/pyspark/context.py | 6 +++--- python/pyspark/rdd.py | 15 ++++++++------- 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index b1b59f73d671..24b1f6bb41b2 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -139,6 +139,29 @@ def __reduce__(self): return _from_id, (self._jbroadcast.id(),) +class BroadcastPickleRegistry(object): + """ Thread-safe registry for broadcast variables that have been pickled + """ + + def __init__(self, lock): + self._registry = set() + self._lock = lock + + @property + def lock(self): + return self._lock + + def add(self, bcast): + with self._lock: + self._registry.add(bcast) + + def get_and_clear(self): + with self._lock: + registry_copy = self._registry.copy() + self._registry.clear() + return registry_copy + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 80cb48fb8209..4b2c80fed196 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -30,7 +30,7 @@ from pyspark import accumulators from pyspark.accumulators import Accumulator -from pyspark.broadcast import Broadcast +from pyspark.broadcast import Broadcast, BroadcastPickleRegistry from pyspark.conf import SparkConf from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway @@ -195,7 +195,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to # send. - self._pickled_broadcast_vars = set() + self._pickled_broadcast_registry = BroadcastPickleRegistry(self._lock) SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() @@ -793,7 +793,7 @@ def broadcast(self, value): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - return Broadcast(self, value, self._pickled_broadcast_vars) + return Broadcast(self, value, self._pickled_broadcast_registry) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 3325b65f8b60..c9b5da4dc79b 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2370,13 +2370,14 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - # The broadcast will have same life cycle as created PythonRDD - broadcast = sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_vars] - sc._pickled_broadcast_vars.clear() + with sc._pickled_broadcast_registry.lock: + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + pickled_broadcast_vars = sc._pickled_broadcast_registry.get_and_clear() + broadcast_vars = [x._jbroadcast for x in pickled_broadcast_vars] return pickled_command, broadcast_vars, sc.environment, sc._python_includes From 375906d109adac1b0c40a17a0f0702f280741dea Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Jul 2017 13:37:12 -0700 Subject: [PATCH 2/6] changed to use thread local storage --- python/pyspark/broadcast.py | 28 ++++++++++++---------------- python/pyspark/context.py | 2 +- python/pyspark/rdd.py | 15 +++++++-------- 3 files changed, 20 insertions(+), 25 deletions(-) diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py index 24b1f6bb41b2..02fc515fb824 100644 --- a/python/pyspark/broadcast.py +++ b/python/pyspark/broadcast.py @@ -19,6 +19,7 @@ import sys import gc from tempfile import NamedTemporaryFile +import threading from pyspark.cloudpickle import print_exec from pyspark.util import _exception_message @@ -139,27 +140,22 @@ def __reduce__(self): return _from_id, (self._jbroadcast.id(),) -class BroadcastPickleRegistry(object): - """ Thread-safe registry for broadcast variables that have been pickled +class BroadcastPickleRegistry(threading.local): + """ Thread-local registry for broadcast variables that have been pickled """ - def __init__(self, lock): - self._registry = set() - self._lock = lock + def __init__(self): + self.__dict__.setdefault("_registry", set()) - @property - def lock(self): - return self._lock + def __iter__(self): + for bcast in self._registry: + yield bcast def add(self, bcast): - with self._lock: - self._registry.add(bcast) - - def get_and_clear(self): - with self._lock: - registry_copy = self._registry.copy() - self._registry.clear() - return registry_copy + self._registry.add(bcast) + + def clear(self): + self._registry.clear() if __name__ == "__main__": diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 4b2c80fed196..d3f820e11136 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -195,7 +195,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to # send. - self._pickled_broadcast_registry = BroadcastPickleRegistry(self._lock) + self._pickled_broadcast_registry = BroadcastPickleRegistry() SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index c9b5da4dc79b..d36749235f4c 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2370,14 +2370,13 @@ def toLocalIterator(self): def _prepare_for_python_RDD(sc, command): # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() - with sc._pickled_broadcast_registry.lock: - pickled_command = ser.dumps(command) - if len(pickled_command) > (1 << 20): # 1M - # The broadcast will have same life cycle as created PythonRDD - broadcast = sc.broadcast(pickled_command) - pickled_command = ser.dumps(broadcast) - pickled_broadcast_vars = sc._pickled_broadcast_registry.get_and_clear() - broadcast_vars = [x._jbroadcast for x in pickled_broadcast_vars] + pickled_command = ser.dumps(command) + if len(pickled_command) > (1 << 20): # 1M + # The broadcast will have same life cycle as created PythonRDD + broadcast = sc.broadcast(pickled_command) + pickled_command = ser.dumps(broadcast) + broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_registry] + sc._pickled_broadcast_registry.clear() return pickled_command, broadcast_vars, sc.environment, sc._python_includes From 54e8357c1fa362c41fd4e0dfdc7b8a67a86c3a65 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Jul 2017 13:39:45 -0700 Subject: [PATCH 3/6] renamed back to original --- python/pyspark/context.py | 4 ++-- python/pyspark/rdd.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index d3f820e11136..a7046043e037 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -195,7 +195,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, # This allows other code to determine which Broadcast instances have # been pickled, so it can determine which Java broadcast objects to # send. - self._pickled_broadcast_registry = BroadcastPickleRegistry() + self._pickled_broadcast_vars = BroadcastPickleRegistry() SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() @@ -793,7 +793,7 @@ def broadcast(self, value): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - return Broadcast(self, value, self._pickled_broadcast_registry) + return Broadcast(self, value, self._pickled_broadcast_vars) def accumulator(self, value, accum_param=None): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d36749235f4c..3325b65f8b60 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -2375,8 +2375,8 @@ def _prepare_for_python_RDD(sc, command): # The broadcast will have same life cycle as created PythonRDD broadcast = sc.broadcast(pickled_command) pickled_command = ser.dumps(broadcast) - broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_registry] - sc._pickled_broadcast_registry.clear() + broadcast_vars = [x._jbroadcast for x in sc._pickled_broadcast_vars] + sc._pickled_broadcast_vars.clear() return pickled_command, broadcast_vars, sc.environment, sc._python_includes From 0f444f6e357fc000f4316a85b9555ac229cd252e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 31 Jul 2017 16:05:50 -0700 Subject: [PATCH 4/6] added regression test for multithreaded broadcast pickle --- python/pyspark/tests.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 73ab442dfd79..21562ce665b6 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -858,6 +858,47 @@ def test_multiple_broadcasts(self): self.assertEqual(N, size) self.assertEqual(checksum, csum) + def test_multithread_broadcast_pickle(self): + import threading + + b1 = self.sc.broadcast(list(range(3))) + b2 = self.sc.broadcast(list(range(3))) + + def f1(): return b1.value + + def f2(): return b2.value + + funcs_num_pickled = {f1: None, f2: None} + + def do_pickle(f, sc): + command = (f, None, sc.serializer, sc.serializer) + ser = CloudPickleSerializer() + ser.dumps(command) + + def process_vars(sc): + broadcast_vars = [x for x in sc._pickled_broadcast_vars] + num_pickled = len(broadcast_vars) + sc._pickled_broadcast_vars.clear() + return num_pickled + + def run(f, sc): + do_pickle(f, sc) + funcs_num_pickled[f] = process_vars(sc) + + # pickle f1, adds b1 to sc._pickled_broadcast_vars in main thread local storage + do_pickle(f1, self.sc) + + # run all for f2, should only add/count/clear b2 from worker thread local storage + t = threading.Thread(target=run, args=(f2, self.sc)) + t.start() + t.join() + + # count number of vars pickled in main thread, only b1 should be counted and cleared + funcs_num_pickled[f1] = process_vars(self.sc) + + self.assertEqual(funcs_num_pickled[f1], 1) + self.assertEqual(funcs_num_pickled[f2], 1) + def test_large_closure(self): N = 200000 data = [float(i) for i in xrange(N)] From 6710f13fc53b9934b46f209a844f40794374777e Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 31 Jul 2017 16:27:20 -0700 Subject: [PATCH 5/6] fixed python style --- python/pyspark/tests.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 21562ce665b6..6a5e6fa72077 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -864,9 +864,11 @@ def test_multithread_broadcast_pickle(self): b1 = self.sc.broadcast(list(range(3))) b2 = self.sc.broadcast(list(range(3))) - def f1(): return b1.value + def f1(): + return b1.value - def f2(): return b2.value + def f2(): + return b2.value funcs_num_pickled = {f1: None, f2: None} From d4d1fed3979209412070a964fa9951043f5a58bd Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 31 Jul 2017 20:14:41 -0700 Subject: [PATCH 6/6] added check that pickled vars is cleared --- python/pyspark/tests.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 6a5e6fa72077..000dd1eb8e48 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -878,7 +878,7 @@ def do_pickle(f, sc): ser.dumps(command) def process_vars(sc): - broadcast_vars = [x for x in sc._pickled_broadcast_vars] + broadcast_vars = list(sc._pickled_broadcast_vars) num_pickled = len(broadcast_vars) sc._pickled_broadcast_vars.clear() return num_pickled @@ -900,6 +900,7 @@ def run(f, sc): self.assertEqual(funcs_num_pickled[f1], 1) self.assertEqual(funcs_num_pickled[f2], 1) + self.assertEqual(len(list(self.sc._pickled_broadcast_vars)), 0) def test_large_closure(self): N = 200000