From e515b85afdd2aadc371612f8e0ca0fff02e58768 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Wed, 12 Nov 2025 19:13:07 -0800 Subject: [PATCH 01/18] done --- .../resources/error/error-conditions.json | 6 + .../add_pipeline_analysis_context.py | 48 ++++ .../pyspark/pipelines/block_connect_access.py | 28 ++- python/pyspark/pipelines/cli.py | 17 +- .../spark_connect_graph_element_registry.py | 9 +- python/pyspark/sql/connect/client/core.py | 59 ++++- .../sql/connect/proto/pipelines_pb2.py | 8 +- .../sql/connect/proto/pipelines_pb2.pyi | 16 ++ .../sql/tests/connect/client/test_client.py | 91 +++++++ .../protobuf/spark/connect/pipelines.proto | 2 + .../connect/pipelines/PipelinesHandler.scala | 47 +++- .../connect/planner/SparkConnectPlanner.scala | 38 ++- .../pipelines/PythonPipelineSuite.scala | 224 +++++++++++++++++- 13 files changed, 564 insertions(+), 29 deletions(-) create mode 100644 python/pyspark/pipelines/add_pipeline_analysis_context.py diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 8f18b6b40f419..da3e946d8a421 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6261,6 +6261,12 @@ }, "sqlState" : "0A000" }, + "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND": { + "message" : [ + "'' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline." + ], + "sqlState" : "0A000" + }, "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" : { "message" : [ "The char/varchar type can't be used in the table schema.", diff --git a/python/pyspark/pipelines/add_pipeline_analysis_context.py b/python/pyspark/pipelines/add_pipeline_analysis_context.py new file mode 100644 index 0000000000000..00b251660e2a7 --- /dev/null +++ b/python/pyspark/pipelines/add_pipeline_analysis_context.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from contextlib import contextmanager +from typing import Generator, Optional +from pyspark.sql import SparkSession + +from typing import Any, cast + + +@contextmanager +def add_pipeline_analysis_context( + spark: SparkSession, dataflow_graph_id: str, flow_name_opt: Optional[str] +) -> Generator[None, None, None]: + """ + Context manager that add PipelineAnalysisContext extension to the user context + used for pipeline specific analysis. + """ + _extension_id = None + _client = cast(Any, spark).client + try: + import pyspark.sql.connect.proto as pb2 + from google.protobuf import any_pb2 + + _analysis_context = pb2.PipelineAnalysisContext(dataflow_graph_id=dataflow_graph_id) + if flow_name_opt is not None: + _analysis_context.flow_name = flow_name_opt + + _extension = any_pb2.Any() + _extension.Pack(_analysis_context) + + _extension_id = _client.add_threadlocal_user_context_extension(_extension) + yield + finally: + _client.remove_user_context_extension(_extension_id) diff --git a/python/pyspark/pipelines/block_connect_access.py b/python/pyspark/pipelines/block_connect_access.py index c5dacbbc2c5cb..28cfbb051d811 100644 --- a/python/pyspark/pipelines/block_connect_access.py +++ b/python/pyspark/pipelines/block_connect_access.py @@ -24,6 +24,22 @@ BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"] +def _is_sql_command_request(request: object) -> bool: + """Check if the request is spark.sql() command (ExecutePlanRequest with a sql_command).""" + try: + if not hasattr(request, "plan"): + return False + + plan = request.plan + + if not plan.HasField("command"): + return False + + return plan.command.HasField("sql_command") + except Exception: + return False + + @contextmanager def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]: """ @@ -41,7 +57,17 @@ def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable: if name not in BLOCKED_RPC_NAMES: return original_getattr(self, name) - def blocked_method(*args: object, **kwargs: object) -> NoReturn: + # Get the original method first + original_method = original_getattr(self, name) + + def blocked_method(*args: object, **kwargs: object): + # allowlist spark.sql() command (ExecutePlan with sql_command) + if name == "ExecutePlan" and len(args) > 0: + request = args[0] + if _is_sql_command_request(request): + return original_method(*args, **kwargs) + + # Block all other ExecutePlan and AnalyzePlan calls raise PySparkException( errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", messageParameters={}, diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index ca198f1c3aff3..6994d072a50e2 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -49,6 +49,8 @@ handle_pipeline_events, ) +from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context + PIPELINE_SPEC_FILE_NAMES = ["pipeline.yaml", "pipeline.yml"] @@ -216,7 +218,11 @@ def validate_str_dict(d: Mapping[str, str], field_name: str) -> Mapping[str, str def register_definitions( - spec_path: Path, registry: GraphElementRegistry, spec: PipelineSpec + spec_path: Path, + registry: GraphElementRegistry, + spec: PipelineSpec, + spark: SparkSession, + dataflow_graph_id: str, ) -> None: """Register the graph element definitions in the pipeline spec with the given registry. - Looks for Python files matching the glob patterns in the spec and imports them. @@ -245,8 +251,11 @@ def register_definitions( assert ( module_spec.loader is not None ), f"Module spec has no loader for {file}" - with block_session_mutations(): - module_spec.loader.exec_module(module) + with add_pipeline_analysis_context( + spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name_opt=None + ): + with block_session_mutations(): + module_spec.loader.exec_module(module) elif file.suffix == ".sql": log_with_curr_timestamp(f"Registering SQL file {file}...") with file.open("r") as f: @@ -324,7 +333,7 @@ def run( log_with_curr_timestamp("Registering graph elements...") registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) - register_definitions(spec_path, registry, spec) + register_definitions(spec_path, registry, spec, spark, dataflow_graph_id) log_with_curr_timestamp("Starting run...") result_iter = start_run( diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index e8a8561c3e749..79d54bb3bf776 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -35,6 +35,7 @@ from pyspark.sql.types import StructType from typing import Any, cast import pyspark.sql.connect.proto as pb2 +from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context class SparkConnectGraphElementRegistry(GraphElementRegistry): @@ -43,6 +44,7 @@ class SparkConnectGraphElementRegistry(GraphElementRegistry): def __init__(self, spark: SparkSession, dataflow_graph_id: str) -> None: # Cast because mypy seems to think `spark`` is a function, not an object. Likely related to # SPARK-47544. + self._spark = spark self._client = cast(Any, spark).client self._dataflow_graph_id = dataflow_graph_id @@ -110,8 +112,11 @@ def register_output(self, output: Output) -> None: self._client.execute_command(command) def register_flow(self, flow: Flow) -> None: - with block_spark_connect_execution_and_analysis(): - df = flow.func() + with add_pipeline_analysis_context( + spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name_opt=flow.name + ): + with block_spark_connect_execution_and_analysis(): + df = flow.func() relation = cast(ConnectDataFrame, df)._plan.plan(self._client) relation_flow_details = pb2.PipelineCommand.DefineFlow.WriteRelationFlowDetails( diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 2a2ac0e6b5399..3d706e3187b00 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -727,6 +727,9 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) + self.global_user_context_extensions = [] + self.global_user_context_extensions_lock = threading.Lock() + @property def _stub(self) -> grpc_lib.SparkConnectServiceStub: if self.is_closed: @@ -1277,6 +1280,24 @@ def token(self) -> Optional[str]: """ return self._builder.token + def _update_request_with_user_context_extensions( + self, + req: Union[ + pb2.AnalyzePlanRequest, + pb2.ConfigRequest, + pb2.ExecutePlanRequest, + pb2.FetchErrorDetailsRequest, + pb2.InterruptRequest, + ], + ) -> None: + with self.global_user_context_extensions_lock: + for _, extension in self.global_user_context_extensions: + req.user_context.extensions.append(extension) + if not hasattr(self.thread_local, "user_context_extensions"): + return + for _, extension in self.thread_local.user_context_extensions: + req.user_context.extensions.append(extension) + def _execute_plan_request_with_metadata( self, operation_id: Optional[str] = None ) -> pb2.ExecutePlanRequest: @@ -1307,6 +1328,7 @@ def _execute_plan_request_with_metadata( messageParameters={"arg_name": "operation_id", "origin": str(ve)}, ) req.operation_id = operation_id + self._update_request_with_user_context_extensions(req) return req def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: @@ -1317,6 +1339,7 @@ def _analyze_plan_request_with_metadata(self) -> pb2.AnalyzePlanRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def _analyze(self, method: str, **kwargs: Any) -> AnalyzeResult: @@ -1731,6 +1754,7 @@ def _config_request_with_metadata(self) -> pb2.ConfigRequest: req.client_type = self._builder.userAgent if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def get_configs(self, *keys: str) -> Tuple[Optional[str], ...]: @@ -1807,6 +1831,7 @@ def _interrupt_request( ) if self._user_id: req.user_context.user_id = self._user_id + self._update_request_with_user_context_extensions(req) return req def interrupt_all(self) -> Optional[List[str]]: @@ -1905,6 +1930,38 @@ def _throw_if_invalid_tag(self, tag: str) -> None: messageParameters={"arg_name": "Spark Connect tag", "arg_value": tag}, ) + def add_threadlocal_user_context_extension(self, extension: any_pb2.Any) -> str: + if not hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + extension_id = "threadlocal_" + str(uuid.uuid4()) + self.thread_local.user_context_extensions.append((extension_id, extension)) + return extension_id + + def add_global_user_context_extension(self, extension: any_pb2.Any) -> str: + extension_id = "global_" + str(uuid.uuid4()) + with self.global_user_context_extensions_lock: + self.global_user_context_extensions.append((extension_id, extension)) + return extension_id + + def remove_user_context_extension(self, extension_id: str) -> None: + if extension_id.find("threadlocal_") == 0: + if not hasattr(self.thread_local, "user_context_extensions"): + return + self.thread_local.user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.thread_local.user_context_extensions) + ) + elif extension_id.find("global_") == 0: + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list( + filter(lambda ex: ex[0] != extension_id, self.global_user_context_extensions) + ) + + def clear_user_context_extensions(self) -> None: + if hasattr(self.thread_local, "user_context_extensions"): + self.thread_local.user_context_extensions = list() + with self.global_user_context_extensions_lock: + self.global_user_context_extensions = list() + def _handle_error(self, error: Exception) -> NoReturn: """ Handle errors that occur during RPC calls. @@ -1945,7 +2002,7 @@ def _fetch_enriched_error(self, info: "ErrorInfo") -> Optional[pb2.FetchErrorDet req.client_observed_server_side_session_id = self._server_session_id if self._user_id: req.user_context.user_id = self._user_id - + self._update_request_with_user_context_extensions(req) try: return self._stub.FetchErrorDetails(req, metadata=self._builder.metadata()) except grpc.RpcError: diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.py b/python/pyspark/sql/connect/proto/pipelines_pb2.py index 0eb77c84b5b57..7a30def861d29 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.py +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.py @@ -42,7 +42,7 @@ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xed"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutputH\x00R\x0c\x64\x65\x66ineOutput\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x92\n\n\x0c\x44\x65\x66ineOutput\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12$\n\x0boutput_name\x18\x02 \x01(\tH\x02R\noutputName\x88\x01\x01\x12?\n\x0boutput_type\x18\x03 \x01(\x0e\x32\x19.spark.connect.OutputTypeH\x03R\noutputType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x04R\x07\x63omment\x88\x01\x01\x12X\n\x14source_code_location\x18\x05 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12_\n\rtable_details\x18\x06 \x01(\x0b\x32\x38.spark.connect.PipelineCommand.DefineOutput.TableDetailsH\x00R\x0ctableDetails\x12\\\n\x0csink_details\x18\x07 \x01(\x0b\x32\x37.spark.connect.PipelineCommand.DefineOutput.SinkDetailsH\x00R\x0bsinkDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xc0\x03\n\x0cTableDetails\x12x\n\x10table_properties\x18\x01 \x03(\x0b\x32M.spark.connect.PipelineCommand.DefineOutput.TableDetails.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x02 \x03(\tR\rpartitionCols\x12\x1b\n\x06\x66ormat\x18\x03 \x01(\tH\x01R\x06\x66ormat\x88\x01\x01\x12\x43\n\x10schema_data_type\x18\x04 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x0eschemaDataType\x12%\n\rschema_string\x18\x05 \x01(\tH\x00R\x0cschemaString\x12-\n\x12\x63lustering_columns\x18\x06 \x03(\tR\x11\x63lusteringColumns\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06schemaB\t\n\x07_format\x1a\xd1\x01\n\x0bSinkDetails\x12^\n\x07options\x18\x01 \x03(\x0b\x32\x44.spark.connect.PipelineCommand.DefineOutput.SinkDetails.OptionsEntryR\x07options\x12\x1b\n\x06\x66ormat\x18\x02 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0e\n\x0c_output_nameB\x0e\n\x0c_output_typeB\n\n\x08_commentB\x17\n\x15_source_code_location\x1a\xff\x06\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x03R\x11targetDatasetName\x88\x01\x01\x12Q\n\x08sql_conf\x18\x04 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x05 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x06 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12x\n\x15relation_flow_details\x18\x07 \x01(\x0b\x32\x42.spark.connect.PipelineCommand.DefineFlow.WriteRelationFlowDetailsH\x00R\x13relationFlowDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x17\n\x04once\x18\x08 \x01(\x08H\x06R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x61\n\x18WriteRelationFlowDetails\x12\x38\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x08relation\x88\x01\x01\x42\x0b\n\t_relation\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0c\n\n_client_idB\x17\n\x15_source_code_locationB\x07\n\x05_once\x1a\xc2\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x12\x1d\n\x07storage\x18\x06 \x01(\tH\x03R\x07storage\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dryB\n\n\x08_storage\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf0\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12k\n\x14\x64\x65\x66ine_output_result\x18\x02 \x01(\x0b\x32\x37.spark.connect.PipelineCommandResult.DefineOutputResultH\x00R\x12\x64\x65\x66ineOutputResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x85\x01\n\x12\x44\x65\x66ineOutputResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"\xf1\x01\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x03 \x01(\tH\x02R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x0c\n\n_file_nameB\x0e\n\x0c_line_numberB\x12\n\x10_definition_path"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames"\xd7\x01\n\x17PipelineAnalysisContext\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x02 \x01(\tH\x01R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x14\n\x12_dataflow_graph_idB\x12\n\x10_definition_path*i\n\nOutputType\x12\x1b\n\x17OUTPUT_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x12\x08\n\x04SINK\x10\x04\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' + b'\n\x1dspark/connect/pipelines.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xed"\n\x0fPipelineCommand\x12h\n\x15\x63reate_dataflow_graph\x18\x01 \x01(\x0b\x32\x32.spark.connect.PipelineCommand.CreateDataflowGraphH\x00R\x13\x63reateDataflowGraph\x12R\n\rdefine_output\x18\x02 \x01(\x0b\x32+.spark.connect.PipelineCommand.DefineOutputH\x00R\x0c\x64\x65\x66ineOutput\x12L\n\x0b\x64\x65\x66ine_flow\x18\x03 \x01(\x0b\x32).spark.connect.PipelineCommand.DefineFlowH\x00R\ndefineFlow\x12\x62\n\x13\x64rop_dataflow_graph\x18\x04 \x01(\x0b\x32\x30.spark.connect.PipelineCommand.DropDataflowGraphH\x00R\x11\x64ropDataflowGraph\x12\x46\n\tstart_run\x18\x05 \x01(\x0b\x32\'.spark.connect.PipelineCommand.StartRunH\x00R\x08startRun\x12r\n\x19\x64\x65\x66ine_sql_graph_elements\x18\x06 \x01(\x0b\x32\x35.spark.connect.PipelineCommand.DefineSqlGraphElementsH\x00R\x16\x64\x65\x66ineSqlGraphElements\x12\xa1\x01\n*get_query_function_execution_signal_stream\x18\x07 \x01(\x0b\x32\x44.spark.connect.PipelineCommand.GetQueryFunctionExecutionSignalStreamH\x00R%getQueryFunctionExecutionSignalStream\x12\x88\x01\n!define_flow_query_function_result\x18\x08 \x01(\x0b\x32<.spark.connect.PipelineCommand.DefineFlowQueryFunctionResultH\x00R\x1d\x64\x65\x66ineFlowQueryFunctionResult\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xb4\x02\n\x13\x43reateDataflowGraph\x12,\n\x0f\x64\x65\x66\x61ult_catalog\x18\x01 \x01(\tH\x00R\x0e\x64\x65\x66\x61ultCatalog\x88\x01\x01\x12.\n\x10\x64\x65\x66\x61ult_database\x18\x02 \x01(\tH\x01R\x0f\x64\x65\x66\x61ultDatabase\x88\x01\x01\x12Z\n\x08sql_conf\x18\x05 \x03(\x0b\x32?.spark.connect.PipelineCommand.CreateDataflowGraph.SqlConfEntryR\x07sqlConf\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x12\n\x10_default_catalogB\x13\n\x11_default_database\x1aZ\n\x11\x44ropDataflowGraph\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x92\n\n\x0c\x44\x65\x66ineOutput\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12$\n\x0boutput_name\x18\x02 \x01(\tH\x02R\noutputName\x88\x01\x01\x12?\n\x0boutput_type\x18\x03 \x01(\x0e\x32\x19.spark.connect.OutputTypeH\x03R\noutputType\x88\x01\x01\x12\x1d\n\x07\x63omment\x18\x04 \x01(\tH\x04R\x07\x63omment\x88\x01\x01\x12X\n\x14source_code_location\x18\x05 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12_\n\rtable_details\x18\x06 \x01(\x0b\x32\x38.spark.connect.PipelineCommand.DefineOutput.TableDetailsH\x00R\x0ctableDetails\x12\\\n\x0csink_details\x18\x07 \x01(\x0b\x32\x37.spark.connect.PipelineCommand.DefineOutput.SinkDetailsH\x00R\x0bsinkDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x1a\xc0\x03\n\x0cTableDetails\x12x\n\x10table_properties\x18\x01 \x03(\x0b\x32M.spark.connect.PipelineCommand.DefineOutput.TableDetails.TablePropertiesEntryR\x0ftableProperties\x12%\n\x0epartition_cols\x18\x02 \x03(\tR\rpartitionCols\x12\x1b\n\x06\x66ormat\x18\x03 \x01(\tH\x01R\x06\x66ormat\x88\x01\x01\x12\x43\n\x10schema_data_type\x18\x04 \x01(\x0b\x32\x17.spark.connect.DataTypeH\x00R\x0eschemaDataType\x12%\n\rschema_string\x18\x05 \x01(\tH\x00R\x0cschemaString\x12-\n\x12\x63lustering_columns\x18\x06 \x03(\tR\x11\x63lusteringColumns\x1a\x42\n\x14TablePropertiesEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\x08\n\x06schemaB\t\n\x07_format\x1a\xd1\x01\n\x0bSinkDetails\x12^\n\x07options\x18\x01 \x03(\x0b\x32\x44.spark.connect.PipelineCommand.DefineOutput.SinkDetails.OptionsEntryR\x07options\x12\x1b\n\x06\x66ormat\x18\x02 \x01(\tH\x00R\x06\x66ormat\x88\x01\x01\x1a:\n\x0cOptionsEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x42\t\n\x07_formatB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0e\n\x0c_output_nameB\x0e\n\x0c_output_typeB\n\n\x08_commentB\x17\n\x15_source_code_location\x1a\xff\x06\n\nDefineFlow\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tflow_name\x18\x02 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\x13target_dataset_name\x18\x03 \x01(\tH\x03R\x11targetDatasetName\x88\x01\x01\x12Q\n\x08sql_conf\x18\x04 \x03(\x0b\x32\x36.spark.connect.PipelineCommand.DefineFlow.SqlConfEntryR\x07sqlConf\x12 \n\tclient_id\x18\x05 \x01(\tH\x04R\x08\x63lientId\x88\x01\x01\x12X\n\x14source_code_location\x18\x06 \x01(\x0b\x32!.spark.connect.SourceCodeLocationH\x05R\x12sourceCodeLocation\x88\x01\x01\x12x\n\x15relation_flow_details\x18\x07 \x01(\x0b\x32\x42.spark.connect.PipelineCommand.DefineFlow.WriteRelationFlowDetailsH\x00R\x13relationFlowDetails\x12\x35\n\textension\x18\xe7\x07 \x01(\x0b\x32\x14.google.protobuf.AnyH\x00R\textension\x12\x17\n\x04once\x18\x08 \x01(\x08H\x06R\x04once\x88\x01\x01\x1a:\n\x0cSqlConfEntry\x12\x10\n\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n\x05value\x18\x02 \x01(\tR\x05value:\x02\x38\x01\x1a\x61\n\x18WriteRelationFlowDetails\x12\x38\n\x08relation\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x08relation\x88\x01\x01\x42\x0b\n\t_relation\x1a:\n\x08Response\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x42\x0c\n\n_flow_nameB\t\n\x07\x64\x65tailsB\x14\n\x12_dataflow_graph_idB\x0c\n\n_flow_nameB\x16\n\x14_target_dataset_nameB\x0c\n\n_client_idB\x17\n\x15_source_code_locationB\x07\n\x05_once\x1a\xc2\x02\n\x08StartRun\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x34\n\x16\x66ull_refresh_selection\x18\x02 \x03(\tR\x14\x66ullRefreshSelection\x12-\n\x10\x66ull_refresh_all\x18\x03 \x01(\x08H\x01R\x0e\x66ullRefreshAll\x88\x01\x01\x12+\n\x11refresh_selection\x18\x04 \x03(\tR\x10refreshSelection\x12\x15\n\x03\x64ry\x18\x05 \x01(\x08H\x02R\x03\x64ry\x88\x01\x01\x12\x1d\n\x07storage\x18\x06 \x01(\tH\x03R\x07storage\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x13\n\x11_full_refresh_allB\x06\n\x04_dryB\n\n\x08_storage\x1a\xc7\x01\n\x16\x44\x65\x66ineSqlGraphElements\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\'\n\rsql_file_path\x18\x02 \x01(\tH\x01R\x0bsqlFilePath\x88\x01\x01\x12\x1e\n\x08sql_text\x18\x03 \x01(\tH\x02R\x07sqlText\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x10\n\x0e_sql_file_pathB\x0b\n\t_sql_text\x1a\x9e\x01\n%GetQueryFunctionExecutionSignalStream\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12 \n\tclient_id\x18\x02 \x01(\tH\x01R\x08\x63lientId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_idB\x0c\n\n_client_id\x1a\xdd\x01\n\x1d\x44\x65\x66ineFlowQueryFunctionResult\x12 \n\tflow_name\x18\x01 \x01(\tH\x00R\x08\x66lowName\x88\x01\x01\x12/\n\x11\x64\x61taflow_graph_id\x18\x02 \x01(\tH\x01R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12\x38\n\x08relation\x18\x03 \x01(\x0b\x32\x17.spark.connect.RelationH\x02R\x08relation\x88\x01\x01\x42\x0c\n\n_flow_nameB\x14\n\x12_dataflow_graph_idB\x0b\n\t_relationB\x0e\n\x0c\x63ommand_type"\xf0\x05\n\x15PipelineCommandResult\x12\x81\x01\n\x1c\x63reate_dataflow_graph_result\x18\x01 \x01(\x0b\x32>.spark.connect.PipelineCommandResult.CreateDataflowGraphResultH\x00R\x19\x63reateDataflowGraphResult\x12k\n\x14\x64\x65\x66ine_output_result\x18\x02 \x01(\x0b\x32\x37.spark.connect.PipelineCommandResult.DefineOutputResultH\x00R\x12\x64\x65\x66ineOutputResult\x12\x65\n\x12\x64\x65\x66ine_flow_result\x18\x03 \x01(\x0b\x32\x35.spark.connect.PipelineCommandResult.DefineFlowResultH\x00R\x10\x64\x65\x66ineFlowResult\x1a\x62\n\x19\x43reateDataflowGraphResult\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x42\x14\n\x12_dataflow_graph_id\x1a\x85\x01\n\x12\x44\x65\x66ineOutputResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifier\x1a\x83\x01\n\x10\x44\x65\x66ineFlowResult\x12W\n\x13resolved_identifier\x18\x01 \x01(\x0b\x32!.spark.connect.ResolvedIdentifierH\x00R\x12resolvedIdentifier\x88\x01\x01\x42\x16\n\x14_resolved_identifierB\r\n\x0bresult_type"I\n\x13PipelineEventResult\x12\x32\n\x05\x65vent\x18\x01 \x01(\x0b\x32\x1c.spark.connect.PipelineEventR\x05\x65vent"t\n\rPipelineEvent\x12\x38\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.TimestampR\ttimestamp\x12\x1d\n\x07message\x18\x02 \x01(\tH\x00R\x07message\x88\x01\x01\x42\n\n\x08_message"\xf1\x01\n\x12SourceCodeLocation\x12 \n\tfile_name\x18\x01 \x01(\tH\x00R\x08\x66ileName\x88\x01\x01\x12$\n\x0bline_number\x18\x02 \x01(\x05H\x01R\nlineNumber\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x03 \x01(\tH\x02R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x0c\n\n_file_nameB\x0e\n\x0c_line_numberB\x12\n\x10_definition_path"E\n$PipelineQueryFunctionExecutionSignal\x12\x1d\n\nflow_names\x18\x01 \x03(\tR\tflowNames"\x87\x02\n\x17PipelineAnalysisContext\x12/\n\x11\x64\x61taflow_graph_id\x18\x01 \x01(\tH\x00R\x0f\x64\x61taflowGraphId\x88\x01\x01\x12,\n\x0f\x64\x65\x66inition_path\x18\x02 \x01(\tH\x01R\x0e\x64\x65\x66initionPath\x88\x01\x01\x12 \n\tflow_name\x18\x03 \x01(\tH\x02R\x08\x66lowName\x88\x01\x01\x12\x33\n\textension\x18\xe7\x07 \x03(\x0b\x32\x14.google.protobuf.AnyR\textensionB\x14\n\x12_dataflow_graph_idB\x12\n\x10_definition_pathB\x0c\n\n_flow_name*i\n\nOutputType\x12\x1b\n\x17OUTPUT_TYPE_UNSPECIFIED\x10\x00\x12\x15\n\x11MATERIALIZED_VIEW\x10\x01\x12\t\n\x05TABLE\x10\x02\x12\x12\n\x0eTEMPORARY_VIEW\x10\x03\x12\x08\n\x04SINK\x10\x04\x42\x36\n\x1eorg.apache.spark.connect.protoP\x01Z\x12internal/generatedb\x06proto3' ) _globals = globals() @@ -69,8 +69,8 @@ ]._serialized_options = b"8\001" _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._loaded_options = None _globals["_PIPELINECOMMAND_DEFINEFLOW_SQLCONFENTRY"]._serialized_options = b"8\001" - _globals["_OUTPUTTYPE"]._serialized_start = 6139 - _globals["_OUTPUTTYPE"]._serialized_end = 6244 + _globals["_OUTPUTTYPE"]._serialized_start = 6187 + _globals["_OUTPUTTYPE"]._serialized_end = 6292 _globals["_PIPELINECOMMAND"]._serialized_start = 195 _globals["_PIPELINECOMMAND"]._serialized_end = 4656 _globals["_PIPELINECOMMAND_CREATEDATAFLOWGRAPH"]._serialized_start = 1129 @@ -126,5 +126,5 @@ _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_start = 5850 _globals["_PIPELINEQUERYFUNCTIONEXECUTIONSIGNAL"]._serialized_end = 5919 _globals["_PIPELINEANALYSISCONTEXT"]._serialized_start = 5922 - _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6137 + _globals["_PIPELINEANALYSISCONTEXT"]._serialized_end = 6185 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi index e0768a1f6baeb..39a1e29ae7dde 100644 --- a/python/pyspark/sql/connect/proto/pipelines_pb2.pyi +++ b/python/pyspark/sql/connect/proto/pipelines_pb2.pyi @@ -1499,11 +1499,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message): DATAFLOW_GRAPH_ID_FIELD_NUMBER: builtins.int DEFINITION_PATH_FIELD_NUMBER: builtins.int + FLOW_NAME_FIELD_NUMBER: builtins.int EXTENSION_FIELD_NUMBER: builtins.int dataflow_graph_id: builtins.str """Unique identifier of the dataflow graph associated with this pipeline.""" definition_path: builtins.str """The path of the top-level pipeline file determined at runtime during pipeline initialization.""" + flow_name: builtins.str + """The name of the Flow involved in this analysis""" @property def extension( self, @@ -1516,6 +1519,7 @@ class PipelineAnalysisContext(google.protobuf.message.Message): *, dataflow_graph_id: builtins.str | None = ..., definition_path: builtins.str | None = ..., + flow_name: builtins.str | None = ..., extension: collections.abc.Iterable[google.protobuf.any_pb2.Any] | None = ..., ) -> None: ... def HasField( @@ -1525,10 +1529,14 @@ class PipelineAnalysisContext(google.protobuf.message.Message): b"_dataflow_graph_id", "_definition_path", b"_definition_path", + "_flow_name", + b"_flow_name", "dataflow_graph_id", b"dataflow_graph_id", "definition_path", b"definition_path", + "flow_name", + b"flow_name", ], ) -> builtins.bool: ... def ClearField( @@ -1538,12 +1546,16 @@ class PipelineAnalysisContext(google.protobuf.message.Message): b"_dataflow_graph_id", "_definition_path", b"_definition_path", + "_flow_name", + b"_flow_name", "dataflow_graph_id", b"dataflow_graph_id", "definition_path", b"definition_path", "extension", b"extension", + "flow_name", + b"flow_name", ], ) -> None: ... @typing.overload @@ -1554,5 +1566,9 @@ class PipelineAnalysisContext(google.protobuf.message.Message): def WhichOneof( self, oneof_group: typing_extensions.Literal["_definition_path", b"_definition_path"] ) -> typing_extensions.Literal["definition_path"] | None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_flow_name", b"_flow_name"] + ) -> typing_extensions.Literal["flow_name"] | None: ... global___PipelineAnalysisContext = PipelineAnalysisContext diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index c189f996cbe43..2006db03b95a2 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -136,9 +136,11 @@ class MockService: def __init__(self, session_id: str): self._session_id = session_id self.req = None + self.client_user_context_extensions = [] def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -159,12 +161,14 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions resp = proto.InterruptResponse() resp.session_id = self._session_id return resp def Config(self, req: proto.ConfigRequest, metadata): self.req = req + self.client_user_context_extensions = req.user_context.extensions resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): @@ -229,6 +233,93 @@ def userId(self) -> Optional[str]: self.assertEqual(client._user_id, "abc") + def test_user_context_extension(self): + client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) + mock = MockService(client._session_id) + client._stub = mock + + exlocal = any_pb2.Any() + exlocal.Pack(wrappers_pb2.StringValue(value="abc")) + exlocal2 = any_pb2.Any() + exlocal2.Pack(wrappers_pb2.StringValue(value="def")) + exglobal = any_pb2.Any() + exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) + exglobal2 = any_pb2.Any() + exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) + + exlocal_id = client.add_threadlocal_user_context_extension(exlocal) + exglobal_id = client.add_global_user_context_extension(exglobal) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_threadlocal_user_context_extension(exlocal2) + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_global_user_context_extension(exglobal2) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exlocal_id) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exglobal_id) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.clear_user_context_extensions() + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) diff --git a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto index a92e24fda9154..0874c2d10ec5c 100644 --- a/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto +++ b/sql/connect/common/src/main/protobuf/spark/connect/pipelines.proto @@ -299,6 +299,8 @@ message PipelineAnalysisContext { optional string dataflow_graph_id = 1; // The path of the top-level pipeline file determined at runtime during pipeline initialization. optional string definition_path = 2; + // The name of the Flow involved in this analysis + optional string flow_name = 3; // Reserved field for protocol extensions. repeated google.protobuf.Any extension = 999; diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 1a3b0d2231c62..356fd4f473172 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connect.pipelines +import scala.collection.Seq import scala.jdk.CollectionConverters._ import scala.util.Using @@ -27,9 +28,10 @@ import org.apache.spark.connect.proto.{ExecutePlanResponse, PipelineCommandResul import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, CreateNamespace, CreateTable, CreateTableAsSelect, CreateView, DescribeRelation, DropView, InsertIntoStatement, LogicalPlan, RenameTable, ShowColumns, ShowCreateTable, ShowFunctions, ShowTableProperties, ShowTables, ShowViews} import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.service.SessionHolder +import org.apache.spark.sql.execution.command.{ShowCatalogsCommand, ShowNamespacesCommand} import org.apache.spark.sql.pipelines.Language.Python import org.apache.spark.sql.pipelines.common.RunState.{CANCELED, FAILED} import org.apache.spark.sql.pipelines.graph.{AllTables, FlowAnalysis, GraphIdentifierManager, GraphRegistrationContext, IdentifierHelper, NoTables, PipelineUpdateContextImpl, QueryContext, QueryOrigin, QueryOriginType, Sink, SinkImpl, SomeTables, SqlGraphRegistrationContext, Table, TableFilter, TemporaryView, UnresolvedFlow} @@ -129,6 +131,46 @@ private[connect] object PipelinesHandler extends Logging { } } + /** + * Block unsupported SQL commands that are not explicitly allowlisted. + */ + def blockUnsupportedSqlCommand(queryPlan: LogicalPlan): Unit = { + val supportedCommand = Set( + classOf[DescribeRelation], + classOf[ShowTables], + classOf[ShowTableProperties], + classOf[ShowNamespacesCommand], + classOf[ShowColumns], + classOf[ShowFunctions], + classOf[ShowViews], + classOf[ShowCatalogsCommand], + classOf[ShowCreateTable]) + val isSqlCommandExplicitlyAllowlisted = { + supportedCommand.exists(c => queryPlan.getClass.getName.equals(c.getName)) + } + val isUnsupportedSqlPlan = if (isSqlCommandExplicitlyAllowlisted) { + false + } else { + // If the SQL command is not explicitly allowlisted, check whether it belongs to + // one of commands pipeline explicitly disallow. + // If not, the SQL command is supported. + queryPlan.isInstanceOf[Command] || + queryPlan.isInstanceOf[CreateTableAsSelect] || + queryPlan.isInstanceOf[CreateTable] || + queryPlan.isInstanceOf[CreateView] || + queryPlan.isInstanceOf[InsertIntoStatement] || + queryPlan.isInstanceOf[RenameTable] || + queryPlan.isInstanceOf[CreateNamespace] || + queryPlan.isInstanceOf[DropView] + } + // scalastyle:on + if (isUnsupportedSqlPlan) { + throw new AnalysisException( + "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND", + Map("command" -> queryPlan.getClass.getSimpleName)) + } + } + private def createDataflowGraph( cmd: proto.PipelineCommand.CreateDataflowGraph, sessionHolder: SessionHolder): String = { @@ -148,6 +190,9 @@ private[connect] object PipelinesHandler extends Logging { val defaultSqlConf = cmd.getSqlConfMap.asScala.toMap + sessionHolder.session.catalog.setCurrentCatalog(defaultCatalog) + sessionHolder.session.catalog.setCurrentDatabase(defaultDatabase) + sessionHolder.dataflowGraphRegistry.createDataflowGraph( defaultCatalog = defaultCatalog, defaultDatabase = defaultDatabase, diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 8bc33c41b3a30..644784fa3db6c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -21,11 +21,12 @@ import java.util.{HashMap, Properties, UUID} import scala.collection.mutable import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag import scala.util.Try import scala.util.control.NonFatal import com.google.common.collect.Lists -import com.google.protobuf.{Any => ProtoAny, ByteString} +import com.google.protobuf.{Any => ProtoAny, ByteString, Message} import io.grpc.{Context, Status, StatusRuntimeException} import io.grpc.stub.StreamObserver @@ -33,7 +34,7 @@ import org.apache.spark.{SparkClassNotFoundException, SparkEnv, SparkException, import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction} import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult} +import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, PipelineAnalysisContext, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult} import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult import org.apache.spark.connect.proto.Parse.ParseFormat import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance @@ -2941,10 +2942,28 @@ class SparkConnectPlanner( .build()) } + private def getExtensionList[T <: Message: ClassTag]( + extensions: mutable.Buffer[ProtoAny]): Seq[T] = { + val cls = implicitly[ClassTag[T]].runtimeClass + .asInstanceOf[Class[_ <: Message]] + extensions.collect { + case any if any.is(cls) => any.unpack(cls).asInstanceOf[T] + }.toSeq + } + private def handleSqlCommand( command: SqlCommand, responseObserver: StreamObserver[ExecutePlanResponse]): Unit = { val tracker = executeHolder.eventsManager.createQueryPlanningTracker() + val userContextExtensions = executeHolder.request.getUserContext.getExtensionsList.asScala + val pipelineAnalysisContextList = { + getExtensionList[PipelineAnalysisContext](userContextExtensions) + } + val hasPipelineAnalysisContext = pipelineAnalysisContextList.nonEmpty + val insidePipelineFlowFunction = pipelineAnalysisContextList.exists(_.hasFlowName) + // To avoid explicit handling of the result on the client, we build the expected input + // of the relation on the server. The client has to simply forward the result. + val result = SqlCommandResult.newBuilder() val relation = if (command.hasInput) { command.getInput @@ -2964,6 +2983,18 @@ class SparkConnectPlanner( .build() } + // Block unsupported SQL commands if the request comes from Spark Declarative Pipelines. + if (hasPipelineAnalysisContext) { + PipelinesHandler.blockUnsupportedSqlCommand(queryPlan = transformRelation(relation)) + } + + // If the spark.sql() is called inside a pipeline flow function, we don't need to execute + // the SQL command and defer the actual analysis and execution to the flow function. + if (insidePipelineFlowFunction) { + result.setRelation(relation) + return + } + val df = relation.getRelTypeCase match { case proto.Relation.RelTypeCase.SQL => executeSQL(relation.getSql, tracker) @@ -2982,9 +3013,6 @@ class SparkConnectPlanner( case _ => Seq.empty } - // To avoid explicit handling of the result on the client, we build the expected input - // of the relation on the server. The client has to simply forward the result. - val result = SqlCommandResult.newBuilder() // Only filled when isCommand val metrics = ExecutePlanResponse.Metrics.newBuilder() if (isCommand || isSqlScript) { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 1a72d112aa2ef..4999a62c26d0d 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -50,7 +50,7 @@ class PythonPipelineSuite def buildGraph(pythonText: String): DataflowGraph = { assume(PythonTestDepsChecker.isConnectDepsAvailable) - val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") + val indentedPythonText = pythonText.linesIterator.map(" " + _).mkString("\n") // create a unique identifier to allow identifying the session and dataflow graph val customSessionIdentifier = UUID.randomUUID().toString val pythonCode = @@ -64,6 +64,9 @@ class PythonPipelineSuite |from pyspark.pipelines.graph_element_registry import ( | graph_element_registration_context, |) + |from pyspark.pipelines.add_pipeline_analysis_context import ( + | add_pipeline_analysis_context + |) | |spark = SparkSession.builder \\ | .remote("sc://localhost:$serverPort") \\ @@ -79,7 +82,10 @@ class PythonPipelineSuite |) | |registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) - |with graph_element_registration_context(registry): + |with add_pipeline_analysis_context( + | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name_opt=None + |): + | with graph_element_registration_context(registry): |$indentedPythonText |""".stripMargin @@ -334,21 +340,35 @@ class PythonPipelineSuite |@dp.table |def b(): | return spark.readStream.table("src") + | + |@dp.materialized_view + |def c(): + | return spark.sql("SELECT * FROM src") + | + |@dp.table + |def d(): + | return spark.sql("SELECT * FROM STREAM src") |""".stripMargin).resolve().validate() assert( graph.table.keySet == Set( graphIdentifier("src"), graphIdentifier("a"), - graphIdentifier("b"))) - Seq("a", "b").foreach { flowName => + graphIdentifier("b"), + graphIdentifier("c"), + graphIdentifier("d"))) + Seq("a", "b", "c").foreach { flowName => // dependency is properly tracked assert(graph.resolvedFlow(graphIdentifier(flowName)).inputs == Set(graphIdentifier("src"))) } val (streamingFlows, batchFlows) = graph.resolvedFlows.partition(_.df.isStreaming) - assert(batchFlows.map(_.identifier) == Seq(graphIdentifier("src"), graphIdentifier("a"))) - assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("b"))) + assert( + batchFlows.map(_.identifier) == Seq( + graphIdentifier("src"), + graphIdentifier("a"), + graphIdentifier("c"))) + assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("b"), graphIdentifier("d"))) } test("referencing external datasets") { @@ -365,18 +385,32 @@ class PythonPipelineSuite |@dp.table |def c(): | return spark.readStream.table("spark_catalog.default.src") + | + |@dp.materialized_view + |def d(): + | return spark.sql("SELECT * FROM spark_catalog.default.src") + | + |@dp.table + |def e(): + | return spark.sql("SELECT * FROM STREAM spark_catalog.default.src") |""".stripMargin).resolve().validate() assert( graph.tables.map(_.identifier).toSet == Set( graphIdentifier("a"), graphIdentifier("b"), - graphIdentifier("c"))) + graphIdentifier("c"), + graphIdentifier("d"), + graphIdentifier("e"))) // dependency is not tracked assert(graph.resolvedFlows.forall(_.inputs.isEmpty)) val (streamingFlows, batchFlows) = graph.resolvedFlows.partition(_.df.isStreaming) - assert(batchFlows.map(_.identifier).toSet == Set(graphIdentifier("a"), graphIdentifier("b"))) - assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("c"))) + assert( + batchFlows.map(_.identifier).toSet == Set( + graphIdentifier("a"), + graphIdentifier("b"), + graphIdentifier("d"))) + assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("c"), graphIdentifier("e"))) } test("referencing internal datasets failed") { @@ -392,9 +426,17 @@ class PythonPipelineSuite |@dp.table |def c(): | return spark.readStream.table("src") + | + |@dp.materialized_view + |def d(): + | return spark.sql("SELECT * FROM src") + | + |@dp.table + |def e(): + | return spark.sql("SELECT * FROM STREAM src") |""".stripMargin).resolve() - assert(graph.resolutionFailedFlows.size == 3) + assert(graph.resolutionFailedFlows.size == 5) graph.resolutionFailedFlows.foreach { flow => assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND]")) assert(flow.failure.head.getMessage.contains("`src`")) @@ -414,12 +456,94 @@ class PythonPipelineSuite |@dp.materialized_view |def c(): | return spark.readStream.table("spark_catalog.default.src") + | + |@dp.materialized_view + |def d(): + | return spark.sql("SELECT * FROM spark_catalog.default.src") + | + |@dp.table + |def e(): + | return spark.sql("SELECT * FROM STREAM spark_catalog.default.src") |""".stripMargin).resolve() + assert(graph.resolutionFailedFlows.size == 5) graph.resolutionFailedFlows.foreach { flow => - assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND] The table or view")) + assert(flow.failure.head.getMessage.contains("[TABLE_OR_VIEW_NOT_FOUND]")) + assert(flow.failure.head.getMessage.contains("`spark_catalog`.`default`.`src`")) } } + test("reading external datasets outside query function works") { + sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") + val graph = buildGraph(s""" + |spark_sql_df = spark.sql("SELECT * FROM spark_catalog.default.src") + |read_table_df = spark.read.table("spark_catalog.default.src") + | + |@dp.materialized_view + |def mv_from_spark_sql_df(): + | return spark_sql_df + | + |@dp.materialized_view + |def mv_from_read_table_df(): + | return read_table_df + |""".stripMargin).resolve().validate() + + assert( + graph.resolvedFlows.map(_.identifier).toSet == Set( + graphIdentifier("mv_from_spark_sql_df"), + graphIdentifier("mv_from_read_table_df"))) + assert(graph.resolvedFlows.forall(_.inputs.isEmpty)) + assert(graph.resolvedFlows.forall(!_.df.isStreaming)) + } + + test( + "reading internal datasets outside query function that don't trigger " + + "eager analysis or execution") { + val graph = buildGraph(""" + |@dp.materialized_view + |def src(): + | return spark.range(5) + | + |read_table_df = spark.read.table("src") + | + |@dp.materialized_view + |def mv_from_read_table_df(): + | return read_table_df + | + |""".stripMargin).resolve().validate() + assert( + graph.resolvedFlows.map(_.identifier).toSet == Set( + graphIdentifier("mv_from_read_table_df"), + graphIdentifier("src"))) + assert(graph.resolvedFlows.forall(!_.df.isStreaming)) + assert( + graph + .resolvedFlow(graphIdentifier("mv_from_read_table_df")) + .inputs + .contains(graphIdentifier("src"))) + } + + gridTest( + "reading internal datasets outside query function that trigger " + + "eager analysis or execution will fail")( + Seq("""spark.sql("SELECT * FROM src")""", """spark.read.table("src").collect()""")) { + command => + val ex = intercept[RuntimeException] { + buildGraph(s""" + |@dp.materialized_view + |def src(): + | return spark.range(5) + | + |spark_sql_df = $command + | + |@dp.materialized_view + |def mv_from_spark_sql_df(): + | return spark_sql_df + |""".stripMargin) + } + assert(ex.getMessage.contains("TABLE_OR_VIEW_NOT_FOUND")) + assert(ex.getMessage.contains("`src`")) + } + test("create dataset with the same name will fail") { assume(PythonTestDepsChecker.isConnectDepsAvailable) val ex = intercept[AnalysisException] { @@ -902,4 +1026,82 @@ class PythonPipelineSuite s"Table should have no transforms, but got: ${stTransforms.mkString(", ")}") } } + + // List of unsupported SQL commands that should result in a failure. + private val unsupportedSqlCommandList: Seq[String] = Seq( + "SET CATALOG some_catalog", + "USE SCHEMA some_schema", + "SET `test_conf` = `true`", + "CREATE TABLE some_table (id INT)", + "CREATE VIEW some_view AS SELECT * FROM some_table", + "INSERT INTO some_table VALUES (1)", + "ALTER TABLE some_table RENAME TO some_new_table", + "CREATE NAMESPACE some_namespace", + "DROP VIEW some_view", + "CREATE MATERIALIZED VIEW some_view AS SELECT * FROM some_table", + "CREATE STREAMING TABLE some_table AS SELECT * FROM some_table") + + gridTest("Unsupported SQL command outside query function should result in a failure")( + unsupportedSqlCommandList) { unsupportedSqlCommand => + val ex = intercept[RuntimeException] { + buildGraph(s""" + |spark.sql("$unsupportedSqlCommand") + | + |@dp.materialized_view() + |def mv(): + | return spark.range(5) + |""".stripMargin) + } + assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND")) + } + + gridTest("Unsupported SQL command inside query function should result in a failure")( + unsupportedSqlCommandList) { unsupportedSqlCommand => + val ex = intercept[RuntimeException] { + buildGraph(s""" + |@dp.materialized_view() + |def mv(): + | spark.sql("$unsupportedSqlCommand") + | return spark.range(5) + |""".stripMargin) + } + assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND")) + } + + // List of supported SQL commands that should work. + val supportedSqlCommandList: Seq[String] = Seq( + "DESCRIBE TABLE spark_catalog.default.src", + "SHOW TABLES", + "SHOW TBLPROPERTIES spark_catalog.default.src", + "SHOW NAMESPACES", + "SHOW COLUMNS FROM spark_catalog.default.src", + "SHOW FUNCTIONS", + "SHOW VIEWS", + "SHOW CATALOGS", + "SHOW CREATE TABLE spark_catalog.default.src", + "SELECT * FROM RANGE(5)", + "SELECT * FROM spark_catalog.default.src") + + gridTest("Supported SQL command outside query function should work")(supportedSqlCommandList) { + supportedSqlCommand => + sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") + buildGraph(s""" + |spark.sql("$supportedSqlCommand") + | + |@dp.materialized_view() + |def mv(): + | return spark.range(5) + |""".stripMargin) + } + + gridTest("Supported SQL command inside query function should work")(supportedSqlCommandList) { + supportedSqlCommand => + sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") + buildGraph(s""" + |@dp.materialized_view() + |def mv(): + | spark.sql("$supportedSqlCommand") + | return spark.range(5) + |""".stripMargin) + } } From c6b88d0fa0f8e89f9aa3371cd21a8c6410efbebb Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Wed, 12 Nov 2025 20:18:26 -0800 Subject: [PATCH 02/18] add python unit tests --- .../test_add_pipeline_analysis_context.py | 72 +++++++++++++++++++ 1 file changed, 72 insertions(+) create mode 100644 python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py diff --git a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py new file mode 100644 index 0000000000000..8601bfbb34413 --- /dev/null +++ b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +from pyspark.errors import PySparkException +from pyspark.testing.connectutils import ReusedConnectTestCase +from pyspark.testing.connectutils import ( + ReusedConnectTestCase, + should_test_connect, + connect_requirement_message, +) + +if should_test_connect: + from pyspark.pipelines.add_pipeline_analysis_context import add_pipeline_analysis_context + + +@unittest.skipIf(not should_test_connect, connect_requirement_message) +class AddPipelineAnalysisContextTests(ReusedConnectTestCase): + def test_add_pipeline_analysis_context_with_flow_name(self): + with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", "test_flow_name"): + import pyspark.sql.connect.proto as pb2 + + thread_local_extensions = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions), 1) + # Extension is stored as (id, extension), unpack the extension + _extension_id, extension = thread_local_extensions[0] + context = pb2.PipelineAnalysisContext() + extension.Unpack(context) + self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id") + self.assertEqual(context.flow_name, "test_flow_name") + thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after), 0) + + def test_add_pipeline_analysis_context_without_flow_name(self): + with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id", None): + import pyspark.sql.connect.proto as pb2 + + thread_local_extensions = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions), 1) + # Extension is stored as (id, extension), unpack the extension + _extension_id, extension = thread_local_extensions[0] + context = pb2.PipelineAnalysisContext() + extension.Unpack(context) + self.assertEqual(context.dataflow_graph_id, "test_dataflow_graph_id") + # Empty string means no flow name + self.assertEqual(context.flow_name, "") + thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after), 0) + + +if __name__ == "__main__": + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2) From 064199102f1535c4e4a6375e7174e7e57a25481e Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Wed, 12 Nov 2025 20:59:29 -0800 Subject: [PATCH 03/18] add more tests --- .../test_add_pipeline_analysis_context.py | 32 +++++++++++++++++++ ...SparkDeclarativePipelinesServerSuite.scala | 26 +++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py index 8601bfbb34413..902059e954ab4 100644 --- a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py +++ b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py @@ -61,6 +61,38 @@ def test_add_pipeline_analysis_context_without_flow_name(self): thread_local_extensions_after = self.spark.client.thread_local.user_context_extensions self.assertEqual(len(thread_local_extensions_after), 0) + def test_nested_add_pipeline_analysis_context(self): + import pyspark.sql.connect.proto as pb2 + + with add_pipeline_analysis_context( + self.spark, "test_dataflow_graph_id_1", flow_name_opt=None + ): + with add_pipeline_analysis_context( + self.spark, "test_dataflow_graph_id_2", flow_name_opt="test_flow_name" + ): + thread_local_extensions = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions), 2) + # Extension is stored as (id, extension), unpack the extensions + _, extension_1 = thread_local_extensions[0] + context_1 = pb2.PipelineAnalysisContext() + extension_1.Unpack(context_1) + self.assertEqual(context_1.dataflow_graph_id, "test_dataflow_graph_id_1") + self.assertEqual(context_1.flow_name, "") + _, extension_2 = thread_local_extensions[1] + context_2 = pb2.PipelineAnalysisContext() + extension_2.Unpack(context_2) + self.assertEqual(context_2.dataflow_graph_id, "test_dataflow_graph_id_2") + self.assertEqual(context_2.flow_name, "test_flow_name") + thread_local_extensions_after_1 = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after_1), 1) + _, extension_3 = thread_local_extensions_after_1[0] + context_3 = pb2.PipelineAnalysisContext() + extension_3.Unpack(context_3) + self.assertEqual(context_3.dataflow_graph_id, "test_dataflow_graph_id_1") + self.assertEqual(context_3.flow_name, "") + thread_local_extensions_after_2 = self.spark.client.thread_local.user_context_extensions + self.assertEqual(len(thread_local_extensions_after_2), 0) + if __name__ == "__main__": try: diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index ab60462e87351..6d7e294fcaaab 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -51,6 +51,32 @@ class SparkDeclarativePipelinesServerSuite } } + test( + "create dataflow graph set session catalog and database to pipeline " + + "default catalog and database") { + withRawBlockingStub { implicit stub => + // Use default spark_catalog and create a test database + sql("CREATE DATABASE IF NOT EXISTS test_db") + try { + val graphId = sendPlan( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase("test_db") + .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + val definition = + getDefaultSessionHolder.dataflowGraphRegistry.getDataflowGraphOrThrow(graphId) + assert(definition.defaultCatalog == "spark_catalog") + assert(definition.defaultDatabase == "test_db") + assert(getDefaultSessionHolder.session.catalog.currentCatalog() == "spark_catalog") + assert(getDefaultSessionHolder.session.catalog.currentDatabase == "test_db") + } finally { + sql("DROP DATABASE IF EXISTS test_db") + } + } + } + test("Define a flow for a graph that does not exist") { val ex = intercept[Exception] { withRawBlockingStub { implicit stub => From a5a3fbb3f5d70795f103e75a8301bb89d694f37f Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Wed, 12 Nov 2025 21:57:42 -0800 Subject: [PATCH 04/18] sandy --- .../add_pipeline_analysis_context.py | 24 ++++---- .../pyspark/pipelines/block_connect_access.py | 60 ++++++++++--------- python/pyspark/pipelines/cli.py | 2 +- .../spark_connect_graph_element_registry.py | 2 +- .../test_add_pipeline_analysis_context.py | 6 +- .../connect/pipelines/PipelinesHandler.scala | 22 ++++--- .../pipelines/PythonPipelineSuite.scala | 2 +- 7 files changed, 61 insertions(+), 57 deletions(-) diff --git a/python/pyspark/pipelines/add_pipeline_analysis_context.py b/python/pyspark/pipelines/add_pipeline_analysis_context.py index 00b251660e2a7..f67353926c3fc 100644 --- a/python/pyspark/pipelines/add_pipeline_analysis_context.py +++ b/python/pyspark/pipelines/add_pipeline_analysis_context.py @@ -23,26 +23,26 @@ @contextmanager def add_pipeline_analysis_context( - spark: SparkSession, dataflow_graph_id: str, flow_name_opt: Optional[str] + spark: SparkSession, dataflow_graph_id: str, flow_name: Optional[str] ) -> Generator[None, None, None]: """ Context manager that add PipelineAnalysisContext extension to the user context used for pipeline specific analysis. """ - _extension_id = None - _client = cast(Any, spark).client + extension_id = None + # Cast because mypy seems to think `spark` is a function, not an object. + # Likely related to SPARK-47544. + client = cast(Any, spark).client try: import pyspark.sql.connect.proto as pb2 from google.protobuf import any_pb2 - _analysis_context = pb2.PipelineAnalysisContext(dataflow_graph_id=dataflow_graph_id) - if flow_name_opt is not None: - _analysis_context.flow_name = flow_name_opt - - _extension = any_pb2.Any() - _extension.Pack(_analysis_context) - - _extension_id = _client.add_threadlocal_user_context_extension(_extension) + _analysis_context = pb2.PipelineAnalysisContext( + dataflow_graph_id=dataflow_graph_id, flow_name=flow_name + ) + extension = any_pb2.Any() + extension.Pack(_analysis_context) + extension_id = client.add_threadlocal_user_context_extension(extension) yield finally: - _client.remove_user_context_extension(_extension_id) + client.remove_user_context_extension(extension_id) diff --git a/python/pyspark/pipelines/block_connect_access.py b/python/pyspark/pipelines/block_connect_access.py index 28cfbb051d811..3c72932d0e2ec 100644 --- a/python/pyspark/pipelines/block_connect_access.py +++ b/python/pyspark/pipelines/block_connect_access.py @@ -24,20 +24,25 @@ BLOCKED_RPC_NAMES = ["AnalyzePlan", "ExecutePlan"] -def _is_sql_command_request(request: object) -> bool: - """Check if the request is spark.sql() command (ExecutePlanRequest with a sql_command).""" - try: - if not hasattr(request, "plan"): - return False - - plan = request.plan +def _is_sql_command_request(rpc_name: str, args: tuple) -> bool: + """ + Check if the RPC call is a spark.sql() command (ExecutePlan with sql_command). - if not plan.HasField("command"): - return False + :param rpc_name: Name of the RPC being called + :param args: Arguments passed to the RPC + :return: True if this is an ExecutePlan request with a sql_command + """ + if rpc_name != "ExecutePlan" or len(args) == 0: + return False - return plan.command.HasField("sql_command") - except Exception: + request = args[0] + if not hasattr(request, "plan"): return False + plan = request.plan + if not plan.HasField("command"): + return False + command = plan.command + return command.HasField("sql_command") @contextmanager @@ -54,26 +59,23 @@ def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]: # Define a new __getattribute__ method that blocks RPC calls def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable: - if name not in BLOCKED_RPC_NAMES: - return original_getattr(self, name) - - # Get the original method first original_method = original_getattr(self, name) - def blocked_method(*args: object, **kwargs: object): - # allowlist spark.sql() command (ExecutePlan with sql_command) - if name == "ExecutePlan" and len(args) > 0: - request = args[0] - if _is_sql_command_request(request): - return original_method(*args, **kwargs) - - # Block all other ExecutePlan and AnalyzePlan calls - raise PySparkException( - errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", - messageParameters={}, - ) - - return blocked_method + def intercepted_method(*args: object, **kwargs: object): + # Allow all RPCs that are not AnalyzePlan or ExecutePlan + if name not in BLOCKED_RPC_NAMES: + return original_method(*args, **kwargs) + # Allow spark.sql() commands (ExecutePlan with sql_command) + elif _is_sql_command_request(name, args): + return original_method(*args, **kwargs) + # Block all other AnalyzePlan and ExecutePlan calls + else: + raise PySparkException( + errorClass="ATTEMPT_ANALYSIS_IN_PIPELINE_QUERY_FUNCTION", + messageParameters={}, + ) + + return intercepted_method try: # Apply our custom __getattribute__ method diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 6994d072a50e2..3ba0bb58fe946 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -252,7 +252,7 @@ def register_definitions( module_spec.loader is not None ), f"Module spec has no loader for {file}" with add_pipeline_analysis_context( - spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name_opt=None + spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None ): with block_session_mutations(): module_spec.loader.exec_module(module) diff --git a/python/pyspark/pipelines/spark_connect_graph_element_registry.py b/python/pyspark/pipelines/spark_connect_graph_element_registry.py index 79d54bb3bf776..b8d297fced3fb 100644 --- a/python/pyspark/pipelines/spark_connect_graph_element_registry.py +++ b/python/pyspark/pipelines/spark_connect_graph_element_registry.py @@ -113,7 +113,7 @@ def register_output(self, output: Output) -> None: def register_flow(self, flow: Flow) -> None: with add_pipeline_analysis_context( - spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name_opt=flow.name + spark=self._spark, dataflow_graph_id=self._dataflow_graph_id, flow_name=flow.name ): with block_spark_connect_execution_and_analysis(): df = flow.func() diff --git a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py index 902059e954ab4..028b95779a4d2 100644 --- a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py +++ b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py @@ -64,11 +64,9 @@ def test_add_pipeline_analysis_context_without_flow_name(self): def test_nested_add_pipeline_analysis_context(self): import pyspark.sql.connect.proto as pb2 - with add_pipeline_analysis_context( - self.spark, "test_dataflow_graph_id_1", flow_name_opt=None - ): + with add_pipeline_analysis_context(self.spark, "test_dataflow_graph_id_1", flow_name=None): with add_pipeline_analysis_context( - self.spark, "test_dataflow_graph_id_2", flow_name_opt="test_flow_name" + self.spark, "test_dataflow_graph_id_2", flow_name="test_flow_name" ): thread_local_extensions = self.spark.client.thread_local.user_context_extensions self.assertEqual(len(thread_local_extensions), 2) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala index 356fd4f473172..4c60e0f70ff4c 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/pipelines/PipelinesHandler.scala @@ -132,10 +132,17 @@ private[connect] object PipelinesHandler extends Logging { } /** - * Block unsupported SQL commands that are not explicitly allowlisted. + * Block SQL commands that have side effects or modify data. + * + * Pipeline definitions should be declarative and side-effect free. This prevents users from + * inadvertently modifying catalogs, creating tables, or performing other stateful operations + * outside the pipeline API boundary during pipeline registration or analysis. + * + * This is a best-effort approach: we block known problematic commands while allowing a curated + * set of read-only operations (e.g., SHOW, DESCRIBE). */ def blockUnsupportedSqlCommand(queryPlan: LogicalPlan): Unit = { - val supportedCommand = Set( + val allowlistedCommands = Set( classOf[DescribeRelation], classOf[ShowTables], classOf[ShowTableProperties], @@ -145,16 +152,14 @@ private[connect] object PipelinesHandler extends Logging { classOf[ShowViews], classOf[ShowCatalogsCommand], classOf[ShowCreateTable]) - val isSqlCommandExplicitlyAllowlisted = { - supportedCommand.exists(c => queryPlan.getClass.getName.equals(c.getName)) - } + val isSqlCommandExplicitlyAllowlisted = allowlistedCommands.exists(_.isInstance(queryPlan)) val isUnsupportedSqlPlan = if (isSqlCommandExplicitlyAllowlisted) { false } else { - // If the SQL command is not explicitly allowlisted, check whether it belongs to - // one of commands pipeline explicitly disallow. - // If not, the SQL command is supported. + // Disable all [[Command]] except the ones that are explicitly allowlisted + // in "allowlistedCommands". queryPlan.isInstanceOf[Command] || + // Following commands are not subclasses of [[Command]] but have side effects. queryPlan.isInstanceOf[CreateTableAsSelect] || queryPlan.isInstanceOf[CreateTable] || queryPlan.isInstanceOf[CreateView] || @@ -163,7 +168,6 @@ private[connect] object PipelinesHandler extends Logging { queryPlan.isInstanceOf[CreateNamespace] || queryPlan.isInstanceOf[DropView] } - // scalastyle:on if (isUnsupportedSqlPlan) { throw new AnalysisException( "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND", diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 4999a62c26d0d..9142ee19e9f88 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -83,7 +83,7 @@ class PythonPipelineSuite | |registry = SparkConnectGraphElementRegistry(spark, dataflow_graph_id) |with add_pipeline_analysis_context( - | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name_opt=None + | spark=spark, dataflow_graph_id=dataflow_graph_id, flow_name=None |): | with graph_element_registration_context(registry): |$indentedPythonText From 506dcddea42566158cd86003682a94bbace8ea4c Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Wed, 12 Nov 2025 22:08:37 -0800 Subject: [PATCH 05/18] test failure --- .../pyspark/pipelines/add_pipeline_analysis_context.py | 4 ++-- .../sql/connect/pipelines/PythonPipelineSuite.scala | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/python/pyspark/pipelines/add_pipeline_analysis_context.py b/python/pyspark/pipelines/add_pipeline_analysis_context.py index f67353926c3fc..6ecabdf43b072 100644 --- a/python/pyspark/pipelines/add_pipeline_analysis_context.py +++ b/python/pyspark/pipelines/add_pipeline_analysis_context.py @@ -37,11 +37,11 @@ def add_pipeline_analysis_context( import pyspark.sql.connect.proto as pb2 from google.protobuf import any_pb2 - _analysis_context = pb2.PipelineAnalysisContext( + analysis_context = pb2.PipelineAnalysisContext( dataflow_graph_id=dataflow_graph_id, flow_name=flow_name ) extension = any_pb2.Any() - extension.Pack(_analysis_context) + extension.Pack(analysis_context) extension_id = client.add_threadlocal_user_context_extension(extension) yield finally: diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 9142ee19e9f88..7274e214bee92 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -149,7 +149,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(28), + line = Option(34), objectName = Option("spark_catalog.default.table1"), objectType = Option(QueryOriginType.Flow.toString))), errorChecker = ex => @@ -201,7 +201,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(34), + line = Option(40), objectName = Option("spark_catalog.default.mv2"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) @@ -215,7 +215,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(38), + line = Option(44), objectName = Option("spark_catalog.default.mv"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) @@ -233,7 +233,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(28), + line = Option(34), objectName = Option("spark_catalog.default.table1"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) @@ -247,7 +247,7 @@ class PythonPipelineSuite QueryOrigin( language = Option(Python()), filePath = Option(""), - line = Option(43), + line = Option(49), objectName = Option("spark_catalog.default.standalone_flow1"), objectType = Option(QueryOriginType.Flow.toString))), expectedEventLevel = EventLevel.INFO) From 4a74125dc178df0a66e622180fe3d2923dcdfbbb Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 11:46:49 -0800 Subject: [PATCH 06/18] fix tests --- .../resources/error/error-conditions.json | 12 +-- .../pyspark/pipelines/block_connect_access.py | 4 +- .../test_add_pipeline_analysis_context.py | 2 - python/pyspark/pipelines/tests/test_cli.py | 15 +++- python/pyspark/sql/connect/client/core.py | 2 +- .../sql/tests/connect/client/test_client.py | 1 + ...SparkDeclarativePipelinesServerSuite.scala | 74 ++++++++++--------- 7 files changed, 60 insertions(+), 50 deletions(-) diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index da3e946d8a421..1dc63b349616a 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -6261,12 +6261,6 @@ }, "sqlState" : "0A000" }, - "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND": { - "message" : [ - "'' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline." - ], - "sqlState" : "0A000" - }, "UNSUPPORTED_CHAR_OR_VARCHAR_AS_STRING" : { "message" : [ "The char/varchar type can't be used in the table schema.", @@ -6930,6 +6924,12 @@ ], "sqlState" : "0A000" }, + "UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND" : { + "message" : [ + "'' is not supported in spark.sql(\"...\") API in Spark Declarative Pipeline." + ], + "sqlState" : "0A000" + }, "UNSUPPORTED_SAVE_MODE" : { "message" : [ "The save mode is not supported for:" diff --git a/python/pyspark/pipelines/block_connect_access.py b/python/pyspark/pipelines/block_connect_access.py index 3c72932d0e2ec..696d0e39b005d 100644 --- a/python/pyspark/pipelines/block_connect_access.py +++ b/python/pyspark/pipelines/block_connect_access.py @@ -15,7 +15,7 @@ # limitations under the License. # from contextlib import contextmanager -from typing import Callable, Generator, NoReturn +from typing import Any, Callable, Generator from pyspark.errors import PySparkException from pyspark.sql.connect.proto.base_pb2_grpc import SparkConnectServiceStub @@ -61,7 +61,7 @@ def block_spark_connect_execution_and_analysis() -> Generator[None, None, None]: def blocked_getattr(self: SparkConnectServiceStub, name: str) -> Callable: original_method = original_getattr(self, name) - def intercepted_method(*args: object, **kwargs: object): + def intercepted_method(*args: object, **kwargs: object) -> Any: # Allow all RPCs that are not AnalyzePlan or ExecutePlan if name not in BLOCKED_RPC_NAMES: return original_method(*args, **kwargs) diff --git a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py index 028b95779a4d2..57c5da22d4601 100644 --- a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py +++ b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py @@ -16,8 +16,6 @@ # import unittest -from pyspark.errors import PySparkException -from pyspark.testing.connectutils import ReusedConnectTestCase from pyspark.testing.connectutils import ( ReusedConnectTestCase, should_test_connect, diff --git a/python/pyspark/pipelines/tests/test_cli.py b/python/pyspark/pipelines/tests/test_cli.py index e8445e63d439d..ff3022fa29663 100644 --- a/python/pyspark/pipelines/tests/test_cli.py +++ b/python/pyspark/pipelines/tests/test_cli.py @@ -22,6 +22,7 @@ from pyspark.errors import PySparkException from pyspark.testing.connectutils import ( + ReusedConnectTestCase, should_test_connect, connect_requirement_message, ) @@ -45,7 +46,7 @@ not should_test_connect or not have_yaml, connect_requirement_message or yaml_requirement_message, ) -class CLIUtilityTests(unittest.TestCase): +class CLIUtilityTests(ReusedConnectTestCase): def test_load_pipeline_spec(self): with tempfile.NamedTemporaryFile(mode="w") as tmpfile: tmpfile.write( @@ -294,7 +295,9 @@ def mv2(): ) registry = LocalGraphElementRegistry() - register_definitions(outer_dir / "pipeline.yaml", registry, spec) + register_definitions( + outer_dir / "pipeline.yaml", registry, spec, self.spark, "test_graph_id" + ) self.assertEqual(len(registry.outputs), 1) self.assertEqual(registry.outputs[0].name, "mv1") @@ -315,7 +318,9 @@ def test_register_definitions_file_raises_error(self): registry = LocalGraphElementRegistry() with self.assertRaises(RuntimeError) as context: - register_definitions(outer_dir / "pipeline.yml", registry, spec) + register_definitions( + outer_dir / "pipeline.yml", registry, spec, self.spark, "test_graph_id" + ) self.assertIn("This is a test exception", str(context.exception)) def test_register_definitions_unsupported_file_extension_matches_glob(self): @@ -334,7 +339,7 @@ def test_register_definitions_unsupported_file_extension_matches_glob(self): registry = LocalGraphElementRegistry() with self.assertRaises(PySparkException) as context: - register_definitions(outer_dir, registry, spec) + register_definitions(outer_dir, registry, spec, self.spark, "test_graph_id") self.assertEqual( context.exception.getCondition(), "PIPELINE_UNSUPPORTED_DEFINITIONS_FILE_EXTENSION" ) @@ -382,6 +387,8 @@ def test_python_import_current_directory(self): configuration={}, libraries=[LibrariesGlob(include="defs.py")], ), + self.spark, + "test_graph_id", ) def test_full_refresh_all_conflicts_with_full_refresh(self): diff --git a/python/pyspark/sql/connect/client/core.py b/python/pyspark/sql/connect/client/core.py index 3d706e3187b00..48e07642e1574 100644 --- a/python/pyspark/sql/connect/client/core.py +++ b/python/pyspark/sql/connect/client/core.py @@ -727,7 +727,7 @@ def __init__( # cleanup ml cache if possible atexit.register(self._cleanup_ml_cache) - self.global_user_context_extensions = [] + self.global_user_context_extensions: List[Tuple[str, any_pb2.Any]] = [] self.global_user_context_extensions_lock = threading.Lock() @property diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 2006db03b95a2..e2a572c8158e4 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -26,6 +26,7 @@ if should_test_connect: import grpc import google.protobuf.any_pb2 as any_pb2 + import google.protobuf.wrappers_pb2 as wrappers_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index 6d7e294fcaaab..b150934bbc107 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -541,8 +541,7 @@ class SparkDeclarativePipelinesServerSuite name: String, datasetType: OutputType, datasetName: String, - defaultCatalog: String = "", - defaultDatabase: String = "", + defaultDatabase: String, expectedResolvedDatasetName: String, expectedResolvedCatalog: String, expectedResolvedNamespace: Seq[String]) @@ -552,6 +551,7 @@ class SparkDeclarativePipelinesServerSuite name = "TEMPORARY_VIEW", datasetType = OutputType.TEMPORARY_VIEW, datasetName = "tv", + defaultDatabase = "default", expectedResolvedDatasetName = "tv", expectedResolvedCatalog = "", expectedResolvedNamespace = Seq.empty), @@ -559,6 +559,7 @@ class SparkDeclarativePipelinesServerSuite name = "TABLE", datasetType = OutputType.TABLE, datasetName = "`tb`", + defaultDatabase = "default", expectedResolvedDatasetName = "tb", expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("default")), @@ -566,6 +567,7 @@ class SparkDeclarativePipelinesServerSuite name = "MV", datasetType = OutputType.MATERIALIZED_VIEW, datasetName = "mv", + defaultDatabase = "default", expectedResolvedDatasetName = "mv", expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("default"))).map(tc => tc.name -> tc).toMap @@ -575,7 +577,6 @@ class SparkDeclarativePipelinesServerSuite name = "TEMPORARY_VIEW", datasetType = OutputType.TEMPORARY_VIEW, datasetName = "tv", - defaultCatalog = "custom_catalog", defaultDatabase = "custom_db", expectedResolvedDatasetName = "tv", expectedResolvedCatalog = "", @@ -584,19 +585,17 @@ class SparkDeclarativePipelinesServerSuite name = "TABLE", datasetType = OutputType.TABLE, datasetName = "`tb`", - defaultCatalog = "`my_catalog`", defaultDatabase = "`my_db`", expectedResolvedDatasetName = "tb", - expectedResolvedCatalog = "`my_catalog`", + expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("`my_db`")), DefineOutputTestCase( name = "MV", datasetType = OutputType.MATERIALIZED_VIEW, datasetName = "mv", - defaultCatalog = "another_catalog", defaultDatabase = "another_db", expectedResolvedDatasetName = "mv", - expectedResolvedCatalog = "another_catalog", + expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("another_db"))) .map(tc => tc.name -> tc) .toMap @@ -630,40 +629,45 @@ class SparkDeclarativePipelinesServerSuite } } - namedGridTest("DefineOutput returns resolved data name for custom catalog/schema")( + namedGridTest("DefineOutput returns resolved data name for custom schema")( defineDatasetCustomTests) { testCase => withRawBlockingStub { implicit stub => - // Build and send the CreateDataflowGraph command with custom catalog/db - val graphId = sendPlan( - buildCreateDataflowGraphPlan( - proto.PipelineCommand.CreateDataflowGraph - .newBuilder() - .setDefaultCatalog(testCase.defaultCatalog) - .setDefaultDatabase(testCase.defaultDatabase) - .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + sql(s"CREATE DATABASE IF NOT EXISTS spark_catalog.${testCase.defaultDatabase}") + try { + // Build and send the CreateDataflowGraph command with custom catalog/db + val graphId = sendPlan( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase(testCase.defaultDatabase) + .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId - assert(graphId.nonEmpty) + assert(graphId.nonEmpty) - // Build DefineOutput with the created graphId and dataset info - val defineDataset = DefineOutput - .newBuilder() - .setDataflowGraphId(graphId) - .setOutputName(testCase.datasetName) - .setOutputType(testCase.datasetType) - val pipelineCmd = PipelineCommand - .newBuilder() - .setDefineOutput(defineDataset) - .build() + // Build DefineOutput with the created graphId and dataset info + val defineDataset = DefineOutput + .newBuilder() + .setDataflowGraphId(graphId) + .setOutputName(testCase.datasetName) + .setOutputType(testCase.datasetType) + val pipelineCmd = PipelineCommand + .newBuilder() + .setDefineOutput(defineDataset) + .build() - val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult - assert(res !== PipelineCommandResult.getDefaultInstance) - assert(res.hasDefineOutputResult) - val graphResult = res.getDefineOutputResult - val identifier = graphResult.getResolvedIdentifier + val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult + assert(res !== PipelineCommandResult.getDefaultInstance) + assert(res.hasDefineOutputResult) + val graphResult = res.getDefineOutputResult + val identifier = graphResult.getResolvedIdentifier - assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) - assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) - assert(identifier.getTableName == testCase.expectedResolvedDatasetName) + assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) + assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) + assert(identifier.getTableName == testCase.expectedResolvedDatasetName) + } finally { + sql(s"DROP DATABASE IF EXISTS spark_catalog.${testCase.defaultDatabase}") + } } } From 56d5ece17edf766f26cc66bd0274fa9cce1d76f0 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 12:20:39 -0800 Subject: [PATCH 07/18] fix test_client --- python/pyspark/sql/tests/connect/client/test_client.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index e2a572c8158e4..189553bee75ef 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -182,6 +182,15 @@ def Config(self, req: proto.ConfigRequest, metadata): pair.value = req.operation.get_with_default.pairs[0].value or "true" return resp + def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): + self.req = req + self.client_user_context_extensions = req.user_context.extensions + resp = proto.AnalyzePlanResponse() + resp.session_id = self._session_id + # Return a minimal response with a semantic hash + resp.semantic_hash.result = 12345 + return resp + # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) # and it blocks the test process exiting because it is registered as the atexit handler # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. From de385722058e63afa6c0d89a59be3878ebbcc229 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 15:03:44 -0800 Subject: [PATCH 08/18] fix tests --- ...SparkDeclarativePipelinesServerSuite.scala | 120 +++++++++--------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala index b150934bbc107..c9551646385c2 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/SparkDeclarativePipelinesServerSuite.scala @@ -675,7 +675,6 @@ class SparkDeclarativePipelinesServerSuite name: String, datasetType: OutputType, flowName: String, - defaultCatalog: String, defaultDatabase: String, expectedResolvedFlowName: String, expectedResolvedCatalog: String, @@ -686,7 +685,6 @@ class SparkDeclarativePipelinesServerSuite name = "MV", datasetType = OutputType.MATERIALIZED_VIEW, flowName = "`mv`", - defaultCatalog = "`spark_catalog`", defaultDatabase = "`default`", expectedResolvedFlowName = "mv", expectedResolvedCatalog = "spark_catalog", @@ -695,7 +693,6 @@ class SparkDeclarativePipelinesServerSuite name = "TV", datasetType = OutputType.TEMPORARY_VIEW, flowName = "tv", - defaultCatalog = "spark_catalog", defaultDatabase = "default", expectedResolvedFlowName = "tv", expectedResolvedCatalog = "", @@ -706,16 +703,14 @@ class SparkDeclarativePipelinesServerSuite name = "MV custom", datasetType = OutputType.MATERIALIZED_VIEW, flowName = "mv", - defaultCatalog = "custom_catalog", defaultDatabase = "custom_db", expectedResolvedFlowName = "mv", - expectedResolvedCatalog = "custom_catalog", + expectedResolvedCatalog = "spark_catalog", expectedResolvedNamespace = Seq("custom_db")), DefineFlowTestCase( name = "TV custom", datasetType = OutputType.TEMPORARY_VIEW, flowName = "tv", - defaultCatalog = "custom_catalog", defaultDatabase = "custom_db", expectedResolvedFlowName = "tv", expectedResolvedCatalog = "", @@ -786,68 +781,73 @@ class SparkDeclarativePipelinesServerSuite namedGridTest("DefineFlow returns resolved data name for custom catalog/schema")( defineFlowCustomTests) { testCase => withRawBlockingStub { implicit stub => - val graphId = sendPlan( - buildCreateDataflowGraphPlan( - proto.PipelineCommand.CreateDataflowGraph + sql(s"CREATE DATABASE IF NOT EXISTS spark_catalog.${testCase.defaultDatabase}") + try { + val graphId = sendPlan( + buildCreateDataflowGraphPlan( + proto.PipelineCommand.CreateDataflowGraph + .newBuilder() + .setDefaultCatalog("spark_catalog") + .setDefaultDatabase(testCase.defaultDatabase) + .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId + assert(graphId.nonEmpty) + + // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly first + if (testCase.datasetType == OutputType.TEMPORARY_VIEW) { + val defineDataset = DefineOutput .newBuilder() - .setDefaultCatalog(testCase.defaultCatalog) - .setDefaultDatabase(testCase.defaultDatabase) - .build())).getPipelineCommandResult.getCreateDataflowGraphResult.getDataflowGraphId - assert(graphId.nonEmpty) + .setDataflowGraphId(graphId) + .setOutputName(testCase.flowName) + .setOutputType(OutputType.TEMPORARY_VIEW) - // If the dataset type is TEMPORARY_VIEW, define the dataset explicitly first - if (testCase.datasetType == OutputType.TEMPORARY_VIEW) { - val defineDataset = DefineOutput + val defineDatasetCmd = PipelineCommand + .newBuilder() + .setDefineOutput(defineDataset) + .build() + + val datasetRes = + sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult + assert(datasetRes.hasDefineOutputResult) + } + + val defineFlow = DefineFlow .newBuilder() .setDataflowGraphId(graphId) - .setOutputName(testCase.flowName) - .setOutputType(OutputType.TEMPORARY_VIEW) - - val defineDatasetCmd = PipelineCommand + .setFlowName(testCase.flowName) + .setTargetDatasetName(testCase.flowName) + .setRelationFlowDetails( + DefineFlow.WriteRelationFlowDetails + .newBuilder() + .setRelation( + Relation + .newBuilder() + .setUnresolvedTableValuedFunction( + UnresolvedTableValuedFunction + .newBuilder() + .setFunctionName("range") + .addArguments(Expression + .newBuilder() + .setLiteral(Expression.Literal.newBuilder().setInteger(5).build()) + .build()) + .build()) + .build()) + .build()) + .build() + val pipelineCmd = PipelineCommand .newBuilder() - .setDefineOutput(defineDataset) + .setDefineFlow(defineFlow) .build() + val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult + assert(res.hasDefineFlowResult) + val graphResult = res.getDefineFlowResult + val identifier = graphResult.getResolvedIdentifier - val datasetRes = - sendPlan(buildPlanFromPipelineCommand(defineDatasetCmd)).getPipelineCommandResult - assert(datasetRes.hasDefineOutputResult) + assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) + assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) + assert(identifier.getTableName == testCase.expectedResolvedFlowName) + } finally { + sql(s"DROP DATABASE IF EXISTS spark_catalog.${testCase.defaultDatabase}") } - - val defineFlow = DefineFlow - .newBuilder() - .setDataflowGraphId(graphId) - .setFlowName(testCase.flowName) - .setTargetDatasetName(testCase.flowName) - .setRelationFlowDetails( - DefineFlow.WriteRelationFlowDetails - .newBuilder() - .setRelation( - Relation - .newBuilder() - .setUnresolvedTableValuedFunction( - UnresolvedTableValuedFunction - .newBuilder() - .setFunctionName("range") - .addArguments(Expression - .newBuilder() - .setLiteral(Expression.Literal.newBuilder().setInteger(5).build()) - .build()) - .build()) - .build()) - .build()) - .build() - val pipelineCmd = PipelineCommand - .newBuilder() - .setDefineFlow(defineFlow) - .build() - val res = sendPlan(buildPlanFromPipelineCommand(pipelineCmd)).getPipelineCommandResult - assert(res.hasDefineFlowResult) - val graphResult = res.getDefineFlowResult - val identifier = graphResult.getResolvedIdentifier - - assert(identifier.getCatalogName == testCase.expectedResolvedCatalog) - assert(identifier.getNamespaceList.asScala == testCase.expectedResolvedNamespace) - assert(identifier.getTableName == testCase.expectedResolvedFlowName) } } } From 6a7d66f7ddd64d404903baa3a6a36d8980e42942 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 16:55:03 -0800 Subject: [PATCH 09/18] Done --- .github/workflows/build_and_test.yml | 2 +- .github/workflows/maven_test.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 19b71026a61c2..50f6ca2cd35c5 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -362,7 +362,7 @@ jobs: - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') || contains(matrix.modules, 'yarn') run: | - python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' + python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'googleapis-common-protos==1.71.0' 'zstandard==0.25.0' python3.11 -m pip list # Run the tests. - name: Run tests diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 7a47d47620a70..67d6777f993ba 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -175,7 +175,7 @@ jobs: - name: Install Python packages (Python 3.11) if: contains(matrix.modules, 'resource-managers#yarn') || (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect') run: | - python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' + python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'googleapis-common-protos==1.71.0' 'zstandard==0.25.0' python3.11 -m pip list # Run the tests using script command. # BSD's script command doesn't support -c option, and the usage is different from Linux's one. From 5c214d74320d9caf37490c2df8fbf32642ca3655 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 18:44:28 -0800 Subject: [PATCH 10/18] fix EndToEndAPISuite --- .../apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala index 55b8a315df570..f674b45bb072d 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/EndToEndAPISuite.scala @@ -162,7 +162,7 @@ class EndToEndAPISuite extends PipelineTest with APITest with SparkConnectServer |name: test-pipeline |${spec.catalog.map(catalog => s"""catalog: "$catalog"""").getOrElse("")} |${spec.database.map(database => s"""database: "$database"""").getOrElse("")} - |storage: "${projectDir.resolve("storage").toAbsolutePath}" + |storage: "file://${projectDir.resolve("storage").toAbsolutePath}" |configuration: | "spark.remote": "sc://localhost:$serverPort" |libraries: From 433b5377bbd51af73d062548ddb37b7bb8c051a9 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 19:32:03 -0800 Subject: [PATCH 11/18] fix test_client resource leak --- .../sql/tests/connect/client/test_client.py | 165 +++++++++--------- 1 file changed, 84 insertions(+), 81 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 189553bee75ef..bf7108c94b090 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -248,87 +248,90 @@ def test_user_context_extension(self): mock = MockService(client._session_id) client._stub = mock - exlocal = any_pb2.Any() - exlocal.Pack(wrappers_pb2.StringValue(value="abc")) - exlocal2 = any_pb2.Any() - exlocal2.Pack(wrappers_pb2.StringValue(value="def")) - exglobal = any_pb2.Any() - exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) - exglobal2 = any_pb2.Any() - exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) - - exlocal_id = client.add_threadlocal_user_context_extension(exlocal) - exglobal_id = client.add_global_user_context_extension(exglobal) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_threadlocal_user_context_extension(exlocal2) - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_global_user_context_extension(exglobal2) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exlocal_id) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exglobal_id) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.clear_user_context_extensions() - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) + try: + exlocal = any_pb2.Any() + exlocal.Pack(wrappers_pb2.StringValue(value="abc")) + exlocal2 = any_pb2.Any() + exlocal2.Pack(wrappers_pb2.StringValue(value="def")) + exglobal = any_pb2.Any() + exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) + exglobal2 = any_pb2.Any() + exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) + + exlocal_id = client.add_threadlocal_user_context_extension(exlocal) + exglobal_id = client.add_global_user_context_extension(exglobal) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_threadlocal_user_context_extension(exlocal2) + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_global_user_context_extension(exglobal2) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exlocal_id) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exglobal_id) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.clear_user_context_extensions() + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + finally: + client.close() def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) From 887ee0ae90e52b8b92a3ff4bb524f0ed249ca8f8 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 20:04:04 -0800 Subject: [PATCH 12/18] fix test_init_cli --- python/pyspark/pipelines/tests/test_init_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/pipelines/tests/test_init_cli.py b/python/pyspark/pipelines/tests/test_init_cli.py index 43c553eddc387..e51bab6a4a691 100644 --- a/python/pyspark/pipelines/tests/test_init_cli.py +++ b/python/pyspark/pipelines/tests/test_init_cli.py @@ -60,7 +60,7 @@ def test_init(self): self.assertTrue((Path.cwd() / "pipeline-storage").exists()) registry = LocalGraphElementRegistry() - register_definitions(spec_path, registry, spec) + register_definitions(spec_path, registry, spec, self.spark, "test_graph_id") self.assertEqual(len(registry.outputs), 1) self.assertEqual(registry.outputs[0].name, "example_python_materialized_view") self.assertEqual(len(registry.flows), 1) From e483cb04f3efa558dd81808934321ea656bbb4a2 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 20:31:15 -0800 Subject: [PATCH 13/18] fix all resource leaks in test_client.py - add client.close() to all tests --- .../sql/tests/connect/client/test_client.py | 104 +++++++++++------- 1 file changed, 67 insertions(+), 37 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index bf7108c94b090..eca5c9b6866b8 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -204,32 +204,44 @@ def test_user_agent_passthrough(self): mock = MockService(client._session_id) client._stub = mock - command = proto.Command() - client.execute_command(command) + try: + command = proto.Command() + client.execute_command(command) - self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") - self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") + finally: + client.close() def test_user_agent_default(self): client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) mock = MockService(client._session_id) client._stub = mock - command = proto.Command() - client.execute_command(command) + try: + command = proto.Command() + client.execute_command(command) - self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") - self.assertRegex( - mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" - ) + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertRegex( + mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" + ) + finally: + client.close() def test_properties(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) - self.assertEqual(client.token, "bar") - self.assertEqual(client.host, "foo") + try: + self.assertEqual(client.token, "bar") + self.assertEqual(client.host, "foo") + finally: + client.close() client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) - self.assertIsNone(client.token) + try: + self.assertIsNone(client.token) + finally: + client.close() def test_channel_builder(self): class CustomChannelBuilder(DefaultChannelBuilder): @@ -241,7 +253,10 @@ def userId(self) -> Optional[str]: CustomChannelBuilder("sc://foo/"), use_reattachable_execute=False ) - self.assertEqual(client._user_id, "abc") + try: + self.assertEqual(client._user_id, "abc") + finally: + client.close() def test_user_context_extension(self): client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) @@ -338,8 +353,11 @@ def test_interrupt_all(self): mock = MockService(client._session_id) client._stub = mock - client.interrupt_all() - self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") + try: + client.interrupt_all() + self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") + finally: + client.close() def test_is_closed(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) @@ -352,7 +370,10 @@ def test_channel_builder_with_session(self): dummy = str(uuid.uuid4()) chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}") client = SparkConnectClient(chan) - self.assertEqual(client._session_id, chan.session_id) + try: + self.assertEqual(client._session_id, chan.session_id) + finally: + client.close() def test_session_hook(self): inits = 0 @@ -390,11 +411,14 @@ def test_custom_operation_id(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) client._stub = mock - req = client._execute_plan_request_with_metadata( - operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b" - ) - for resp in client._stub.ExecutePlan(req, metadata=None): - assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b" + try: + req = client._execute_plan_request_with_metadata( + operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b" + ) + for resp in client._stub.ExecutePlan(req, metadata=None): + assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b" + finally: + client.close() @unittest.skipIf(not should_test_connect, connect_requirement_message) @@ -576,13 +600,16 @@ def test_server_unreachable(self): client = SparkConnectClient( "sc://foo", use_reattachable_execute=False, retry_policy=dict(max_retries=0) ) - with self.assertRaises(SparkConnectGrpcException) as cm: - command = proto.Command() - client.execute_command(command) - err = cm.exception - self.assertEqual(err.getGrpcStatusCode(), grpc.StatusCode.UNAVAILABLE) - self.assertEqual(err.getErrorClass(), None) - self.assertEqual(err.getSqlState(), None) + try: + with self.assertRaises(SparkConnectGrpcException) as cm: + command = proto.Command() + client.execute_command(command) + err = cm.exception + self.assertEqual(err.getGrpcStatusCode(), grpc.StatusCode.UNAVAILABLE) + self.assertEqual(err.getErrorClass(), None) + self.assertEqual(err.getSqlState(), None) + finally: + client.close() def test_error_codes(self): msg = "Something went wrong on the server" @@ -645,14 +672,17 @@ def raise_with_sql_state(): client = SparkConnectClient( "sc://foo", use_reattachable_execute=False, retry_policy=dict(max_retries=0) ) - client._stub = self._stub_with([response_function]) - with self.assertRaises(SparkConnectGrpcException) as cm: - command = proto.Command() - client.execute_command(command) - err = cm.exception - self.assertEqual(err.getGrpcStatusCode(), expected_status_code) - self.assertEqual(err.getErrorClass(), expected_error_class) - self.assertEqual(err.getSqlState(), expected_sql_state) + try: + client._stub = self._stub_with([response_function]) + with self.assertRaises(SparkConnectGrpcException) as cm: + command = proto.Command() + client.execute_command(command) + err = cm.exception + self.assertEqual(err.getGrpcStatusCode(), expected_status_code) + self.assertEqual(err.getErrorClass(), expected_error_class) + self.assertEqual(err.getSqlState(), expected_sql_state) + finally: + client.close() if __name__ == "__main__": From 8da65ce04cd39401317543b82236cc7499101f93 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 21:47:15 -0800 Subject: [PATCH 14/18] Revert "fix all resource leaks in test_client.py - add client.close() to all tests" This reverts commit e483cb04f3efa558dd81808934321ea656bbb4a2. --- .../sql/tests/connect/client/test_client.py | 104 +++++++----------- 1 file changed, 37 insertions(+), 67 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index eca5c9b6866b8..bf7108c94b090 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -204,44 +204,32 @@ def test_user_agent_passthrough(self): mock = MockService(client._session_id) client._stub = mock - try: - command = proto.Command() - client.execute_command(command) + command = proto.Command() + client.execute_command(command) - self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") - self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") - finally: - client.close() + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertRegex(mock.req.client_type, r"^bar spark/[^ ]+ os/[^ ]+ python/[^ ]+$") def test_user_agent_default(self): client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) mock = MockService(client._session_id) client._stub = mock - try: - command = proto.Command() - client.execute_command(command) + command = proto.Command() + client.execute_command(command) - self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") - self.assertRegex( - mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" - ) - finally: - client.close() + self.assertIsNotNone(mock.req, "ExecutePlan API was not called when expected") + self.assertRegex( + mock.req.client_type, r"^_SPARK_CONNECT_PYTHON spark/[^ ]+ os/[^ ]+ python/[^ ]+$" + ) def test_properties(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) - try: - self.assertEqual(client.token, "bar") - self.assertEqual(client.host, "foo") - finally: - client.close() + self.assertEqual(client.token, "bar") + self.assertEqual(client.host, "foo") client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) - try: - self.assertIsNone(client.token) - finally: - client.close() + self.assertIsNone(client.token) def test_channel_builder(self): class CustomChannelBuilder(DefaultChannelBuilder): @@ -253,10 +241,7 @@ def userId(self) -> Optional[str]: CustomChannelBuilder("sc://foo/"), use_reattachable_execute=False ) - try: - self.assertEqual(client._user_id, "abc") - finally: - client.close() + self.assertEqual(client._user_id, "abc") def test_user_context_extension(self): client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) @@ -353,11 +338,8 @@ def test_interrupt_all(self): mock = MockService(client._session_id) client._stub = mock - try: - client.interrupt_all() - self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") - finally: - client.close() + client.interrupt_all() + self.assertIsNotNone(mock.req, "Interrupt API was not called when expected") def test_is_closed(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) @@ -370,10 +352,7 @@ def test_channel_builder_with_session(self): dummy = str(uuid.uuid4()) chan = DefaultChannelBuilder(f"sc://foo/;session_id={dummy}") client = SparkConnectClient(chan) - try: - self.assertEqual(client._session_id, chan.session_id) - finally: - client.close() + self.assertEqual(client._session_id, chan.session_id) def test_session_hook(self): inits = 0 @@ -411,14 +390,11 @@ def test_custom_operation_id(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) client._stub = mock - try: - req = client._execute_plan_request_with_metadata( - operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b" - ) - for resp in client._stub.ExecutePlan(req, metadata=None): - assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b" - finally: - client.close() + req = client._execute_plan_request_with_metadata( + operation_id="10a4c38e-7e87-40ee-9d6f-60ff0751e63b" + ) + for resp in client._stub.ExecutePlan(req, metadata=None): + assert resp.operation_id == "10a4c38e-7e87-40ee-9d6f-60ff0751e63b" @unittest.skipIf(not should_test_connect, connect_requirement_message) @@ -600,16 +576,13 @@ def test_server_unreachable(self): client = SparkConnectClient( "sc://foo", use_reattachable_execute=False, retry_policy=dict(max_retries=0) ) - try: - with self.assertRaises(SparkConnectGrpcException) as cm: - command = proto.Command() - client.execute_command(command) - err = cm.exception - self.assertEqual(err.getGrpcStatusCode(), grpc.StatusCode.UNAVAILABLE) - self.assertEqual(err.getErrorClass(), None) - self.assertEqual(err.getSqlState(), None) - finally: - client.close() + with self.assertRaises(SparkConnectGrpcException) as cm: + command = proto.Command() + client.execute_command(command) + err = cm.exception + self.assertEqual(err.getGrpcStatusCode(), grpc.StatusCode.UNAVAILABLE) + self.assertEqual(err.getErrorClass(), None) + self.assertEqual(err.getSqlState(), None) def test_error_codes(self): msg = "Something went wrong on the server" @@ -672,17 +645,14 @@ def raise_with_sql_state(): client = SparkConnectClient( "sc://foo", use_reattachable_execute=False, retry_policy=dict(max_retries=0) ) - try: - client._stub = self._stub_with([response_function]) - with self.assertRaises(SparkConnectGrpcException) as cm: - command = proto.Command() - client.execute_command(command) - err = cm.exception - self.assertEqual(err.getGrpcStatusCode(), expected_status_code) - self.assertEqual(err.getErrorClass(), expected_error_class) - self.assertEqual(err.getSqlState(), expected_sql_state) - finally: - client.close() + client._stub = self._stub_with([response_function]) + with self.assertRaises(SparkConnectGrpcException) as cm: + command = proto.Command() + client.execute_command(command) + err = cm.exception + self.assertEqual(err.getGrpcStatusCode(), expected_status_code) + self.assertEqual(err.getErrorClass(), expected_error_class) + self.assertEqual(err.getSqlState(), expected_sql_state) if __name__ == "__main__": From fe6f1cdb1acfcafdaac9ba0fc05db394c3cc6cab Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 21:47:20 -0800 Subject: [PATCH 15/18] Revert "fix test_client resource leak" This reverts commit 433b5377bbd51af73d062548ddb37b7bb8c051a9. --- .../sql/tests/connect/client/test_client.py | 165 +++++++++--------- 1 file changed, 81 insertions(+), 84 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index bf7108c94b090..189553bee75ef 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -248,90 +248,87 @@ def test_user_context_extension(self): mock = MockService(client._session_id) client._stub = mock - try: - exlocal = any_pb2.Any() - exlocal.Pack(wrappers_pb2.StringValue(value="abc")) - exlocal2 = any_pb2.Any() - exlocal2.Pack(wrappers_pb2.StringValue(value="def")) - exglobal = any_pb2.Any() - exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) - exglobal2 = any_pb2.Any() - exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) - - exlocal_id = client.add_threadlocal_user_context_extension(exlocal) - exglobal_id = client.add_global_user_context_extension(exglobal) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_threadlocal_user_context_extension(exlocal2) - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_global_user_context_extension(exglobal2) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exlocal_id) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exglobal_id) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.clear_user_context_extensions() - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - finally: - client.close() + exlocal = any_pb2.Any() + exlocal.Pack(wrappers_pb2.StringValue(value="abc")) + exlocal2 = any_pb2.Any() + exlocal2.Pack(wrappers_pb2.StringValue(value="def")) + exglobal = any_pb2.Any() + exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) + exglobal2 = any_pb2.Any() + exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) + + exlocal_id = client.add_threadlocal_user_context_extension(exlocal) + exglobal_id = client.add_global_user_context_extension(exglobal) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_threadlocal_user_context_extension(exlocal2) + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + client.add_global_user_context_extension(exglobal2) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertTrue(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exlocal_id) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertTrue(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.remove_user_context_extension(exglobal_id) + + mock.client_user_context_extensions = [] + command = proto.Command() + client.execute_command(command) + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertTrue(exlocal2 in mock.client_user_context_extensions) + self.assertTrue(exglobal2 in mock.client_user_context_extensions) + + client.clear_user_context_extensions() + + mock.client_user_context_extensions = [] + plan = proto.Plan() + client.semantic_hash(plan) # use semantic_hash to test analyze + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.interrupt_all() + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) + + mock.client_user_context_extensions = [] + client.get_configs("foo", "bar") + self.assertFalse(exlocal in mock.client_user_context_extensions) + self.assertFalse(exglobal in mock.client_user_context_extensions) + self.assertFalse(exlocal2 in mock.client_user_context_extensions) + self.assertFalse(exglobal2 in mock.client_user_context_extensions) def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) From 435ff6c545478d729b019aeafa7907c6b7cbfae4 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Thu, 13 Nov 2025 22:14:49 -0800 Subject: [PATCH 16/18] fix PythonPipelineSuite --- .github/workflows/build_and_test.yml | 2 +- .github/workflows/maven_test.yml | 2 +- .../spark/sql/connect/pipelines/PythonPipelineSuite.scala | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 50f6ca2cd35c5..fe71c761f5f8a 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -362,7 +362,7 @@ jobs: - name: Install Python packages (Python 3.11) if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect') || contains(matrix.modules, 'yarn') run: | - python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'googleapis-common-protos==1.71.0' 'zstandard==0.25.0' + python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'lxml==4.9.4' 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'zstandard==0.25.0' python3.11 -m pip list # Run the tests. - name: Run tests diff --git a/.github/workflows/maven_test.yml b/.github/workflows/maven_test.yml index 67d6777f993ba..27ce14afd6df4 100644 --- a/.github/workflows/maven_test.yml +++ b/.github/workflows/maven_test.yml @@ -175,7 +175,7 @@ jobs: - name: Install Python packages (Python 3.11) if: contains(matrix.modules, 'resource-managers#yarn') || (contains(matrix.modules, 'sql#core')) || contains(matrix.modules, 'connect') run: | - python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'googleapis-common-protos==1.71.0' 'zstandard==0.25.0' + python3.11 -m pip install 'numpy>=1.22' pyarrow pandas pyyaml scipy unittest-xml-reporting 'grpcio==1.76.0' 'grpcio-status==1.76.0' 'protobuf==6.33.0' 'zstandard==0.25.0' python3.11 -m pip list # Run the tests using script command. # BSD's script command doesn't support -c option, and the usage is different from Linux's one. diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index 7274e214bee92..c6c4bb0936076 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -410,7 +410,8 @@ class PythonPipelineSuite graphIdentifier("a"), graphIdentifier("b"), graphIdentifier("d"))) - assert(streamingFlows.map(_.identifier) == Seq(graphIdentifier("c"), graphIdentifier("e"))) + assert( + streamingFlows.map(_.identifier).toSet == Set(graphIdentifier("c"), graphIdentifier("e"))) } test("referencing internal datasets failed") { From 314d69edfe0832b22106cb9b6312b3e0de139f89 Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Fri, 14 Nov 2025 11:22:02 -0800 Subject: [PATCH 17/18] revert test_client changes --- .../sql/tests/connect/client/test_client.py | 101 ------------------ 1 file changed, 101 deletions(-) diff --git a/python/pyspark/sql/tests/connect/client/test_client.py b/python/pyspark/sql/tests/connect/client/test_client.py index 189553bee75ef..c189f996cbe43 100644 --- a/python/pyspark/sql/tests/connect/client/test_client.py +++ b/python/pyspark/sql/tests/connect/client/test_client.py @@ -26,7 +26,6 @@ if should_test_connect: import grpc import google.protobuf.any_pb2 as any_pb2 - import google.protobuf.wrappers_pb2 as wrappers_pb2 from google.rpc import status_pb2 from google.rpc.error_details_pb2 import ErrorInfo import pandas as pd @@ -137,11 +136,9 @@ class MockService: def __init__(self, session_id: str): self._session_id = session_id self.req = None - self.client_user_context_extensions = [] def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions resp = proto.ExecutePlanResponse() resp.session_id = self._session_id resp.operation_id = req.operation_id @@ -162,14 +159,12 @@ def ExecutePlan(self, req: proto.ExecutePlanRequest, metadata): def Interrupt(self, req: proto.InterruptRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions resp = proto.InterruptResponse() resp.session_id = self._session_id return resp def Config(self, req: proto.ConfigRequest, metadata): self.req = req - self.client_user_context_extensions = req.user_context.extensions resp = proto.ConfigResponse() resp.session_id = self._session_id if req.operation.HasField("get"): @@ -182,15 +177,6 @@ def Config(self, req: proto.ConfigRequest, metadata): pair.value = req.operation.get_with_default.pairs[0].value or "true" return resp - def AnalyzePlan(self, req: proto.AnalyzePlanRequest, metadata): - self.req = req - self.client_user_context_extensions = req.user_context.extensions - resp = proto.AnalyzePlanResponse() - resp.session_id = self._session_id - # Return a minimal response with a semantic hash - resp.semantic_hash.result = 12345 - return resp - # The _cleanup_ml_cache invocation will hang in this test (no valid spark cluster) # and it blocks the test process exiting because it is registered as the atexit handler # in `SparkConnectClient` constructor. To bypass the issue, patch the method in the test. @@ -243,93 +229,6 @@ def userId(self) -> Optional[str]: self.assertEqual(client._user_id, "abc") - def test_user_context_extension(self): - client = SparkConnectClient("sc://foo/", use_reattachable_execute=False) - mock = MockService(client._session_id) - client._stub = mock - - exlocal = any_pb2.Any() - exlocal.Pack(wrappers_pb2.StringValue(value="abc")) - exlocal2 = any_pb2.Any() - exlocal2.Pack(wrappers_pb2.StringValue(value="def")) - exglobal = any_pb2.Any() - exglobal.Pack(wrappers_pb2.StringValue(value="ghi")) - exglobal2 = any_pb2.Any() - exglobal2.Pack(wrappers_pb2.StringValue(value="jkl")) - - exlocal_id = client.add_threadlocal_user_context_extension(exlocal) - exglobal_id = client.add_global_user_context_extension(exglobal) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_threadlocal_user_context_extension(exlocal2) - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - client.add_global_user_context_extension(exglobal2) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertTrue(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exlocal_id) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertTrue(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.remove_user_context_extension(exglobal_id) - - mock.client_user_context_extensions = [] - command = proto.Command() - client.execute_command(command) - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertTrue(exlocal2 in mock.client_user_context_extensions) - self.assertTrue(exglobal2 in mock.client_user_context_extensions) - - client.clear_user_context_extensions() - - mock.client_user_context_extensions = [] - plan = proto.Plan() - client.semantic_hash(plan) # use semantic_hash to test analyze - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.interrupt_all() - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - - mock.client_user_context_extensions = [] - client.get_configs("foo", "bar") - self.assertFalse(exlocal in mock.client_user_context_extensions) - self.assertFalse(exglobal in mock.client_user_context_extensions) - self.assertFalse(exlocal2 in mock.client_user_context_extensions) - self.assertFalse(exglobal2 in mock.client_user_context_extensions) - def test_interrupt_all(self): client = SparkConnectClient("sc://foo/;token=bar", use_reattachable_execute=False) mock = MockService(client._session_id) From b67b486e4828bbf296c420a15539c844575e381e Mon Sep 17 00:00:00 2001 From: Yuheng Chang Date: Fri, 14 Nov 2025 11:40:15 -0800 Subject: [PATCH 18/18] fix indent --- .../pipelines/PythonPipelineSuite.scala | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala index c6c4bb0936076..1850241f07026 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/pipelines/PythonPipelineSuite.scala @@ -476,17 +476,17 @@ class PythonPipelineSuite test("reading external datasets outside query function works") { sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") val graph = buildGraph(s""" - |spark_sql_df = spark.sql("SELECT * FROM spark_catalog.default.src") - |read_table_df = spark.read.table("spark_catalog.default.src") - | - |@dp.materialized_view - |def mv_from_spark_sql_df(): - | return spark_sql_df - | - |@dp.materialized_view - |def mv_from_read_table_df(): - | return read_table_df - |""".stripMargin).resolve().validate() + |spark_sql_df = spark.sql("SELECT * FROM spark_catalog.default.src") + |read_table_df = spark.read.table("spark_catalog.default.src") + | + |@dp.materialized_view + |def mv_from_spark_sql_df(): + | return spark_sql_df + | + |@dp.materialized_view + |def mv_from_read_table_df(): + | return read_table_df + |""".stripMargin).resolve().validate() assert( graph.resolvedFlows.map(_.identifier).toSet == Set( @@ -1046,12 +1046,12 @@ class PythonPipelineSuite unsupportedSqlCommandList) { unsupportedSqlCommand => val ex = intercept[RuntimeException] { buildGraph(s""" - |spark.sql("$unsupportedSqlCommand") - | - |@dp.materialized_view() - |def mv(): - | return spark.range(5) - |""".stripMargin) + |spark.sql("$unsupportedSqlCommand") + | + |@dp.materialized_view() + |def mv(): + | return spark.range(5) + |""".stripMargin) } assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND")) } @@ -1060,11 +1060,11 @@ class PythonPipelineSuite unsupportedSqlCommandList) { unsupportedSqlCommand => val ex = intercept[RuntimeException] { buildGraph(s""" - |@dp.materialized_view() - |def mv(): - | spark.sql("$unsupportedSqlCommand") - | return spark.range(5) - |""".stripMargin) + |@dp.materialized_view() + |def mv(): + | spark.sql("$unsupportedSqlCommand") + | return spark.range(5) + |""".stripMargin) } assert(ex.getMessage.contains("UNSUPPORTED_PIPELINE_SPARK_SQL_COMMAND")) } @@ -1087,22 +1087,22 @@ class PythonPipelineSuite supportedSqlCommand => sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") buildGraph(s""" - |spark.sql("$supportedSqlCommand") - | - |@dp.materialized_view() - |def mv(): - | return spark.range(5) - |""".stripMargin) + |spark.sql("$supportedSqlCommand") + | + |@dp.materialized_view() + |def mv(): + | return spark.range(5) + |""".stripMargin) } gridTest("Supported SQL command inside query function should work")(supportedSqlCommandList) { supportedSqlCommand => sql("CREATE TABLE spark_catalog.default.src AS SELECT * FROM RANGE(5)") buildGraph(s""" - |@dp.materialized_view() - |def mv(): - | spark.sql("$supportedSqlCommand") - | return spark.range(5) - |""".stripMargin) + |@dp.materialized_view() + |def mv(): + | spark.sql("$supportedSqlCommand") + | return spark.range(5) + |""".stripMargin) } }