diff --git a/python/pyspark/sql/tests/connect/test_session.py b/python/pyspark/sql/tests/connect/test_session.py index 0482f119d63b..17dd4cefd21c 100644 --- a/python/pyspark/sql/tests/connect/test_session.py +++ b/python/pyspark/sql/tests/connect/test_session.py @@ -19,6 +19,7 @@ import unittest from typing import Optional +from pyspark import InheritableThread, inheritable_thread_target from pyspark.sql.connect.client import ChannelBuilder from pyspark.sql.connect.session import SparkSession as RemoteSparkSession from pyspark.testing.connectutils import should_test_connect @@ -179,3 +180,47 @@ def func(itr): self.assertFalse( is_job_cancelled[i], "Thread {i}: Job in group B did not succeeded.".format(i=i) ) + + def test_inheritable_tags(self): + self.check_inheritable_tags( + create_thread=lambda target, session: InheritableThread(target, session=session) + ) + self.check_inheritable_tags( + create_thread=lambda target, session: threading.Thread( + target=inheritable_thread_target(session)(target) + ) + ) + + # Test decorator usage + @inheritable_thread_target(self.spark) + def func(target): + return target() + + self.check_inheritable_tags( + create_thread=lambda target, session: threading.Thread(target=func, args=(target,)) + ) + + def check_inheritable_tags(self, create_thread): + spark = self.spark + spark.addTag("a") + first = set() + second = set() + + def get_inner_local_prop(): + spark.addTag("c") + second.update(spark.getTags()) + + def get_outer_local_prop(): + spark.addTag("b") + first.update(spark.getTags()) + t2 = create_thread(target=get_inner_local_prop, session=spark) + t2.start() + t2.join() + + t1 = create_thread(target=get_outer_local_prop, session=spark) + t1.start() + t1.join() + + self.assertEqual(spark.getTags(), {"a"}) + self.assertEqual(first, {"a", "b"}) + self.assertEqual(second, {"a", "b", "c"}) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 5232c929e162..87f808549d1a 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -24,8 +24,9 @@ import sys import threading import traceback +import typing from types import TracebackType -from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple +from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple, Union from pyspark.errors import PySparkRuntimeError @@ -35,6 +36,9 @@ from py4j.java_gateway import JavaObject +if typing.TYPE_CHECKING: + from pyspark.sql import SparkSession + def print_exec(stream: TextIO) -> None: ei = sys.exc_info() @@ -279,50 +283,81 @@ def _parse_memory(s: str) -> int: return int(float(s[:-1]) * units[s[-1].lower()]) -def inheritable_thread_target(f: Callable) -> Callable: +def inheritable_thread_target(f: Optional[Union[Callable, "SparkSession"]] = None) -> Callable: """ Return thread target wrapper which is recommended to be used in PySpark when the pinned thread mode is enabled. The wrapper function, before calling original thread target, it inherits the inheritable properties specific - to JVM thread such as ``InheritableThreadLocal``. - - Also, note that pinned thread mode does not close the connection from Python - to JVM when the thread is finished in the Python side. With this wrapper, Python - garbage-collects the Python thread instance and also closes the connection - which finishes JVM thread correctly. + to JVM thread such as ``InheritableThreadLocal``, or thread local such as tags + with Spark Connect. When the pinned thread mode is off, it return the original ``f``. .. versionadded:: 3.2.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Parameters ---------- - f : function - the original thread target. + f : function, or :class:`SparkSession` + the original thread target, or :class:`SparkSession` if Spark Connect is being used. + See the examples below. Notes ----- This API is experimental. - It is important to know that it captures the local properties when you decorate it - whereas :class:`InheritableThread` captures when the thread is started. + It is important to know that it captures the local properties or tags when you + decorate it whereas :class:`InheritableThread` captures when the thread is started. Therefore, it is encouraged to decorate it when you want to capture the local properties. - For example, the local properties from the current Spark context is captured - when you define a function here instead of the invocation: + For example, the local properties or tags from the current Spark context or Spark + session is captured when you define a function here instead of the invocation: >>> @inheritable_thread_target ... def target_func(): ... pass # your codes. - If you have any updates on local properties afterwards, it would not be reflected to - the Spark context in ``target_func()``. + If you have any updates on local properties or tags afterwards, it would not be + reflected to the Spark context in ``target_func()``. The example below mimics the behavior of JVM threads as close as possible: >>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP + + If you're using Spark Connect, you should explicitly provide Spark session as follows: + + >>> @inheritable_thread_target(session) # doctest: +SKIP + ... def target_func(): + ... pass # your codes. + + >>> Thread(target=inheritable_thread_target(session)(target_func)).start() # doctest: +SKIP """ + from pyspark.sql import is_remote + + # Spark Connect + if is_remote(): + session = f + assert session is not None, "Spark Connect session must be provided." + + def outer(ff: Callable) -> Callable: + if not hasattr(session.client.thread_local, "tags"): # type: ignore[union-attr] + session.client.thread_local.tags = set() # type: ignore[union-attr] + tags = set(session.client.thread_local.tags) # type: ignore[union-attr] + + @functools.wraps(ff) + def inner(*args: Any, **kwargs: Any) -> Any: + # Set tags in child thread. + session.client.thread_local.tags = tags # type: ignore[union-attr] + return ff(*args, **kwargs) + + return inner + + return outer + + # Non Spark Connect from pyspark import SparkContext if isinstance(SparkContext._gateway, ClientServer): @@ -333,35 +368,35 @@ def inheritable_thread_target(f: Callable) -> Callable: # copies when the function is wrapped. assert SparkContext._active_spark_context is not None properties = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone() + assert callable(f) @functools.wraps(f) def wrapped(*args: Any, **kwargs: Any) -> Any: # Set local properties in child thread. assert SparkContext._active_spark_context is not None SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties) - return f(*args, **kwargs) + return f(*args, **kwargs) # type: ignore[misc, operator] return wrapped else: - return f + return f # type: ignore[return-value] class InheritableThread(threading.Thread): """ - Thread that is recommended to be used in PySpark instead of :class:`threading.Thread` - when the pinned thread mode is enabled. The usage of this class is exactly same as - :class:`threading.Thread` but correctly inherits the inheritable properties specific - to JVM thread such as ``InheritableThreadLocal``. - - Also, note that pinned thread mode does not close the connection from Python - to JVM when the thread is finished in the Python side. With this class, Python - garbage-collects the Python thread instance and also closes the connection - which finishes JVM thread correctly. + Thread that is recommended to be used in PySpark when the pinned thread mode is + enabled. The wrapper function, before calling original thread target, it + inherits the inheritable properties specific to JVM thread such as + ``InheritableThreadLocal``, or thread local such as tags + with Spark Connect. When the pinned thread mode is off, this works as :class:`threading.Thread`. .. versionadded:: 3.1.0 + .. versionchanged:: 3.5.0 + Supports Spark Connect. + Notes ----- This API is experimental. @@ -369,35 +404,67 @@ class InheritableThread(threading.Thread): _props: JavaObject - def __init__(self, target: Callable, *args: Any, **kwargs: Any): - from pyspark import SparkContext + def __init__( + self, target: Callable, *args: Any, session: Optional["SparkSession"] = None, **kwargs: Any + ): + from pyspark.sql import is_remote + + # Spark Connect + if is_remote(): + assert session is not None, "Spark Connect must be provided." + self._session = session - if isinstance(SparkContext._gateway, ClientServer): - # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. def copy_local_properties(*a: Any, **k: Any) -> Any: - # self._props is set before starting the thread to match the behavior with JVM. - assert hasattr(self, "_props") - assert SparkContext._active_spark_context is not None - SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props) + # Set tags in child thread. + assert hasattr(self, "_tags") + session.client.thread_local.tags = self._tags # type: ignore[union-attr, has-type] return target(*a, **k) super(InheritableThread, self).__init__( target=copy_local_properties, *args, **kwargs # type: ignore[misc] ) else: - super(InheritableThread, self).__init__( - target=target, *args, **kwargs # type: ignore[misc] - ) + # Non Spark Connect + from pyspark import SparkContext + + if isinstance(SparkContext._gateway, ClientServer): + # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. + def copy_local_properties(*a: Any, **k: Any) -> Any: + # self._props is set before starting the thread to match the behavior with JVM. + assert hasattr(self, "_props") + assert SparkContext._active_spark_context is not None + SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props) + return target(*a, **k) + + super(InheritableThread, self).__init__( + target=copy_local_properties, *args, **kwargs # type: ignore[misc] + ) + else: + super(InheritableThread, self).__init__( + target=target, *args, **kwargs # type: ignore[misc] + ) def start(self) -> None: - from pyspark import SparkContext + from pyspark.sql import is_remote + + if is_remote(): + # Spark Connect + assert hasattr(self, "_session") + if not hasattr(self._session.client.thread_local, "tags"): + self._session.client.thread_local.tags = set() + self._tags = set(self._session.client.thread_local.tags) + else: + # Non Spark Connect + from pyspark import SparkContext - if isinstance(SparkContext._gateway, ClientServer): - # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. + if isinstance(SparkContext._gateway, ClientServer): + # Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on. - # Local property copy should happen in Thread.start to mimic JVM's behavior. - assert SparkContext._active_spark_context is not None - self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone() + # Local property copy should happen in Thread.start to mimic JVM's behavior. + assert SparkContext._active_spark_context is not None + self._props = ( + SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone() + ) return super(InheritableThread, self).start()