From 35631f4fc90fcf715396aa030728f42b4083786f Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Wed, 12 Nov 2025 10:54:00 -0800 Subject: [PATCH 01/11] port --- python/pyspark/sql/connect/client/core.py | 70 +++++++++++++- python/pyspark/sql/connect/session.py | 58 ++++++++++++ .../sql/tests/connect/client/test_client.py | 94 +++++++++++++++++++ 3 files changed, 221 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index d0d191dbd7fde..f90f90b08e372 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -147,6 +147,9 @@ def __init__( for key, value in channelOptions: self.setChannelOption(key, value) + self.global_user_context_extensions = [] # EDGE + self.global_user_context_extensions_lock = threading.Lock() # EDGE + def get(self, key: str) -> Any: """ Parameters @@ -225,6 +228,26 @@ def token(self) -> Optional[str]: ChannelBuilder.PARAM_TOKEN, os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN") ) + # BEGIN-EDGE + def _update_request_with_user_context_extensions( + self, + req: Union[ + pb2.AnalyzePlanRequest, + pb2.ConfigRequest, + pb2.ExecutePlanRequest, + pb2.FetchErrorDetailsRequest, + pb2.InterruptRequest, + ], + ) -> None: + with self.global_user_context_extensions_lock: + for _, extension in self.global_user_context_extensions: + req.user_context.extensions.append(extension) + if not hasattr(self.thread_local, "user_context_extensions"): + return + for _, extension in self.thread_local.user_context_extensions: + req.user_context.extensions.append(extension) + # END-EDGE + def metadata(self) -> Iterable[Tuple[str, str]]: """ Builds the GRPC specific metadata list to be injected into the request. All @@ -1270,6 +1293,9 @@ def _execute_plan_request_with_metadata( messageParameters={"arg_name": "operation_id", "origin": str(ve)}, ) req.operation_id = operation_id + # BEGIN-EDGE + self._update_request_with_user_context_extensions(req) + # END-EDGE return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1280,6 +1306,9 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + # BEGIN-EDGE + self._update_request_with_user_context_extensions(req) + # END-EDGE return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1694,6 +1723,9 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + # BEGIN-EDGE + self._update_request_with_user_context_extensions(req) + # END-EDGE return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: @@ -1770,6 +1802,7 @@ def _interrupt_request( ) if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def interrupt_all(self) -> Optional[List[str]]: @@ -1868,6 +1901,39 @@ def _throw_if_invalid_tag(self, tag: str) -> None: messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag}, ) + # BEGIN-EDGE + def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str: + if not hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + extension_id = "threadlocal_" + str(uuid.uuid4()) + self.thread_local.user_context_extensions.append((extension_id, extension)) + return extension_id + + def add_global_user_context_extension(self, extension: any_pb2.Any) -> str: + extension_id = "global_" + str(uuid.uuid4()) + with self.global_user_context_extensions_lock: + self.global_user_context_extensions.append((extension_id, extension)) + return extension_id + + def remove_user_context_extension(self, extension_id: str) -> None: + if extension_id.find("threadlocal_") == 0: + if not hasattr(self.thread_local, "user_context_extensions"): + return + self.thread_local.user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions) + ) + elif extension_id.find("global_") == 0: + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions) + ) + + def clear_user_context_extensions(self) -> None: + if hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list() + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. @@ -1908,7 +1974,9 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.client_observed_server_side_session_id = self._server_session_id if self._user_id: req.user_context.user_id = self._user_id - + # BEGIN-EDGE + self._update_request_with_user_context_extensions(req) + # END-EDGE try: return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) except grpc.RpcError: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index f759137fac1d2..ee5a51888424b 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -42,6 +42,7 @@ ClassVar, ) +import google.protobuf.any_pb2 as any_pb2 # EDGE import numpy as np import pandas as pd import pyarrow as pa @@ -894,6 +895,63 @@ def clearTags(self) -> None: clearTags.__doc__ = PySparkSession.clearTags.__doc__ + # BEGIN-EDGE + def addThreadlocalUserContextExtension(self, extension: any_pb2.Any) -> str: + """ + Add a user context extension to the current session in the current thread. + It will be sent in the UserContext of every request sent from the current thread, until + it is removed with removeUserContextExtension using the returned id. + + Parameters + ---------- + extension: any_pb2.Any + Protobuf Any message to add as the extension to UserContext. + + Returns + ------- + str + Id that can be used with removeUserContextExtension to remove the extension. + """ + return self.client.add_threadlocal_user_context_extension(extension) + + def addGlobalUserContextExtension(self, extension: any_pb2.Any) -> str: + """ + Add a user context extension to the current session, globally. + It will be sent in the UserContext of every request, until it is removed with + removeUserContextExtension using the returned id. It will precede any threadlocal extension. + + Parameters + ---------- + extension: any_pb2.Any + Protobuf Any message to add as the extension to UserContext. + + Returns + ------- + str + Id that can be used with removeUserContextExtension to remove the extension. + """ + return self.client.add_global_user_context_extension(extension) + + def removeUserContextExtension(self, extension_id: str) -> None: + """ + Remove a user context extension previously added by addUserContextExtension. + + Parameters + ---------- + extension_id: str + id returned by addUserContextExtension. + """ + self.client.remove_user_context_extension(extension_id) + + def clearUserContextExtensions(self) -> None: + """ + Clear all user context extensions previously added by addGlobalUserContextExtension and + addThreadlocalUserContextExtension + """ + self.client.clear_user_context_extensions() + + # END-EDGE + def stop(self) -> None: """ Release the current session and close the GRPC connection to the Spark Connect server. diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c189f996cbe43..4bf9bf59537fe 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -136,9 +136,11 @@ class MockService: def __init__(self, session_id: str): self._session_id = session_id self.req = None + self.client_user_context_extensions = [] # EDGE def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions # EDGE resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -159,12 +161,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions # EDGE resp = proto.InterruptResponse() resp.session_id = self._session_id return resp def Config(self, req: proto.ConfigRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions # EDGE resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): @@ -229,6 +233,96 @@ def userId(self) -> Optional[str]: self.assertEqual(client._user_id, "abc") + # BEGIN-EDGE + def test_user_context_extension(self): + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) + mock = MockService(client._session_id) + client._stub = mock + + exlocal = any_pb2.Any() + exlocal.Pack(wrappers_pb2.StringValue(value="abc")) + exlocal2 = any_pb2.Any() + exlocal2.Pack(wrappers_pb2.StringValue(value="def")) + exglobal = any_pb2.Any() + exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) + exglobal2 = any_pb2.Any() + exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) + + exlocal_id = client.add_threadlocal_user_context_extension(exlocal) + exglobal_id = client.add_global_user_context_extension(exglobal) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_threadlocal_user_context_extension(exlocal2) + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_global_user_context_extension(exglobal2) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exlocal_id) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exglobal_id) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.clear_user_context_extensions() + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + # END-EDGE + def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) From c9a07a8e639b08c7854911c453f49e09ddaa7cb9 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Wed, 12 Nov 2025 10:57:56 -0800 Subject: [PATCH 02/11] nit --- python/pyspark/sql/connect/client/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index f90f90b08e372..4a6ed4c8ba039 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -1934,6 +1934,8 @@ def clear_user_context_extensions(self) -> None: with self.global_user_context_extensions_lock: self.global_user_context_extensions = list() + # END-EDGE + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. From e51b655fbd6e89379b3341acb15575d328d3278d Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Wed, 12 Nov 2025 11:45:37 -0800 Subject: [PATCH 03/11] move fun to client --- python/pyspark/sql/connect/client/core.py | 41 ++++++++++++----------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 4a6ed4c8ba039..5ddf11dc64f8b 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -228,26 +228,6 @@ def token(self) -> Optional[str]: ChannelBuilder.PARAM_TOKEN, os.environ.get("SPARK_CONNECT_AUTHENTICATE_TOKEN") ) - # BEGIN-EDGE - def _update_request_with_user_context_extensions( - self, - req: Union[ - pb2.AnalyzePlanRequest, - pb2.ConfigRequest, - pb2.ExecutePlanRequest, - pb2.FetchErrorDetailsRequest, - pb2.InterruptRequest, - ], - ) -> None: - with self.global_user_context_extensions_lock: - for _, extension in self.global_user_context_extensions: - req.user_context.extensions.append(extension) - if not hasattr(self.thread_local, "user_context_extensions"): - return - for _, extension in self.thread_local.user_context_extensions: - req.user_context.extensions.append(extension) - # END-EDGE - def metadata(self) -> Iterable[Tuple[str, str]]: """ Builds the GRPC specific metadata list to be injected into the request. All @@ -1263,6 +1243,27 @@ def token(self) -> Optional[str]: """ return self._builder.token + # BEGIN-EDGE + def _update_request_with_user_context_extensions( + self, + req: Union[ + pb2.AnalyzePlanRequest, + pb2.ConfigRequest, + pb2.ExecutePlanRequest, + pb2.FetchErrorDetailsRequest, + pb2.InterruptRequest, + ], + ) -> None: + with self.global_user_context_extensions_lock: + for _, extension in self.global_user_context_extensions: + req.user_context.extensions.append(extension) + if not hasattr(self.thread_local, "user_context_extensions"): + return + for _, extension in self.thread_local.user_context_extensions: + req.user_context.extensions.append(extension) + + # END-EDGE + def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: From 30abf519c59f5440dccbbb75159f2028eb78c55a Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Wed, 12 Nov 2025 11:52:30 -0800 Subject: [PATCH 04/11] move fun to client --- python/pyspark/sql/connect/client/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 5ddf11dc64f8b..c72a478d43871 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -147,9 +147,6 @@ def __init__( for key, value in channelOptions: self.setChannelOption(key, value) - self.global_user_context_extensions = [] # EDGE - self.global_user_context_extensions_lock = threading.Lock() # EDGE - def get(self, key: str) -> Any: """ Parameters @@ -712,6 +709,9 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) + self.global_user_context_extensions = [] # EDGE + self.global_user_context_extensions_lock = threading.Lock() # EDGE + @property def _stub(self) -> grpc_lib.SparkConnectServiceStub: if self.is_closed: From 56903b8be4247fe3221d11cad4e1eb35d6ac9734 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Wed, 12 Nov 2025 13:24:30 -0800 Subject: [PATCH 05/11] remove unnecessary keyword --- python/pyspark/sql/connect/client/core.py | 18 ++---------------- python/pyspark/sql/connect/session.py | 5 +---- .../sql/tests/connect/client/test_client.py | 11 ++++------- 3 files changed, 7 insertions(+), 27 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index c72a478d43871..58c2bd7bd8d01 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -709,8 +709,8 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) - self.global_user_context_extensions = [] # EDGE - self.global_user_context_extensions_lock = threading.Lock() # EDGE + self.global_user_context_extensions = [] + self.global_user_context_extensions_lock = threading.Lock() @property def _stub(self) -> grpc_lib.SparkConnectServiceStub: @@ -1243,7 +1243,6 @@ def token(self) -> Optional[str]: """ return self._builder.token - # BEGIN-EDGE def _update_request_with_user_context_extensions( self, req: Union[ @@ -1262,8 +1261,6 @@ def _update_request_with_user_context_extensions( for _, extension in self.thread_local.user_context_extensions: req.user_context.extensions.append(extension) - # END-EDGE - def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: @@ -1294,9 +1291,7 @@ def _execute_plan_request_with_metadata( messageParameters={"arg_name": "operation_id", "origin": str(ve)}, ) req.operation_id = operation_id - # BEGIN-EDGE self._update_request_with_user_context_extensions(req) - # END-EDGE return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1307,9 +1302,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id - # BEGIN-EDGE self._update_request_with_user_context_extensions(req) - # END-EDGE return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1724,9 +1717,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id - # BEGIN-EDGE self._update_request_with_user_context_extensions(req) - # END-EDGE return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: @@ -1902,7 +1893,6 @@ def _throw_if_invalid_tag(self, tag: str) -> None: messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag}, ) - # BEGIN-EDGE def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str: if not hasattr(self.thread_local, "user_context_extensions"): self.thread_local.user_context_extensions = list() @@ -1935,8 +1925,6 @@ def clear_user_context_extensions(self) -> None: with self.global_user_context_extensions_lock: self.global_user_context_extensions = list() - # END-EDGE - def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. @@ -1977,9 +1965,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.client_observed_server_side_session_id = self._server_session_id if self._user_id: req.user_context.user_id = self._user_id - # BEGIN-EDGE self._update_request_with_user_context_extensions(req) - # END-EDGE try: return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) except grpc.RpcError: diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index ee5a51888424b..6be7c2a7e1ec8 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -42,7 +42,7 @@ ClassVar, ) -import google.protobuf.any_pb2 as any_pb2 # EDGE +import google.protobuf.any_pb2 as any_pb2 import numpy as np import pandas as pd import pyarrow as pa @@ -895,7 +895,6 @@ def clearTags(self) -> None: clearTags.__doc__ = PySparkSession.clearTags.__doc__ - # BEGIN-EDGE def addThreadlocalUserContextExtension(self, extension: any_pb2.Any) -> str: """ Add a user context extension to the current session in the current thread. @@ -950,8 +949,6 @@ def clearUserContextExtensions(self) -> None: """ self.client.clear_user_context_extensions() - # END-EDGE - def stop(self) -> None: """ Release the current session and close the GRPC connection to the Spark Connect server. diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 4bf9bf59537fe..2006db03b95a2 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -136,11 +136,11 @@ class MockService: def __init__(self, session_id: str): self._session_id = session_id self.req = None - self.client_user_context_extensions = [] # EDGE + self.client_user_context_extensions = [] def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions # EDGE + self.client_user_context_extensions = req.user_context.extensions resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -161,14 +161,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions # EDGE + self.client_user_context_extensions = req.user_context.extensions resp = proto.InterruptResponse() resp.session_id = self._session_id return resp def Config(self, req: proto.ConfigRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions # EDGE + self.client_user_context_extensions = req.user_context.extensions resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): @@ -233,7 +233,6 @@ def userId(self) -> Optional[str]: self.assertEqual(client._user_id, "abc") - # BEGIN-EDGE def test_user_context_extension(self): client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) mock = MockService(client._session_id) @@ -321,8 +320,6 @@ def test_user_context_extension(self): self.assertFalse(exlocal2 in mock.client_user_context_extensions) self.assertFalse(exglobal2 in mock.client_user_context_extensions) - # END-EDGE - def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) From c615e24437ad2df07dc25704ca9b756f06b42f2a Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Wed, 12 Nov 2025 14:22:29 -0800 Subject: [PATCH 06/11] nit: update wording --- python/pyspark/sql/connect/session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 6be7c2a7e1ec8..4a59b9fc98cc6 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -933,12 +933,12 @@ def addGlobalUserContextExtension(self, extension: any_pb2.Any) -> str: def removeUserContextExtension(self, extension_id: str) -> None: """ - Remove a user context extension previously added by addUserContextExtension. + Remove a user context extension previously added by addThreadlocalUserContextExtension. Parameters ---------- extension_id: str - id returned by addUserContextExtension. + id returned by addThreadlocalUserContextExtension. """ self.client.remove_user_context_extension(extension_id) From 32c28f213aa0bae4c095314d1918c6bd2212b940 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 09:37:21 -0800 Subject: [PATCH 07/11] remove extra api --- python/pyspark/sql/connect/session.py | 55 --------------------------- 1 file changed, 55 deletions(-) diff --git a/python/pyspark/sql/connect/session.py b/python/pyspark/sql/connect/session.py index 4a59b9fc98cc6..f759137fac1d2 100644 --- a/python/pyspark/sql/connect/session.py +++ b/python/pyspark/sql/connect/session.py @@ -42,7 +42,6 @@ ClassVar, ) -import google.protobuf.any_pb2 as any_pb2 import numpy as np import pandas as pd import pyarrow as pa @@ -895,60 +894,6 @@ def clearTags(self) -> None: clearTags.__doc__ = PySparkSession.clearTags.__doc__ - def addThreadlocalUserContextExtension(self, extension: any_pb2.Any) -> str: - """ - Add a user context extension to the current session in the current thread. - It will be sent in the UserContext of every request sent from the current thread, until - it is removed with removeUserContextExtension using the returned id. - - Parameters - ---------- - extension: any_pb2.Any - Protobuf Any message to add as the extension to UserContext. - - Returns - ------- - str - Id that can be used with removeUserContextExtension to remove the extension. - """ - return self.client.add_threadlocal_user_context_extension(extension) - - def addGlobalUserContextExtension(self, extension: any_pb2.Any) -> str: - """ - Add a user context extension to the current session, globally. - It will be sent in the UserContext of every request, until it is removed with - removeUserContextExtension using the returned id. It will precede any threadlocal extension. - - Parameters - ---------- - extension: any_pb2.Any - Protobuf Any message to add as the extension to UserContext. - - Returns - ------- - str - Id that can be used with removeUserContextExtension to remove the extension. - """ - return self.client.add_global_user_context_extension(extension) - - def removeUserContextExtension(self, extension_id: str) -> None: - """ - Remove a user context extension previously added by addThreadlocalUserContextExtension. - - Parameters - ---------- - extension_id: str - id returned by addThreadlocalUserContextExtension. - """ - self.client.remove_user_context_extension(extension_id) - - def clearUserContextExtensions(self) -> None: - """ - Clear all user context extensions previously added by addGlobalUserContextExtension and - addThreadlocalUserContextExtension - """ - self.client.clear_user_context_extensions() - def stop(self) -> None: """ Release the current session and close the GRPC connection to the Spark Connect server. From bd8b1ed5fd03fd78bf54046820cdfd73c065479f Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 10:42:57 -0800 Subject: [PATCH 08/11] nit --- python/pyspark/sql/tests/connect/client/test_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 2006db03b95a2..91accf17e721e 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -20,6 +20,8 @@ from collections.abc import Generator from typing import Optional, Any, Union +import google.protobuf.wrappers_pb2 as wrappers_pb2 + from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import eventually From 203e56a8adee8634bb94cd6a4e0ca38762e5bf5e Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 13:05:25 -0800 Subject: [PATCH 09/11] fix failed checks --- python/pyspark/sql/connect/client/core.py | 2 +- .../pyspark/sql/tests/connect/client/test_client.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 3d706e3187b00..48e07642e1574 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -727,7 +727,7 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) - self.global_user_context_extensions = [] + self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = [] self.global_user_context_extensions_lock = threading.Lock() @property diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 91accf17e721e..189553bee75ef 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -20,14 +20,13 @@ from collections.abc import Generator from typing import Optional, Any, Union -import google.protobuf.wrappers_pb2 as wrappers_pb2 - from pyspark.testing.connectutils import should_test_connect, connect_requirement_message from pyspark.testing.utils import eventually if should_test_connect: import grpc import google.protobuf.any_pb2 as any_pb2 + import google.protobuf.wrappers_pb2 as wrappers_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd @@ -183,6 +182,15 @@ def Config(self, req: proto.ConfigRequest, metadata): pair.value = req.operation.get_with_default.pairs[0].value or "true" return resp + def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): + self.req = req + self.client_user_context_extensions = req.user_context.extensions + resp = proto.AnalyzePlanResponse() + resp.session_id = self._session_id + # Return a minimal response with a semantic hash + resp.semantic_hash.result = 12345 + return resp + # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) # and it blocks the test process exiting because it is registered as the atexit handler # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. From 60d99493b200d3e36239f9976bd10cdc32768430 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Thu, 13 Nov 2025 19:37:22 -0800 Subject: [PATCH 10/11] fix memory leak --- .../sql/tests/connect/client/test_client.py | 165 +++++++++--------- 1 file changed, 84 insertions(+), 81 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 189553bee75ef..bf7108c94b090 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -248,87 +248,90 @@ def test_user_context_extension(self): mock = MockService(client._session_id) client._stub = mock - exlocal = any_pb2.Any() - exlocal.Pack(wrappers_pb2.StringValue(value="abc")) - exlocal2 = any_pb2.Any() - exlocal2.Pack(wrappers_pb2.StringValue(value="def")) - exglobal = any_pb2.Any() - exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) - exglobal2 = any_pb2.Any() - exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) - - exlocal_id = client.add_threadlocal_user_context_extension(exlocal) - exglobal_id = client.add_global_user_context_extension(exglobal) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_threadlocal_user_context_extension(exlocal2) - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_global_user_context_extension(exglobal2) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exlocal_id) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exglobal_id) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.clear_user_context_extensions() - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) + try: + exlocal = any_pb2.Any() + exlocal.Pack(wrappers_pb2.StringValue(value="abc")) + exlocal2 = any_pb2.Any() + exlocal2.Pack(wrappers_pb2.StringValue(value="def")) + exglobal = any_pb2.Any() + exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) + exglobal2 = any_pb2.Any() + exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) + + exlocal_id = client.add_threadlocal_user_context_extension(exlocal) + exglobal_id = client.add_global_user_context_extension(exglobal) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_threadlocal_user_context_extension(exlocal2) + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_global_user_context_extension(exglobal2) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exlocal_id) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exglobal_id) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.clear_user_context_extensions() + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + finally: + client.close() def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) From 0d0384bd29d1f3577dda055caa01c4518763a257 Mon Sep 17 00:00:00 2001 From: Jessie Luo Date: Fri, 14 Nov 2025 10:42:05 -0800 Subject: [PATCH 11/11] try rm self.client_user_context_extensions --- python/pyspark/sql/tests/connect/client/test_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index bf7108c94b090..d842d045bc053 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -141,7 +141,7 @@ def __init__(self, session_id: str): def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions + # self.client_user_context_extensions = req.user_context.extensions resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -162,14 +162,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions + # self.client_user_context_extensions = req.user_context.extensions resp = proto.InterruptResponse() resp.session_id = self._session_id return resp def Config(self, req: proto.ConfigRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions + # self.client_user_context_extensions = req.user_context.extensions resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): @@ -184,7 +184,7 @@ def Config(self, req: proto.ConfigRequest, metadata): def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions + # self.client_user_context_extensions = req.user_context.extensions resp = proto.AnalyzePlanResponse() resp.session_id = self._session_id # Return a minimal response with a semantic hash