Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions python/pyspark/sql/tests/connect/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
from typing import Optional

from pyspark import InheritableThread, inheritable_thread_target
from pyspark.sql.connect.client import ChannelBuilder
from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
from pyspark.testing.connectutils import should_test_connect
Expand Down Expand Up @@ -179,3 +180,47 @@ def func(itr):
self.assertFalse(
is_job_cancelled[i], "Thread {i}: Job in group B did not succeeded.".format(i=i)
)

def test_inheritable_tags(self):
self.check_inheritable_tags(
create_thread=lambda target, session: InheritableThread(target, session=session)
)
self.check_inheritable_tags(
create_thread=lambda target, session: threading.Thread(
target=inheritable_thread_target(session)(target)
)
)

# Test decorator usage
@inheritable_thread_target(self.spark)
def func(target):
return target()

self.check_inheritable_tags(
create_thread=lambda target, session: threading.Thread(target=func, args=(target,))
)

def check_inheritable_tags(self, create_thread):
spark = self.spark
spark.addTag("a")
first = set()
second = set()

def get_inner_local_prop():
spark.addTag("c")
second.update(spark.getTags())

def get_outer_local_prop():
spark.addTag("b")
first.update(spark.getTags())
t2 = create_thread(target=get_inner_local_prop, session=spark)
t2.start()
t2.join()

t1 = create_thread(target=get_outer_local_prop, session=spark)
t1.start()
t1.join()

self.assertEqual(spark.getTags(), {"a"})
self.assertEqual(first, {"a", "b"})
self.assertEqual(second, {"a", "b", "c"})
155 changes: 111 additions & 44 deletions python/pyspark/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
import sys
import threading
import traceback
import typing
from types import TracebackType
from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple
from typing import Any, Callable, Iterator, List, Optional, TextIO, Tuple, Union

from pyspark.errors import PySparkRuntimeError

Expand All @@ -35,6 +36,9 @@

from py4j.java_gateway import JavaObject

if typing.TYPE_CHECKING:
from pyspark.sql import SparkSession


def print_exec(stream: TextIO) -> None:
ei = sys.exc_info()
Expand Down Expand Up @@ -279,50 +283,81 @@ def _parse_memory(s: str) -> int:
return int(float(s[:-1]) * units[s[-1].lower()])


def inheritable_thread_target(f: Callable) -> Callable:
def inheritable_thread_target(f: Optional[Union[Callable, "SparkSession"]] = None) -> Callable:
"""
Return thread target wrapper which is recommended to be used in PySpark when the
pinned thread mode is enabled. The wrapper function, before calling original
thread target, it inherits the inheritable properties specific
to JVM thread such as ``InheritableThreadLocal``.

Also, note that pinned thread mode does not close the connection from Python
to JVM when the thread is finished in the Python side. With this wrapper, Python
garbage-collects the Python thread instance and also closes the connection
which finishes JVM thread correctly.
to JVM thread such as ``InheritableThreadLocal``, or thread local such as tags
with Spark Connect.

When the pinned thread mode is off, it return the original ``f``.

.. versionadded:: 3.2.0

.. versionchanged:: 3.5.0
Supports Spark Connect.

Parameters
----------
f : function
the original thread target.
f : function, or :class:`SparkSession`
the original thread target, or :class:`SparkSession` if Spark Connect is being used.
See the examples below.

Notes
-----
This API is experimental.

It is important to know that it captures the local properties when you decorate it
whereas :class:`InheritableThread` captures when the thread is started.
It is important to know that it captures the local properties or tags when you
decorate it whereas :class:`InheritableThread` captures when the thread is started.
Therefore, it is encouraged to decorate it when you want to capture the local
properties.

For example, the local properties from the current Spark context is captured
when you define a function here instead of the invocation:
For example, the local properties or tags from the current Spark context or Spark
session is captured when you define a function here instead of the invocation:

>>> @inheritable_thread_target
... def target_func():
... pass # your codes.

If you have any updates on local properties afterwards, it would not be reflected to
the Spark context in ``target_func()``.
If you have any updates on local properties or tags afterwards, it would not be
reflected to the Spark context in ``target_func()``.

The example below mimics the behavior of JVM threads as close as possible:

>>> Thread(target=inheritable_thread_target(target_func)).start() # doctest: +SKIP

If you're using Spark Connect, you should explicitly provide Spark session as follows:

>>> @inheritable_thread_target(session) # doctest: +SKIP
... def target_func():
... pass # your codes.

>>> Thread(target=inheritable_thread_target(session)(target_func)).start() # doctest: +SKIP
"""
from pyspark.sql import is_remote

# Spark Connect
if is_remote():
session = f
assert session is not None, "Spark Connect session must be provided."

