diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 0fc32b7ba9d3e..0ff8400dc37c8 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5002,6 +5002,12 @@ ], "sqlState" : "42000" }, + "PIPELINE_STORAGE_ROOT_INVALID" : { + "message" : [ + "Pipeline storage root must be an absolute path with a URI scheme (e.g., file://, s3a://, hdfs://). Got: ``." + ], + "sqlState" : "42K03" + }, "PIPE_OPERATOR_AGGREGATE_EXPRESSION_CONTAINS_NO_AGGREGATE_FUNCTION" : { "message" : [ "Non-grouping expression is provided as an argument to the |> AGGREGATE pipe operator but does not contain any aggregate function; please update it to include an aggregate function and then retry the query again." diff --git a/python/pyspark/pipelines/init_cli.py b/python/pyspark/pipelines/init_cli.py index ffe5d3c12b636..f8149b19263f8 100644 --- a/python/pyspark/pipelines/init_cli.py +++ b/python/pyspark/pipelines/init_cli.py @@ -19,7 +19,7 @@ SPEC = """ name: {{ name }} -storage: storage-root +storage: {{ storage_root }} libraries: - glob: include: transformations/** @@ -46,10 +46,18 @@ def init(name: str) -> None: project_dir = Path.cwd() / name project_dir.mkdir(parents=True, exist_ok=False) + # Create the storage directory + storage_dir = project_dir / "pipeline-storage" + storage_dir.mkdir(parents=True) + + # Create absolute file URI for storage path + storage_path = f"file://{storage_dir.resolve()}" + # Write the spec file to the project directory spec_file = project_dir / "pipeline.yml" with open(spec_file, "w") as f: - f.write(SPEC.replace("{{ name }}", name)) + spec_content = SPEC.replace("{{ name }}", name).replace("{{ storage_root }}", storage_path) + f.write(spec_content) # Create the transformations directory transformations_dir = project_dir / "transformations" diff --git a/python/pyspark/pipelines/tests/test_init_cli.py b/python/pyspark/pipelines/tests/test_init_cli.py index 49c949200821a..43c553eddc387 100644 --- a/python/pyspark/pipelines/tests/test_init_cli.py +++ b/python/pyspark/pipelines/tests/test_init_cli.py @@ -51,6 +51,14 @@ def test_init(self): spec_path = find_pipeline_spec(Path.cwd()) spec = load_pipeline_spec(spec_path) assert spec.name == project_name + + # Verify that the storage path is an absolute URI with file scheme + expected_storage_path = f"file://{Path.cwd() / 'pipeline-storage'}" + self.assertEqual(spec.storage, expected_storage_path) + + # Verify that the storage directory was created + self.assertTrue((Path.cwd() / "pipeline-storage").exists()) + registry = LocalGraphElementRegistry() register_definitions(spec_path, registry, spec) self.assertEqual(len(registry.outputs), 1) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 680755afdca21..1b747705e9ad7 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -437,7 +437,7 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { val pipelineUpdateContext = new PipelineUpdateContextImpl( new DataflowGraph(Seq(), Seq(), Seq(), Seq()), (_: PipelineEvent) => None, - storageRoot = "test_storage_root") + storageRoot = "file:///test_storage_root") sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext) assert( sessionHolder.getPipelineExecution(graphId).nonEmpty, diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala index a3d851c1ce7b9..94deb83f6ad43 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionManagerSuite.scala @@ -161,7 +161,7 @@ class SparkConnectSessionManagerSuite extends SharedSparkSession with BeforeAndA val pipelineUpdateContext = new PipelineUpdateContextImpl( new DataflowGraph(Seq(), Seq(), Seq(), Seq()), (_: PipelineEvent) => None, - storageRoot = "test_storage_root") + storageRoot = "file:///test_storage_root") sessionHolder.cachePipelineExecution(graphId, pipelineUpdateContext) assert( sessionHolder.getPipelineExecution(graphId).nonEmpty, diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala index bb2009b259124..7dfffbd9d662c 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImpl.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.pipelines.graph +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.pipelines.logging.{FlowProgressEventLogger, PipelineEvent} @@ -36,6 +39,8 @@ class PipelineUpdateContextImpl( override val storageRoot: String ) extends PipelineUpdateContext { + PipelineUpdateContextImpl.validateStorageRoot(storageRoot) + override val spark: SparkSession = SparkSession.getActiveSession.getOrElse( throw new IllegalStateException("SparkSession is not available") ) @@ -45,3 +50,19 @@ class PipelineUpdateContextImpl( override val resetCheckpointFlows: FlowFilter = NoFlows } + +object PipelineUpdateContextImpl { + def validateStorageRoot(storageRoot: String): Unit = { + // Use the same validation logic as streaming checkpoint directories + val path = new Path(storageRoot) + + val uri = path.toUri + if (!path.isAbsolute || uri.getScheme == null || uri.getScheme.isEmpty) { + throw new SparkException( + errorClass = "PIPELINE_STORAGE_ROOT_INVALID", + messageParameters = Map("storage_root" -> storageRoot), + cause = null + ) + } + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImplSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImplSuite.scala new file mode 100644 index 0000000000000..b22b6a9967e04 --- /dev/null +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/PipelineUpdateContextImplSuite.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.pipelines.graph + +import org.apache.spark.SparkException +import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} +import org.apache.spark.sql.test.SharedSparkSession + +class PipelineUpdateContextImplSuite extends PipelineTest with SharedSparkSession { + + test("validateStorageRoot should accept valid URIs with schemes") { + val validStorageRoots = Seq( + "file:///tmp/test", + "hdfs://localhost:9000/pipelines", + "s3a://my-bucket/pipelines", + "abfss://container@account.dfs.core.windows.net/pipelines" + ) + + validStorageRoots.foreach(PipelineUpdateContextImpl.validateStorageRoot) + } + + test("validateStorageRoot should reject relative paths") { + val invalidStorageRoots = Seq( + "relative/path", + "./relative/path", + "../relative/path", + "pipelines" + ) + + invalidStorageRoots.foreach { storageRoot => + val exception = intercept[SparkException] { + PipelineUpdateContextImpl.validateStorageRoot(storageRoot) + } + assert(exception.getCondition == "PIPELINE_STORAGE_ROOT_INVALID") + assert(exception.getMessageParameters.get("storage_root") == storageRoot) + } + } + + test("validateStorageRoot should reject absolute paths without URI scheme") { + val invalidStorageRoots = Seq( + "/tmp/test", + "/absolute/path", + "/pipelines/storage" + ) + + invalidStorageRoots.foreach { storageRoot => + val exception = intercept[SparkException] { + PipelineUpdateContextImpl.validateStorageRoot(storageRoot) + } + assert(exception.getCondition == "PIPELINE_STORAGE_ROOT_INVALID") + assert(exception.getMessageParameters.get("storage_root") == storageRoot) + } + } + + test("PipelineUpdateContextImpl constructor should validate storage root") { + val session = spark + import session.implicits._ + + class TestPipeline extends TestGraphRegistrationContext(spark) { + registerPersistedView("test", query = dfFlowFunc(Seq(1).toDF("value"))) + } + val graph = new TestPipeline().resolveToDataflowGraph() + + val validStorageRoot = "file:///tmp/test" + val context = new PipelineUpdateContextImpl( + unresolvedGraph = graph, + eventCallback = _ => {}, + storageRoot = validStorageRoot + ) + assert(context.storageRoot == validStorageRoot) + + val invalidStorageRoot = "/tmp/test" + val exception = intercept[SparkException] { + new PipelineUpdateContextImpl( + unresolvedGraph = graph, + eventCallback = _ => {}, + storageRoot = invalidStorageRoot + ) + } + assert(exception.getCondition == "PIPELINE_STORAGE_ROOT_INVALID") + assert(exception.getMessageParameters.get("storage_root") == invalidStorageRoot) + } +} diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala index 958ef5a80fd57..9e6010a611e97 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/SinkExecutionSuite.scala @@ -115,7 +115,7 @@ class SinkExecutionSuite extends ExecutionTest with SharedSparkSession { sinkIdentifier: TableIdentifier, flowIdentifier: TableIdentifier): Unit = { val expectedCheckpointLocation = new Path( - "file://" + rootDirectory + s"/_checkpoints/${sinkIdentifier.table}/${flowIdentifier.table}/0" + rootDirectory + s"/_checkpoints/${sinkIdentifier.table}/${flowIdentifier.table}/0" ) val streamingQuery = graphExecution .flowExecutions(flowIdentifier) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala index 420e2c6ad0e91..7c998c6df3d53 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/StorageRootMixin.scala @@ -38,7 +38,7 @@ trait StorageRootMixin extends BeforeAndAfterEach { self: Suite => override protected def beforeEach(): Unit = { super.beforeEach() storageRoot = - Files.createTempDirectory(getClass.getSimpleName).normalize.toString + s"file://${Files.createTempDirectory(getClass.getSimpleName).normalize.toString}" } override protected def afterEach(): Unit = {