Skip to content
59 changes: 58 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,9 @@ def __init__(
# cleanup ml cache if possible
atexit.register(self._cleanup_ml_cache)

self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = []
self.global_user_context_extensions_lock = threading.Lock()

@property
def _stub(self) -> grpc_lib.SparkConnectServiceStub:
if self.is_closed:
Expand Down Expand Up @@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]:
"""
return self._builder.token

def _update_request_with_user_context_extensions(
self,
req: Union[
pb2.AnalyzePlanRequest,
pb2.ConfigRequest,
pb2.ExecutePlanRequest,
pb2.FetchErrorDetailsRequest,
pb2.InterruptRequest,
],
) -> None:
with self.global_user_context_extensions_lock:
for _, extension in self.global_user_context_extensions:
req.user_context.extensions.append(extension)
if not hasattr(self.thread_local, "user_context_extensions"):
return
for _, extension in self.thread_local.user_context_extensions:
req.user_context.extensions.append(extension)

def _execute_plan_request_with_metadata(
self, operation_id: Optional[str] = None
) -> pb2.ExecutePlanRequest:
Expand Down Expand Up @@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata(
messageParameters={"arg_name": "operation_id", "origin": str(ve)},
)
req.operation_id = operation_id
self._update_request_with_user_context_extensions(req)
return req

def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
Expand All @@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
return req

def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult:
Expand Down Expand Up @@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest:
req.client_type = self._builder.userAgent
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
return req

def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]:
Expand Down Expand Up @@ -1807,6 +1831,7 @@ def _interrupt_request(
)
if self._user_id:
req.user_context.user_id = self._user_id
self._update_request_with_user_context_extensions(req)
return req

def interrupt_all(self) -> Optional[List[str]]:
Expand Down Expand Up @@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None:
messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag},
)

def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str:
if not hasattr(self.thread_local, "user_context_extensions"):
self.thread_local.user_context_extensions = list()
extension_id = "threadlocal_" + str(uuid.uuid4())
self.thread_local.user_context_extensions.append((extension_id, extension))
return extension_id

def add_global_user_context_extension(self, extension: any_pb2.Any) -> str:
extension_id = "global_" + str(uuid.uuid4())
with self.global_user_context_extensions_lock:
self.global_user_context_extensions.append((extension_id, extension))
return extension_id

def remove_user_context_extension(self, extension_id: str) -> None:
if extension_id.find("threadlocal_") == 0:
if not hasattr(self.thread_local, "user_context_extensions"):
return
self.thread_local.user_context_extensions = list(
filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions)
)
elif extension_id.find("global_") == 0:
with self.global_user_context_extensions_lock:
self.global_user_context_extensions = list(
filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions)
)

def clear_user_context_extensions(self) -> None:
if hasattr(self.thread_local, "user_context_extensions"):
self.thread_local.user_context_extensions = list()
with self.global_user_context_extensions_lock:
self.global_user_context_extensions = list()

def _handle_error(self, error: Exception) -> NoReturn:
"""
Handle errors that occur during RPC calls.
Expand Down Expand Up @@ -1945,7 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet
req.client_observed_server_side_session_id = self._server_session_id
if self._user_id:
req.user_context.user_id = self._user_id

self._update_request_with_user_context_extensions(req)
try:
return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata())
except grpc.RpcError:
Expand Down
104 changes: 104 additions & 0 deletions python/pyspark/sql/tests/connect/client/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
if should_test_connect:
import grpc
import google.protobuf.any_pb2 as any_pb2
import google.protobuf.wrappers_pb2 as wrappers_pb2
from google.rpc import status_pb2
from google.rpc.error_details_pb2 import ErrorInfo
import pandas as pd
Expand Down Expand Up @@ -136,9 +137,11 @@ class MockService:
def __init__(self, session_id: str):
self._session_id = session_id
self.req = None
self.client_user_context_extensions = []

def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):
self.req = req
# self.client_user_context_extensions = req.user_context.extensions
resp = proto.ExecutePlanResponse()
resp.session_id = self._session_id
resp.operation_id = req.operation_id
Expand All @@ -159,12 +162,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata):

def Interrupt(self, req: proto.InterruptRequest, metadata):
self.req = req
# self.client_user_context_extensions = req.user_context.extensions
resp = proto.InterruptResponse()
resp.session_id = self._session_id
return resp

def Config(self, req: proto.ConfigRequest, metadata):
self.req = req
# self.client_user_context_extensions = req.user_context.extensions
resp = proto.ConfigResponse()
resp.session_id = self._session_id
if req.operation.HasField("get"):
Expand All @@ -177,6 +182,15 @@ def Config(self, req: proto.ConfigRequest, metadata):
pair.value = req.operation.get_with_default.pairs[0].value or "true"
return resp

