diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index 1be61b7ce8fda..042cc9638cd87 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -1517,6 +1517,7 @@ def __hash__(self): source_file_regexes=["python/pyspark/pipelines"], python_test_goals=[ "pyspark.pipelines.tests.test_block_connect_access", + "pyspark.pipelines.tests.test_block_session_mutations", "pyspark.pipelines.tests.test_cli", "pyspark.pipelines.tests.test_decorators", "pyspark.pipelines.tests.test_graph_element_registry", diff --git a/python/pyspark/errors/error-conditions.json b/python/pyspark/errors/error-conditions.json index 2a638bc7ec36b..46c1220dd9663 100644 --- a/python/pyspark/errors/error-conditions.json +++ b/python/pyspark/errors/error-conditions.json @@ -1002,6 +1002,73 @@ "Cannot start a remote Spark session because there is a regular Spark session already running." ] }, + "SESSION_MUTATION_IN_DECLARATIVE_PIPELINE": { + "message": [ + "Session mutation is not allowed in declarative pipelines." + ], + "sub_class": { + "SET_RUNTIME_CONF": { + "message": [ + "Instead set configuration via the pipeline spec or use the 'spark_conf' argument in various decorators." + ] + }, + "SET_CURRENT_CATALOG": { + "message": [ + "Instead set catalog via the pipeline spec or the 'name' argument on the dataset decorators." + ] + }, + "SET_CURRENT_DATABASE": { + "message": [ + "Instead set database via the pipeline spec or the 'name' argument on the dataset decorators." + ] + }, + "DROP_TEMP_VIEW": { + "message": [ + "Instead remove the temporary view definition directly." + ] + }, + "DROP_GLOBAL_TEMP_VIEW": { + "message": [ + "Instead remove the temporary view definition directly." + ] + }, + "CREATE_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "CREATE_OR_REPLACE_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "CREATE_GLOBAL_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW": { + "message": [ + "Instead use the @temporary_view decorator to define temporary views." + ] + }, + "REGISTER_UDF": { + "message": [ + "" + ] + }, + "REGISTER_JAVA_UDF": { + "message": [ + "" + ] + }, + "REGISTER_JAVA_UDAF": { + "message": [ + "" + ] + } + } + }, "SESSION_NEED_CONN_STR_OR_BUILDER": { "message": [ "Needs either connection string or channelBuilder (mutually exclusive) to create a new SparkSession." diff --git a/python/pyspark/pipelines/block_session_mutations.py b/python/pyspark/pipelines/block_session_mutations.py new file mode 100644 index 0000000000000..df63d2023a4ba --- /dev/null +++ b/python/pyspark/pipelines/block_session_mutations.py @@ -0,0 +1,135 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from contextlib import contextmanager +from typing import Generator, NoReturn, List, Callable + +from pyspark.errors import PySparkException +from pyspark.sql.connect.catalog import Catalog +from pyspark.sql.connect.conf import RuntimeConf +from pyspark.sql.connect.dataframe import DataFrame +from pyspark.sql.connect.udf import UDFRegistration + +# pyspark methods that should be blocked from executing in python pipeline definition files +ERROR_CLASS = "SESSION_MUTATION_IN_DECLARATIVE_PIPELINE" +BLOCKED_METHODS: List = [ + { + "class": RuntimeConf, + "method": "set", + "error_sub_class": "SET_RUNTIME_CONF", + }, + { + "class": Catalog, + "method": "setCurrentCatalog", + "error_sub_class": "SET_CURRENT_CATALOG", + }, + { + "class": Catalog, + "method": "setCurrentDatabase", + "error_sub_class": "SET_CURRENT_DATABASE", + }, + { + "class": Catalog, + "method": "dropTempView", + "error_sub_class": "DROP_TEMP_VIEW", + }, + { + "class": Catalog, + "method": "dropGlobalTempView", + "error_sub_class": "DROP_GLOBAL_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createTempView", + "error_sub_class": "CREATE_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createOrReplaceTempView", + "error_sub_class": "CREATE_OR_REPLACE_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createGlobalTempView", + "error_sub_class": "CREATE_GLOBAL_TEMP_VIEW", + }, + { + "class": DataFrame, + "method": "createOrReplaceGlobalTempView", + "error_sub_class": "CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW", + }, + { + "class": UDFRegistration, + "method": "register", + "error_sub_class": "REGISTER_UDF", + }, + { + "class": UDFRegistration, + "method": "registerJavaFunction", + "error_sub_class": "REGISTER_JAVA_UDF", + }, + { + "class": UDFRegistration, + "method": "registerJavaUDAF", + "error_sub_class": "REGISTER_JAVA_UDAF", + }, +] + + +def _create_blocked_method(error_method_name: str, error_sub_class: str) -> Callable: + def blocked_method(*args: object, **kwargs: object) -> NoReturn: + raise PySparkException( + errorClass=f"{ERROR_CLASS}.{error_sub_class}", + messageParameters={ + "method": error_method_name, + }, + ) + + return blocked_method + + +@contextmanager +def block_session_mutations() -> Generator[None, None, None]: + """ + Context manager that blocks imperative constructs found in a pipeline python definition file + See BLOCKED_METHODS above for a list + """ + # Store original methods + original_methods = {} + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + original_methods[(cls, method_name)] = getattr(cls, method_name) + + try: + # Replace methods with blocked versions + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + error_method_name = f"'{cls.__name__}.{method_name}'" + blocked_method = _create_blocked_method( + error_method_name, method_info["error_sub_class"] + ) + setattr(cls, method_name, blocked_method) + + yield + finally: + # Restore original methods + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + original_method = original_methods[(cls, method_name)] + setattr(cls, method_name, original_method) diff --git a/python/pyspark/pipelines/cli.py b/python/pyspark/pipelines/cli.py index 2a0cf880d10c6..f739e055f48ac 100644 --- a/python/pyspark/pipelines/cli.py +++ b/python/pyspark/pipelines/cli.py @@ -32,6 +32,7 @@ from pyspark.errors import PySparkException, PySparkTypeError from pyspark.sql import SparkSession +from pyspark.pipelines.block_session_mutations import block_session_mutations from pyspark.pipelines.graph_element_registry import ( graph_element_registration_context, GraphElementRegistry, @@ -192,7 +193,8 @@ def register_definitions( assert ( module_spec.loader is not None ), f"Module spec has no loader for {file}" - module_spec.loader.exec_module(module) + with block_session_mutations(): + module_spec.loader.exec_module(module) elif file.suffix == ".sql": log_with_curr_timestamp(f"Registering SQL file {file}...") with file.open("r") as f: diff --git a/python/pyspark/pipelines/tests/test_block_session_mutations.py b/python/pyspark/pipelines/tests/test_block_session_mutations.py new file mode 100644 index 0000000000000..771321d73832b --- /dev/null +++ b/python/pyspark/pipelines/tests/test_block_session_mutations.py @@ -0,0 +1,259 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from pyspark.errors import PySparkException +from pyspark.sql.types import StringType +from pyspark.testing.connectutils import ( + ReusedConnectTestCase, + should_test_connect, + connect_requirement_message, +) + +from pyspark.pipelines.block_session_mutations import ( + block_session_mutations, + BLOCKED_METHODS, + ERROR_CLASS, +) + + +@unittest.skipIf(not should_test_connect, connect_requirement_message or "Connect not available") +class BlockImperativeConfSetConnectTests(ReusedConnectTestCase): + def test_blocks_runtime_conf_set(self): + """Test that spark.conf.set() is blocked.""" + config = self.spark.conf + + test_cases = [ + ("spark.test.string", "string_value"), + ("spark.test.int", 42), + ("spark.test.bool", True), + ] + + for key, value in test_cases: + with self.subTest(key=key, value=value): + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + config.set(key, value) + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.SET_RUNTIME_CONF", + ) + self.assertIn("'RuntimeConf.set'", str(context.exception)) + + def test_blocks_catalog_set_current_catalog(self): + """Test that spark.catalog.setCurrentCatalog() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.setCurrentCatalog("test_catalog") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.SET_CURRENT_CATALOG", + ) + self.assertIn("'Catalog.setCurrentCatalog'", str(context.exception)) + + def test_blocks_catalog_set_current_database(self): + """Test that spark.catalog.setCurrentDatabase() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.setCurrentDatabase("test_db") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.SET_CURRENT_DATABASE", + ) + self.assertIn("'Catalog.setCurrentDatabase'", str(context.exception)) + + def test_blocks_catalog_drop_temp_view(self): + """Test that spark.catalog.dropTempView() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.dropTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.DROP_TEMP_VIEW", + ) + self.assertIn("'Catalog.dropTempView'", str(context.exception)) + + def test_blocks_catalog_drop_global_temp_view(self): + """Test that spark.catalog.dropGlobalTempView() is blocked.""" + catalog = self.spark.catalog + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + catalog.dropGlobalTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.DROP_GLOBAL_TEMP_VIEW", + ) + self.assertIn("'Catalog.dropGlobalTempView'", str(context.exception)) + + def test_blocks_dataframe_create_temp_view(self): + """Test that DataFrame.createTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createTempView'", str(context.exception)) + + def test_blocks_dataframe_create_or_replace_temp_view(self): + """Test that DataFrame.createOrReplaceTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createOrReplaceTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_OR_REPLACE_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createOrReplaceTempView'", str(context.exception)) + + def test_blocks_dataframe_create_global_temp_view(self): + """Test that DataFrame.createGlobalTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createGlobalTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_GLOBAL_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createGlobalTempView'", str(context.exception)) + + def test_blocks_dataframe_create_or_replace_global_temp_view(self): + """Test that DataFrame.createOrReplaceGlobalTempView() is blocked.""" + df = self.spark.range(1) + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + df.createOrReplaceGlobalTempView("test_view") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.CREATE_OR_REPLACE_GLOBAL_TEMP_VIEW", + ) + self.assertIn("'DataFrame.createOrReplaceGlobalTempView'", str(context.exception)) + + def test_blocks_udf_register(self): + """Test that spark.udf.register() is blocked.""" + udf_registry = self.spark.udf + + def test_func(x): + return x + 1 + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + udf_registry.register("test_udf", test_func, StringType()) + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.REGISTER_UDF", + ) + self.assertIn("'UDFRegistration.register'", str(context.exception)) + + def test_blocks_udf_register_java_function(self): + """Test that spark.udf.registerJavaFunction() is blocked.""" + udf_registry = self.spark.udf + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + udf_registry.registerJavaFunction( + "test_java_udf", "com.example.TestUDF", StringType() + ) + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.REGISTER_JAVA_UDF", + ) + self.assertIn("'UDFRegistration.registerJavaFunction'", str(context.exception)) + + def test_blocks_udf_register_java_udaf(self): + """Test that spark.udf.registerJavaUDAF() is blocked.""" + udf_registry = self.spark.udf + + with block_session_mutations(): + with self.assertRaises(PySparkException) as context: + udf_registry.registerJavaUDAF("test_java_udaf", "com.example.TestUDAF") + + self.assertEqual( + context.exception.getCondition(), + f"{ERROR_CLASS}.REGISTER_JAVA_UDAF", + ) + self.assertIn("'UDFRegistration.registerJavaUDAF'", str(context.exception)) + + def test_restores_original_methods_after_context(self): + """Test that all methods are properly restored after context manager exits.""" + # Store original methods + original_methods = {} + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + original_methods[(cls, method_name)] = getattr(cls, method_name) + + # Verify methods are originally set correctly + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + with self.subTest(class_method=f"{cls.__name__}.{method_name}"): + self.assertIs(getattr(cls, method_name), original_methods[(cls, method_name)]) + + # Verify methods are replaced during context + with block_session_mutations(): + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + with self.subTest(class_method=f"{cls.__name__}.{method_name}"): + self.assertIsNot( + getattr(cls, method_name), original_methods[(cls, method_name)] + ) + + # Verify methods are restored after context + for method_info in BLOCKED_METHODS: + cls = method_info["class"] + method_name = method_info["method"] + with self.subTest(class_method=f"{cls.__name__}.{method_name}"): + self.assertIs(getattr(cls, method_name), original_methods[(cls, method_name)]) + + +if __name__ == "__main__": + try: + import xmlrunner # type: ignore + + testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)