def outer(ff: Callable) -> Callable:
if not hasattr(session.client.thread_local, "tags"): # type: ignore[union-attr]
session.client.thread_local.tags = set() # type: ignore[union-attr]
tags = set(session.client.thread_local.tags) # type: ignore[union-attr]

@functools.wraps(ff)
def inner(*args: Any, **kwargs: Any) -> Any:
# Set tags in child thread.
session.client.thread_local.tags = tags # type: ignore[union-attr]
return ff(*args, **kwargs)

return inner

return outer

# Non Spark Connect
from pyspark import SparkContext

if isinstance(SparkContext._gateway, ClientServer):
Expand All @@ -333,71 +368,103 @@ def inheritable_thread_target(f: Callable) -> Callable:
# copies when the function is wrapped.
assert SparkContext._active_spark_context is not None
properties = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
assert callable(f)

@functools.wraps(f)
def wrapped(*args: Any, **kwargs: Any) -> Any:
# Set local properties in child thread.
assert SparkContext._active_spark_context is not None
SparkContext._active_spark_context._jsc.sc().setLocalProperties(properties)
return f(*args, **kwargs)
return f(*args, **kwargs) # type: ignore[misc, operator]

return wrapped
else:
return f
return f # type: ignore[return-value]


class InheritableThread(threading.Thread):
"""
Thread that is recommended to be used in PySpark instead of :class:`threading.Thread`
when the pinned thread mode is enabled. The usage of this class is exactly same as
:class:`threading.Thread` but correctly inherits the inheritable properties specific
to JVM thread such as ``InheritableThreadLocal``.

Also, note that pinned thread mode does not close the connection from Python
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already fixed in Py4J, py4j/py4j#471

to JVM when the thread is finished in the Python side. With this class, Python
garbage-collects the Python thread instance and also closes the connection
which finishes JVM thread correctly.
Thread that is recommended to be used in PySpark when the pinned thread mode is
enabled. The wrapper function, before calling original thread target, it
inherits the inheritable properties specific to JVM thread such as
``InheritableThreadLocal``, or thread local such as tags
with Spark Connect.

When the pinned thread mode is off, this works as :class:`threading.Thread`.

.. versionadded:: 3.1.0

.. versionchanged:: 3.5.0
Supports Spark Connect.

Notes
-----
This API is experimental.
"""

_props: JavaObject

def __init__(self, target: Callable, *args: Any, **kwargs: Any):
from pyspark import SparkContext
def __init__(
self, target: Callable, *args: Any, session: Optional["SparkSession"] = None, **kwargs: Any
):
from pyspark.sql import is_remote

# Spark Connect
if is_remote():
assert session is not None, "Spark Connect must be provided."
self._session = session

if isinstance(SparkContext._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
def copy_local_properties(*a: Any, **k: Any) -> Any:
# self._props is set before starting the thread to match the behavior with JVM.
assert hasattr(self, "_props")
assert SparkContext._active_spark_context is not None
SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props)
# Set tags in child thread.
assert hasattr(self, "_tags")
session.client.thread_local.tags = self._tags # type: ignore[union-attr, has-type]
return target(*a, **k)

super(InheritableThread, self).__init__(
target=copy_local_properties, *args, **kwargs # type: ignore[misc]
)
else:
super(InheritableThread, self).__init__(
target=target, *args, **kwargs # type: ignore[misc]
)
# Non Spark Connect
from pyspark import SparkContext

if isinstance(SparkContext._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
def copy_local_properties(*a: Any, **k: Any) -> Any:
# self._props is set before starting the thread to match the behavior with JVM.
assert hasattr(self, "_props")
assert SparkContext._active_spark_context is not None
SparkContext._active_spark_context._jsc.sc().setLocalProperties(self._props)
return target(*a, **k)

super(InheritableThread, self).__init__(
target=copy_local_properties, *args, **kwargs # type: ignore[misc]
)
else:
super(InheritableThread, self).__init__(
target=target, *args, **kwargs # type: ignore[misc]
)

def start(self) -> None:
from pyspark import SparkContext
from pyspark.sql import is_remote

if is_remote():
# Spark Connect
assert hasattr(self, "_session")
if not hasattr(self._session.client.thread_local, "tags"):
self._session.client.thread_local.tags = set()
self._tags = set(self._session.client.thread_local.tags)
else:
# Non Spark Connect
from pyspark import SparkContext

if isinstance(SparkContext._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.
if isinstance(SparkContext._gateway, ClientServer):
# Here's when the pinned-thread mode (PYSPARK_PIN_THREAD) is on.

# Local property copy should happen in Thread.start to mimic JVM's behavior.
assert SparkContext._active_spark_context is not None
self._props = SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
# Local property copy should happen in Thread.start to mimic JVM's behavior.
assert SparkContext._active_spark_context is not None
self._props = (
SparkContext._active_spark_context._jsc.sc().getLocalProperties().clone()
)
return super(InheritableThread, self).start()


Expand Down