Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,6 +1111,7 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
--------
>>> import threading
>>> from time import sleep
>>> from pyspark import InheritableThread
>>> result = "Not Set"
>>> lock = threading.Lock()
>>> def map_func(x):
Expand All @@ -1128,8 +1129,8 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False):
... sleep(5)
... sc.cancelJobGroup("job_to_cancel")
>>> suppress = lock.acquire()
>>> suppress = threading.Thread(target=start_job, args=(10,)).start()
>>> suppress = threading.Thread(target=stop_job).start()
>>> suppress = InheritableThread(target=start_job, args=(10,)).start()
>>> suppress = InheritableThread(target=stop_job).start()
>>> suppress = lock.acquire()
>>> print(result)
Cancelled
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from abc import ABCMeta, abstractmethod, abstractproperty
from multiprocessing.pool import ThreadPool

from pyspark import keyword_only, since, SparkContext
from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
from pyspark.ml import Estimator, Predictor, PredictionModel, Model
from pyspark.ml.param.shared import HasRawPredictionCol, HasProbabilityCol, HasThresholds, \
HasRegParam, HasMaxIter, HasFitIntercept, HasTol, HasStandardization, HasWeightCol, \
Expand Down Expand Up @@ -2921,7 +2921,7 @@ def trainSingleClass(index):

pool = ThreadPool(processes=min(self.getParallelism(), numClasses))

models = pool.map(trainSingleClass, range(numClasses))
models = pool.map(inheritable_thread_target(trainSingleClass), range(numClasses))

if handlePersistence:
multiclassLabeled.unpersist()
Expand Down
10 changes: 7 additions & 3 deletions python/pyspark/ml/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import numpy as np

from pyspark import keyword_only, since, SparkContext
from pyspark import keyword_only, since, SparkContext, inheritable_thread_target
from pyspark.ml import Estimator, Transformer, Model
from pyspark.ml.common import inherit_doc, _py2java, _java2py
from pyspark.ml.evaluation import Evaluator
Expand Down Expand Up @@ -729,7 +729,9 @@ def _fit(self, dataset):
validation = datasets[i][1].cache()
train = datasets[i][0].cache()

tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
tasks = map(
inheritable_thread_target,
_parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam))
for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
metrics[j] += (metric / nFolds)
if collectSubModelsParam:
Expand Down Expand Up @@ -1261,7 +1263,9 @@ def _fit(self, dataset):
if collectSubModelsParam:
subModels = [None for i in range(numModels)]

tasks = _parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam)
tasks = map(
inheritable_thread_target,
_parallelFitTasks(est, train, eva, validation, epm, collectSubModelsParam))
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
metrics = [None] * numModels
for j, metric, subModel in pool.imap_unordered(lambda f: f(), tasks):
Expand Down
99 changes: 67 additions & 32 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,13 @@ def inheritable_thread_target(f):
-----
This API is experimental.

It captures the local properties when you decorate it. Therefore, it is encouraged
to decorate it when you want to capture the local properties.
It is important to know that it captures the local properties when you decorate it
whereas :class:`InheritableThread` captures when the thread is started.
Therefore, it is encouraged to decorate it when you want to capture the local
properties.

For example, the local properties from the current Spark context is captured
when you define a function here:
when you define a function here instead of the invocation:

>>> @inheritable_thread_target
... def target_func():
Expand All @@ -305,35 +307,22 @@ def inheritable_thread_target(f):
>>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP
"""
from pyspark import SparkContext
if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true":
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
sc = SparkContext._active_spark_context

# Get local properties from main thread
properties = sc._jsc.sc().getLocalProperties().clone()
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
# NOTICE the internal difference vs `InheritableThread`. `InheritableThread`
# copies local properties when the thread starts but `inheritable_thread_target`
# copies when the function is wrapped.
properties = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()

@functools.wraps(f)
def wrapped_f(*args, **kwargs):
def wrapped(*args, **kwargs):
try:
# Set local properties in child thread.
sc._jsc.sc().setLocalProperties(properties)
SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties)
return f(*args, **kwargs)
finally:
thread_connection = sc._jvm._gateway_client.thread_connection.connection()
if thread_connection is not None:
connections = sc._jvm._gateway_client.deque
# Reuse the lock for Py4J in PySpark
with SparkContext._lock:
for i in range(len(connections)):
if connections[i] is thread_connection:
connections[i].close()
del connections[i]
break
else:
# Just in case the connection was not closed but removed from the
# queue.
thread_connection.close()
return wrapped_f
InheritableThread._clean_py4j_conn_for_current_thread()
return wrapped
else:
return f

Expand All @@ -354,21 +343,67 @@ class InheritableThread(threading.Thread):

.. versionadded:: 3.1.0


Notes
-----
This API is experimental.
"""
def __init__(self, target, *args, **kwargs):
super(InheritableThread, self).__init__(
target=inheritable_thread_target(target), *args, **kwargs
)
from pyspark import SparkContext

if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
def copy_local_properties(*a, **k):
# self._props is set before starting the thread to match the behavior with JVM.
assert hasattr(self, "_props")
SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props)
try:
return target(*a, **k)
finally:
InheritableThread._clean_py4j_conn_for_current_thread()

super(InheritableThread, self).__init__(
target=copy_local_properties, *args, **kwargs)
else:
super(InheritableThread, self).__init__(target=target, *args, **kwargs)

def start(self, *args, **kwargs):
from pyspark import SparkContext

if __name__ == "__main__":
import doctest
if os.environ.get("PYSPARK_PIN_THREAD", "true").lower() == "true":
# Local property copy should happen in Thread.start to mimic JVM's behavior.
self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
return super(InheritableThread, self).start(*args, **kwargs)

@staticmethod
def _clean_py4j_conn_for_current_thread():
from pyspark import SparkContext

jvm = SparkContext._jvm
thread_connection = jvm._gateway_client.thread_connection.connection()
if thread_connection is not None:
connections = jvm._gateway_client.deque
# Reuse the lock for Py4J in PySpark
with SparkContext._lock:
for i in range(len(connections)):
if connections[i] is thread_connection:
connections[i].close()
del connections[i]
break
else:
# Just in case the connection was not closed but removed from the
# queue.
thread_connection.close()


if __name__ == "__main__":
if "pypy" not in platform.python_implementation().lower() and sys.version_info[:2] >= (3, 7):
(failure_count, test_count) = doctest.testmod()
import doctest
import pyspark.util
from pyspark.context import SparkContext

globs = pyspark.util.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(pyspark.util, globs=globs)
globs['sc'].stop()

if failure_count:
sys.exit(-1)