diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 60141792d499..0bfb9b8a9717 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -28,7 +28,7 @@ import socket from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile -from threading import Thread +from threading import Thread, Lock from collections import defaultdict from itertools import chain from functools import reduce @@ -55,6 +55,9 @@ __all__ = ["RDD"] +# Lock which will make sure that dependend broadcast variables are pickled along +# with their PythonRDD wrapped function when using multple threads(SPARK-12717). +_lock = Lock() def portable_hash(x): """ @@ -2451,10 +2454,12 @@ def _jrdd(self): else: profiler = None - wrapped_func = _wrap_function(self.ctx, self.func, self._prev_jrdd_deserializer, - self._jrdd_deserializer, profiler) - python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, - self.preservesPartitioning) + with _lock: + wrapped_func = _wrap_function(self.ctx, self.func, + self._prev_jrdd_deserializer, self._jrdd_deserializer, profiler) + python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), wrapped_func, + self.preservesPartitioning) + self._jrdd_val = python_rdd.asJavaRDD() if profiler: