diff --git a/python/pyspark/context.py b/python/pyspark/context.py index c96156b21072..6c6a538808a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1111,6 +1111,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): -------- >>> import threading >>> from time import sleep + >>> from pyspark import InheritableThread >>> result = "Not Set" >>> lock = threading.Lock() >>> def map_func(x): @@ -1128,8 +1129,8 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): ... sleep(5) ... sc.cancelJobGroup("job_to_cancel") >>> suppress = lock.acquire() - >>> suppress = threading.Thread(target=start_job, args=(10,)).start() - >>> suppress = threading.Thread(target=stop_job).start() + >>> suppress = InheritableThread(target=start_job, args=(10,)).start() + >>> suppress = InheritableThread(target=stop_job).start() >>> suppress = lock.acquire() >>> print(result) Cancelled diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 92f4070b822e..79b57d7ed67a 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -23,7 +23,7 @@ from abc import ABCMeta, abstractmethod, abstractproperty from multiprocessing.pool import ThreadPool -from pyspark import keyword_only, since, SparkContext +from pyspark import keyword_only, since, SparkContext, inheritable_thread_target from pyspark.ml import Estimator, Predictor, PredictionModel, Model from pyspark.ml.param.shared import HasRawPredictionCol, HasProbabilityCol, HasThresholds, \ HasRegParam, HasMaxIter, HasFitIntercept, HasTol, HasStandardization, HasWeightCol, \ @@ -2921,7 +2921,7 @@ def trainSingleClass(index): pool = ThreadPool(processes=min(self.getParallelism(), numClasses)) - models = pool.map(trainSingleClass, range(numClasses)) + models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses)) if handlePersistence: multiclassLabeled.unpersist() diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 74d3642b52a6..04959f73048e 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -24,7 +24,7 @@ import numpy as np -from pyspark import keyword_only, since, SparkContext +from pyspark import keyword_only, since, SparkContext, inheritable_thread_target from pyspark.ml import Estimator, Transformer, Model from pyspark.ml.common import inherit_doc, _py2java, _java2py from pyspark.ml.evaluation import Evaluator @@ -729,7 +729,9 @@ def _fit(self, dataset): validation = datasets[i][1].cache() train = datasets[i][0].cache() - tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + tasks = map( + inheritable_thread_target, + _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)) for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): metrics[j] += (metric / nFolds) if collectSubModelsParam: @@ -1261,7 +1263,9 @@ def _fit(self, dataset): if collectSubModelsParam: subModels = [None for i in range(numModels)] - tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam) + tasks = map( + inheritable_thread_target, + _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)) pool = ThreadPool(processes=min(self.getParallelism(), numModels)) metrics = [None] * numModels for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index af32ba1ab082..ee9aee20fa37 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -287,11 +287,13 @@ def inheritable_thread_target(f): ----- This API is experimental. - It captures the local properties when you decorate it. Therefore, it is encouraged - to decorate it when you want to capture the local properties. + It is important to know that it captures the local properties when you decorate it + whereas :class:`InheritableThread` captures when the thread is started. + Therefore, it is encouraged to decorate it when you want to capture the local + properties. For example, the local properties from the current Spark context is captured - when you define a function here: + when you define a function here instead of the invocation: >>> @inheritable_thread_target ... def target_func(): @@ -305,35 +307,22 @@ def inheritable_thread_target(f): >>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP """ from pyspark import SparkContext - if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true": - # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. - sc = SparkContext._active_spark_context - # Get local properties from main thread - properties = sc._jsc.sc().getLocalProperties().clone() + if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + # NOTICE the internal difference vs `InheritableThread`. `InheritableThread` + # copies local properties when the thread starts but `inheritable_thread_target` + # copies when the function is wrapped. + properties = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone() @functools.wraps(f) - def wrapped_f(*args, **kwargs): + def wrapped(*args, **kwargs): try: # Set local properties in child thread. - sc._jsc.sc().setLocalProperties(properties) + SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties) return f(*args, **kwargs) finally: - thread_connection = sc._jvm._gateway_client.thread_connection.connection() - if thread_connection is not None: - connections = sc._jvm._gateway_client.deque - # Reuse the lock for Py4J in PySpark - with SparkContext._lock: - for i in range(len(connections)): - if connections[i] is thread_connection: - connections[i].close() - del connections[i] - break - else: - # Just in case the connection was not closed but removed from the - # queue. - thread_connection.close() - return wrapped_f + InheritableThread._clean_py4j_conn_for_current_thread() + return wrapped else: return f @@ -354,21 +343,67 @@ class InheritableThread(threading.Thread): .. versionadded:: 3.1.0 - Notes ----- This API is experimental. """ def __init__(self, target, *args, **kwargs): - super(InheritableThread, self).__init__( - target=inheritable_thread_target(target), *args, **kwargs - ) + from pyspark import SparkContext + + if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + def copy_local_properties(*a, **k): + # self._props is set before starting the thread to match the behavior with JVM. + assert hasattr(self, "_props") + SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props) + try: + return target(*a, **k) + finally: + InheritableThread._clean_py4j_conn_for_current_thread() + + super(InheritableThread, self).__init__( + target=copy_local_properties, *args, **kwargs) + else: + super(InheritableThread, self).__init__(target=target, *args, **kwargs) + def start(self, *args, **kwargs): + from pyspark import SparkContext -if __name__ == "__main__": - import doctest + if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true": + # Local property copy should happen in Thread.start to mimic JVM's behavior. + self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone() + return super(InheritableThread, self).start(*args, **kwargs) + + @staticmethod + def _clean_py4j_conn_for_current_thread(): + from pyspark import SparkContext + + jvm = SparkContext._jvm + thread_connection = jvm._gateway_client.thread_connection.connection() + if thread_connection is not None: + connections = jvm._gateway_client.deque + # Reuse the lock for Py4J in PySpark + with SparkContext._lock: + for i in range(len(connections)): + if connections[i] is thread_connection: + connections[i].close() + del connections[i] + break + else: + # Just in case the connection was not closed but removed from the + # queue. + thread_connection.close() + +if __name__ == "__main__": if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 7): - (failure_count, test_count) = doctest.testmod() + import doctest + import pyspark.util + from pyspark.context import SparkContext + + globs = pyspark.util.__dict__.copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest') + (failure_count, test_count) = doctest.testmod(pyspark.util, globs=globs) + globs['sc'].stop() + if failure_count: sys.exit(-1)