diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2a2ac0e6b5399..48e07642e1574 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -727,6 +727,9 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) + self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = [] + self.global_user_context_extensions_lock = threading.Lock() + @property def _stub(self) -> grpc_lib.SparkConnectServiceStub: if self.is_closed: @@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]: """ return self._builder.token + def _update_request_with_user_context_extensions( + self, + req: Union[ + pb2.AnalyzePlanRequest, + pb2.ConfigRequest, + pb2.ExecutePlanRequest, + pb2.FetchErrorDetailsRequest, + pb2.InterruptRequest, + ], + ) -> None: + with self.global_user_context_extensions_lock: + for _, extension in self.global_user_context_extensions: + req.user_context.extensions.append(extension) + if not hasattr(self.thread_local, "user_context_extensions"): + return + for _, extension in self.thread_local.user_context_extensions: + req.user_context.extensions.append(extension) + def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: @@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata( messageParameters={"arg_name": "operation_id", "origin": str(ve)}, ) req.operation_id = operation_id + self._update_request_with_user_context_extensions(req) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: @@ -1807,6 +1831,7 @@ def _interrupt_request( ) if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def interrupt_all(self) -> Optional[List[str]]: @@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None: messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag}, ) + def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str: + if not hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + extension_id = "threadlocal_" + str(uuid.uuid4()) + self.thread_local.user_context_extensions.append((extension_id, extension)) + return extension_id + + def add_global_user_context_extension(self, extension: any_pb2.Any) -> str: + extension_id = "global_" + str(uuid.uuid4()) + with self.global_user_context_extensions_lock: + self.global_user_context_extensions.append((extension_id, extension)) + return extension_id + + def remove_user_context_extension(self, extension_id: str) -> None: + if extension_id.find("threadlocal_") == 0: + if not hasattr(self.thread_local, "user_context_extensions"): + return + self.thread_local.user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions) + ) + elif extension_id.find("global_") == 0: + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions) + ) + + def clear_user_context_extensions(self) -> None: + if hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list() + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. @@ -1945,7 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.client_observed_server_side_session_id = self._server_session_id if self._user_id: req.user_context.user_id = self._user_id - + self._update_request_with_user_context_extensions(req) try: return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) except grpc.RpcError: diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c189f996cbe43..d842d045bc053 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -26,6 +26,7 @@ if should_test_connect: import grpc import google.protobuf.any_pb2 as any_pb2 + import google.protobuf.wrappers_pb2 as wrappers_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd @@ -136,9 +137,11 @@ class MockService: def __init__(self, session_id: str): self._session_id = session_id self.req = None + self.client_user_context_extensions = [] def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req + # self.client_user_context_extensions = req.user_context.extensions resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -159,12 +162,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req + # self.client_user_context_extensions = req.user_context.extensions resp = proto.InterruptResponse() resp.session_id = self._session_id return resp def Config(self, req: proto.ConfigRequest, metadata): self.req = req + # self.client_user_context_extensions = req.user_context.extensions resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): @@ -177,6 +182,15 @@ def Config(self, req: proto.ConfigRequest, metadata): pair.value = req.operation.get_with_default.pairs[0].value or "true" return resp + def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): + self.req = req + # self.client_user_context_extensions = req.user_context.extensions + resp = proto.AnalyzePlanResponse() + resp.session_id = self._session_id + # Return a minimal response with a semantic hash + resp.semantic_hash.result = 12345 + return resp + # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) # and it blocks the test process exiting because it is registered as the atexit handler # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. @@ -229,6 +243,96 @@ def userId(self) -> Optional[str]: self.assertEqual(client._user_id, "abc") + def test_user_context_extension(self): + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) + mock = MockService(client._session_id) + client._stub = mock + + try: + exlocal = any_pb2.Any() + exlocal.Pack(wrappers_pb2.StringValue(value="abc")) + exlocal2 = any_pb2.Any() + exlocal2.Pack(wrappers_pb2.StringValue(value="def")) + exglobal = any_pb2.Any() + exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) + exglobal2 = any_pb2.Any() + exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) + + exlocal_id = client.add_threadlocal_user_context_extension(exlocal) + exglobal_id = client.add_global_user_context_extension(exglobal) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_threadlocal_user_context_extension(exlocal2) + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_global_user_context_extension(exglobal2) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exlocal_id) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exglobal_id) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.clear_user_context_extensions() + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + finally: + client.close() + def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id)