def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata):
self.req = req
# self.client_user_context_extensions = req.user_context.extensions
resp = proto.AnalyzePlanResponse()
resp.session_id = self._session_id
# Return a minimal response with a semantic hash
resp.semantic_hash.result = 12345
return resp

# The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster)
# and it blocks the test process exiting because it is registered as the atexit handler
# in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test.
Expand Down Expand Up @@ -229,6 +243,96 @@ def userId(self) -> Optional[str]:

self.assertEqual(client._user_id, "abc")

def test_user_context_extension(self):
client = SparkConnectClient("sc://foo/", use_reattachable_execute=False)
mock = MockService(client._session_id)
client._stub = mock

try:
exlocal = any_pb2.Any()
exlocal.Pack(wrappers_pb2.StringValue(value="abc"))
exlocal2 = any_pb2.Any()
exlocal2.Pack(wrappers_pb2.StringValue(value="def"))
exglobal = any_pb2.Any()
exglobal.Pack(wrappers_pb2.StringValue(value="ghi"))
exglobal2 = any_pb2.Any()
exglobal2.Pack(wrappers_pb2.StringValue(value="jkl"))

exlocal_id = client.add_threadlocal_user_context_extension(exlocal)
exglobal_id = client.add_global_user_context_extension(exglobal)

mock.client_user_context_extensions = []
command = proto.Command()
client.execute_command(command)
self.assertTrue(exlocal in mock.client_user_context_extensions)
self.assertTrue(exglobal in mock.client_user_context_extensions)
self.assertFalse(exlocal2 in mock.client_user_context_extensions)
self.assertFalse(exglobal2 in mock.client_user_context_extensions)

client.add_threadlocal_user_context_extension(exlocal2)

mock.client_user_context_extensions = []
plan = proto.Plan()
client.semantic_hash(plan) # use semantic_hash to test analyze
self.assertTrue(exlocal in mock.client_user_context_extensions)
self.assertTrue(exglobal in mock.client_user_context_extensions)
self.assertTrue(exlocal2 in mock.client_user_context_extensions)
self.assertFalse(exglobal2 in mock.client_user_context_extensions)

client.add_global_user_context_extension(exglobal2)

mock.client_user_context_extensions = []
client.interrupt_all()
self.assertTrue(exlocal in mock.client_user_context_extensions)
self.assertTrue(exglobal in mock.client_user_context_extensions)
self.assertTrue(exlocal2 in mock.client_user_context_extensions)
self.assertTrue(exglobal2 in mock.client_user_context_extensions)

client.remove_user_context_extension(exlocal_id)

mock.client_user_context_extensions = []
client.get_configs("foo", "bar")
self.assertFalse(exlocal in mock.client_user_context_extensions)
self.assertTrue(exglobal in mock.client_user_context_extensions)
self.assertTrue(exlocal2 in mock.client_user_context_extensions)
self.assertTrue(exglobal2 in mock.client_user_context_extensions)

client.remove_user_context_extension(exglobal_id)

mock.client_user_context_extensions = []
command = proto.Command()
client.execute_command(command)
self.assertFalse(exlocal in mock.client_user_context_extensions)
self.assertFalse(exglobal in mock.client_user_context_extensions)
self.assertTrue(exlocal2 in mock.client_user_context_extensions)
self.assertTrue(exglobal2 in mock.client_user_context_extensions)

client.clear_user_context_extensions()

mock.client_user_context_extensions = []
plan = proto.Plan()
client.semantic_hash(plan) # use semantic_hash to test analyze
self.assertFalse(exlocal in mock.client_user_context_extensions)
self.assertFalse(exglobal in mock.client_user_context_extensions)
self.assertFalse(exlocal2 in mock.client_user_context_extensions)
self.assertFalse(exglobal2 in mock.client_user_context_extensions)

mock.client_user_context_extensions = []
client.interrupt_all()
self.assertFalse(exlocal in mock.client_user_context_extensions)
self.assertFalse(exglobal in mock.client_user_context_extensions)
self.assertFalse(exlocal2 in mock.client_user_context_extensions)
self.assertFalse(exglobal2 in mock.client_user_context_extensions)

mock.client_user_context_extensions = []
client.get_configs("foo", "bar")
self.assertFalse(exlocal in mock.client_user_context_extensions)
self.assertFalse(exglobal in mock.client_user_context_extensions)
self.assertFalse(exlocal2 in mock.client_user_context_extensions)
self.assertFalse(exglobal2 in mock.client_user_context_extensions)
finally:
client.close()

def test_interrupt_all(self):
client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False)
mock = MockService(client._session_id)
Expand Down