diff --git a/core/src/main/scala/org/apache/spark/internal/io/BatchFileNamingProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/BatchFileNamingProtocol.scala new file mode 100644 index 000000000000..f90e4315bcf9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/BatchFileNamingProtocol.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +/** + * A [[FileNamingProtocol]] implementation to write output data in batch processing. + */ +class BatchFileNamingProtocol(jobId: String) extends FileNamingProtocol with Serializable { + + override def getTaskTempPath( + taskContext: TaskAttemptContext, fileContext: FileContext): String = { + // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val prefix = fileContext.prefix.getOrElse("") + val ext = fileContext.ext + val filename = f"${prefix}part-$split%05d-$jobId$ext" + + fileContext.relativeDir.map { + d => new Path(d, filename).toString + }.getOrElse(filename) + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index d9d7b06cdb8c..45a72e48e39d 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -73,35 +73,31 @@ abstract class FileCommitProtocol extends Logging { * Notifies the commit protocol to add a new file, and gets back the full path that should be * used. Must be called on the executors when running tasks. * - * Note that the returned temp file may have an arbitrary path. The commit protocol only - * promises that the file will be at the location specified by the arguments after job commit. + * Note that "relativePath" parameter specifies the relative path of returned temp file. The full + * path is left to the commit protocol to decide. The commit protocol only promises that the file + * will be at the location specified by the relative path after job commits. * - * A full file path consists of the following parts: - * 1. the base path - * 2. some sub-directory within the base path, used to specify partitioning - * 3. file prefix, usually some unique job id with the task id - * 4. bucket id - * 5. source specific file extension, e.g. ".snappy.parquet" - * - * The "dir" parameter specifies 2, and "ext" parameter specifies both 4 and 5, and the rest - * are left to the commit protocol implementation to decide. - * - * Important: it is the caller's responsibility to add uniquely identifying content to "ext" - * if a task is going to write out multiple files to the same dir. The file commit protocol only - * guarantees that files written by different tasks will not conflict. + * Important: it is the caller's responsibility to add uniquely identifying content to + * "relativePath" if a task is going to write out multiple files to the same directory. The file + * commit protocol only guarantees that files written by different tasks will not conflict. */ - def newTaskTempFile(taskContext: TaskAttemptContext, dir: Option[String], ext: String): String + def newTaskTempFile(taskContext: TaskAttemptContext, relativePath: String): String /** * Similar to newTaskTempFile(), but allows files to committed to an absolute output location. * Depending on the implementation, there may be weaker guarantees around adding files this way. * - * Important: it is the caller's responsibility to add uniquely identifying content to "ext" - * if a task is going to write out multiple files to the same dir. The file commit protocol only - * guarantees that files written by different tasks will not conflict. + * "relativePath" parameter specifies the relative path of returned temp file, and "finalPath" + * parameter specifies the full path of file after job commit. The commit protocol promises that + * the file will be at the location specified by the "finalPath" after job commits. + * + * Important: it is the caller's responsibility to add uniquely identifying content to + * "relativePath" and "finalPath" if a task is going to write out multiple files to the same + * directory. The file commit protocol only guarantees that files written by different tasks will + * not conflict. */ def newTaskTempFileAbsPath( - taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String + taskContext: TaskAttemptContext, relativePath: String, finalPath: String): String /** * Commits a task after the writes succeed. Must be called on the executors when running tasks. diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileNamingProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileNamingProtocol.scala new file mode 100644 index 000000000000..404e44ce7eb4 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileNamingProtocol.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +import org.apache.hadoop.mapreduce.TaskAttemptContext + +/** + * An interface to define how a single Spark job names its outputs. Two notes: + * + * 1. Implementations must be serializable, as the instance instantiated on the driver + * will be used for tasks on executors. + * 2. An instance should not be reused across multiple Spark jobs. + * + * The proper way to call is: + * + * As part of each task's execution, whenever a new output file needs be created, executor calls + * [[getTaskTempPath]] to get a valid relative file path before commit. + */ +abstract class FileNamingProtocol { + + /** + * Gets the relative path should be used for the output file. + * + * Important: it is the caller's responsibility to add uniquely identifying content to + * "fileContext" if a task is going to write out multiple files to the same directory. The file + * naming protocol only guarantees that files written by different tasks will not conflict. + */ + def getTaskTempPath(taskContext: TaskAttemptContext, fileContext: FileContext): String +} + +/** + * The context for Spark output file. This is used by [[FileNamingProtocol]] to create file path. + * + * @param ext Source specific file extension, e.g. ".snappy.parquet". + * @param relativeDir Relative directory of file. Can be used for writing dynamic partitions. + * E.g., "a=1/b=2" is directory for partition (a=1, b=2). + * @param prefix file prefix. + */ +final case class FileContext( + ext: String, + relativeDir: Option[String], + prefix: Option[String]) diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index c061d617fce4..ea2342a3fb4c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -18,7 +18,7 @@ package org.apache.spark.internal.io import java.io.IOException -import java.util.{Date, UUID} +import java.util.Date import scala.collection.mutable import scala.util.Try @@ -104,7 +104,7 @@ class HadoopMapReduceCommitProtocol( * The staging directory of this write job. Spark uses it to deal with files with absolute output * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. */ - protected def stagingDir = getStagingDir(path, jobId) + protected def stagingDir: Path = getStagingDir(path, jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.getConstructor().newInstance() @@ -116,50 +116,30 @@ class HadoopMapReduceCommitProtocol( format.getOutputCommitter(context) } - override def newTaskTempFile( - taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - val filename = getFilename(taskContext, ext) - + override def newTaskTempFile(taskContext: TaskAttemptContext, relativePath: String): String = { val stagingDir: Path = committer match { // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => if (dynamicPartitionOverwrite) { - assert(dir.isDefined, + val dir = new Path(relativePath).getParent.toString + assert(dir.nonEmpty, "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") - partitionPaths += dir.get + partitionPaths += dir } new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) case _ => new Path(path) } - dir.map { d => - new Path(new Path(stagingDir, d), filename).toString - }.getOrElse { - new Path(stagingDir, filename).toString - } + new Path(stagingDir, relativePath).toString } override def newTaskTempFileAbsPath( - taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { - val filename = getFilename(taskContext, ext) - val absOutputPath = new Path(absoluteDir, filename).toString - - // Include a UUID here to prevent file collisions for one task writing to different dirs. - // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString - - addedAbsPathFiles(tmpOutputPath) = absOutputPath + taskContext: TaskAttemptContext, relativePath: String, finalPath: String): String = { + val tmpOutputPath = new Path(stagingDir, relativePath).toString + addedAbsPathFiles(tmpOutputPath) = finalPath tmpOutputPath } - protected def getFilename(taskContext: TaskAttemptContext, ext: String): String = { - // The file name looks like part-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003-c000.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - f"part-$split%05d-$jobId$ext" - } - override def setupJob(jobContext: JobContext): Unit = { // Setup IDs val jobId = SparkHadoopWriterUtils.createJobID(new Date, 0) diff --git a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala index 2ca50878485c..1529d6f9dcd9 100644 --- a/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala +++ b/hadoop-cloud/src/hadoop-3/main/scala/org/apache/spark/internal/io/cloud/PathOutputCommitProtocol.scala @@ -121,24 +121,14 @@ class PathOutputCommitProtocol( * Create a temporary file for a task. * * @param taskContext task context - * @param dir optional subdirectory - * @param ext file extension - * @return a path as a string + * @param relativePath relative path as a string for file + * @return the full path as a string for file */ - override def newTaskTempFile( - taskContext: TaskAttemptContext, - dir: Option[String], - ext: String): String = { - - val workDir = committer.getWorkPath - val parent = dir.map { - d => new Path(workDir, d) - }.getOrElse(workDir) - val file = new Path(parent, getFilename(taskContext, ext)) - logTrace(s"Creating task file $file for dir $dir and ext $ext") + override def newTaskTempFile(taskContext: TaskAttemptContext, relativePath: String): String = { + val file = new Path(committer.getWorkPath, relativePath) + logTrace(s"Creating task file $file with relative path $relativePath") file.toString } - } object PathOutputCommitProtocol { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala index c7da75883f80..d2f90ab6e6a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryExecutionErrors.scala @@ -1405,9 +1405,9 @@ object QueryExecutionErrors { s"Multiple streaming queries are concurrently using $path", e) } - def addFilesWithAbsolutePathUnsupportedError(commitProtocol: String): Throwable = { + def addFilesWithAbsolutePathUnsupportedError(protocol: String): Throwable = { new UnsupportedOperationException( - s"$commitProtocol does not support adding files with an absolute path") + s"$protocol does not support adding files with an absolute path") } def microBatchUnsupportedByDataSourceError(srcName: String): Throwable = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 7e5a8cce2783..afa8e56f18eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -16,13 +16,15 @@ */ package org.apache.spark.sql.execution.datasources +import java.util.UUID + import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.{FileCommitProtocol, FileContext, FileNamingProtocol} import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -122,7 +124,8 @@ class EmptyDirectoryDataWriter( class SingleDirectoryDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) + committer: FileCommitProtocol, + namingProtocol: FileNamingProtocol) extends FileFormatDataWriter(description, taskAttemptContext, committer) { private var fileCounter: Int = _ private var recordsInFile: Long = _ @@ -133,11 +136,11 @@ class SingleDirectoryDataWriter( recordsInFile = 0 releaseResources() - val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) + val ext = f"-c$fileCounter%03d" + + description.outputWriterFactory.getFileExtension(taskAttemptContext) val currentPath = committer.newTaskTempFile( taskAttemptContext, - None, - f"-c$fileCounter%03d" + ext) + namingProtocol.getTaskTempPath(taskAttemptContext, FileContext(ext, None, None))) currentWriter = description.outputWriterFactory.newInstance( path = currentPath, @@ -169,7 +172,8 @@ class SingleDirectoryDataWriter( abstract class BaseDynamicPartitionDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) + committer: FileCommitProtocol, + namingProtocol: FileNamingProtocol) extends FileFormatDataWriter(description, taskAttemptContext, committer) { /** Flag saying whether or not the data to be written out is partitioned. */ @@ -262,9 +266,15 @@ abstract class BaseDynamicPartitionDataWriter( description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) } val currentPath = if (customPath.isDefined) { - committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) + // Include a UUID here to prevent file collisions for one task writing to different dirs. + val relativePath = UUID.randomUUID().toString + "-" + + namingProtocol.getTaskTempPath(taskAttemptContext, FileContext(ext, None, None)) + val finalPath = new Path(customPath.get, relativePath).toString + committer.newTaskTempFileAbsPath(taskAttemptContext, relativePath, finalPath) } else { - committer.newTaskTempFile(taskAttemptContext, partDir, ext) + val relativePath = namingProtocol.getTaskTempPath( + taskAttemptContext, FileContext(ext, partDir, None)) + committer.newTaskTempFile(taskAttemptContext, relativePath) } currentWriter = description.outputWriterFactory.newInstance( @@ -314,8 +324,10 @@ abstract class BaseDynamicPartitionDataWriter( class DynamicPartitionDataSingleWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) - extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) { + committer: FileCommitProtocol, + namingProtocol: FileNamingProtocol) + extends BaseDynamicPartitionDataWriter( + description, taskAttemptContext, committer, namingProtocol) { private var currentPartitionValues: Option[UnsafeRow] = None private var currentBucketId: Option[Int] = None @@ -361,8 +373,10 @@ class DynamicPartitionDataConcurrentWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol, + namingProtocol: FileNamingProtocol, concurrentOutputWriterSpec: ConcurrentOutputWriterSpec) - extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) + extends BaseDynamicPartitionDataWriter( + description, taskAttemptContext, committer, namingProtocol) with Logging { /** Wrapper class to index a unique concurrent output writer. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 6839a4db0bc2..b89eef4cb3d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -27,7 +27,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} +import org.apache.spark.internal.io.{FileCommitProtocol, FileNamingProtocol, SparkHadoopWriterUtils} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow @@ -96,7 +96,7 @@ object FileFormatWriter extends Logging { sparkSession: SparkSession, plan: SparkPlan, fileFormat: FileFormat, - committer: FileCommitProtocol, + protocols: (FileCommitProtocol, FileNamingProtocol), outputSpec: OutputSpec, hadoopConf: Configuration, partitionColumns: Seq[Attribute], @@ -105,6 +105,7 @@ object FileFormatWriter extends Logging { options: Map[String, String]) : Set[String] = { + val committer = protocols._1 val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) @@ -225,6 +226,7 @@ object FileFormatWriter extends Logging { sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE, committer, + protocols._2, iterator = iter, concurrentOutputWriterSpec = concurrentOutputWriterSpec) }, @@ -260,6 +262,7 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, + namingProtocol: FileNamingProtocol, iterator: Iterator[InternalRow], concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]): WriteTaskResult = { @@ -287,14 +290,15 @@ object FileFormatWriter extends Logging { // In case of empty job, leave first partition to save meta for file format like parquet. new EmptyDirectoryDataWriter(description, taskAttemptContext, committer) } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { - new SingleDirectoryDataWriter(description, taskAttemptContext, committer) + new SingleDirectoryDataWriter(description, taskAttemptContext, committer, namingProtocol) } else { concurrentOutputWriterSpec match { case Some(spec) => new DynamicPartitionDataConcurrentWriter( - description, taskAttemptContext, committer, spec) + description, taskAttemptContext, committer, namingProtocol, spec) case _ => - new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + new DynamicPartitionDataSingleWriter( + description, taskAttemptContext, committer, namingProtocol) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 267b360b474c..a226cab02738 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.{BatchFileNamingProtocol, FileCommitProtocol} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -110,6 +110,7 @@ case class InsertIntoHadoopFsRelationCommand( jobId = jobId, outputPath = outputPath.toString, dynamicPartitionOverwrite = dynamicPartitionOverwrite) + val namingProtocol = new BatchFileNamingProtocol(jobId) val doInsertion = if (mode == SaveMode.Append) { true @@ -176,7 +177,7 @@ case class InsertIntoHadoopFsRelationCommand( sparkSession = sparkSession, plan = child, fileFormat = fileFormat, - committer = committer, + protocols = (committer, namingProtocol), outputSpec = FileFormatWriter.OutputSpec( committerOutputPath.toString, customPartitionLocations, outputColumns), hadoopConf = hadoopConf, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala index 7227e48bc9a1..847dc2b58374 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileBatchWrite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.hadoop.mapreduce.Job import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.{FileCommitProtocol, FileNamingProtocol} import org.apache.spark.sql.connector.write.{BatchWrite, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage} import org.apache.spark.sql.execution.datasources.{WriteJobDescription, WriteTaskResult} import org.apache.spark.sql.execution.datasources.FileFormatWriter.processStats @@ -28,7 +28,8 @@ import org.apache.spark.util.Utils class FileBatchWrite( job: Job, description: WriteJobDescription, - committer: FileCommitProtocol) + committer: FileCommitProtocol, + namingProtocol: FileNamingProtocol) extends BatchWrite with Logging { override def commit(messages: Array[WriterCommitMessage]): Unit = { val results = messages.map(_.asInstanceOf[WriteTaskResult]) @@ -47,7 +48,7 @@ class FileBatchWrite( } override def createBatchWriterFactory(info: PhysicalWriteInfo): DataWriterFactory = { - FileWriterFactory(description, committer) + FileWriterFactory(description, committer, namingProtocol) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala index 4f736cbd8970..2d01643b98a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWrite.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat -import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.{BatchFileNamingProtocol, FileCommitProtocol} import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} @@ -57,15 +57,17 @@ trait FileWrite extends Write { // Hadoop Configurations are case sensitive. val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) val job = getJobInstance(hadoopConf, path) + val jobId = java.util.UUID.randomUUID().toString val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, - jobId = java.util.UUID.randomUUID().toString, + jobId = jobId, outputPath = paths.head) + val namingProtocol = new BatchFileNamingProtocol(jobId) lazy val description = createWriteJobDescription(sparkSession, hadoopConf, job, paths.head, options.asScala.toMap) committer.setupJob(job) - new FileBatchWrite(job, description, committer) + new FileBatchWrite(job, description, committer, namingProtocol) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index d827e8362357..8cbbae894a70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -21,21 +21,23 @@ import java.util.Date import org.apache.hadoop.mapreduce.{TaskAttemptID, TaskID, TaskType} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl -import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} +import org.apache.spark.internal.io.{FileCommitProtocol, FileNamingProtocol, SparkHadoopWriterUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataSingleWriter, SingleDirectoryDataWriter, WriteJobDescription} case class FileWriterFactory ( description: WriteJobDescription, - committer: FileCommitProtocol) extends DataWriterFactory { + committer: FileCommitProtocol, + namingProtocol: FileNamingProtocol) extends DataWriterFactory { override def createWriter(partitionId: Int, realTaskId: Long): DataWriter[InternalRow] = { val taskAttemptContext = createTaskAttemptContext(partitionId) committer.setupTask(taskAttemptContext) if (description.partitionColumns.isEmpty) { - new SingleDirectoryDataWriter(description, taskAttemptContext, committer) + new SingleDirectoryDataWriter(description, taskAttemptContext, committer, namingProtocol) } else { - new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + new DynamicPartitionDataSingleWriter( + description, taskAttemptContext, committer, namingProtocol) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 0066678a5d96..cd4d5e3180b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -152,6 +152,7 @@ class FileStreamSink( manifestCommitter.setupManifestOptions(fileLog, batchId) case _ => // Do nothing } + val namingProtocol = new StreamingFileNamingProtocol(batchId.toString) // Get the actual partition columns as attributes after matching them by name with // the given columns names. @@ -167,7 +168,7 @@ class FileStreamSink( sparkSession = sparkSession, plan = qe.executedPlan, fileFormat = fileFormat, - committer = committer, + protocols = (committer, namingProtocol), outputSpec = FileFormatWriter.OutputSpec(path, Map.empty, qe.analyzed.output), hadoopConf = hadoopConf, partitionColumns = partitionColumns, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala index 46ce33687890..f79b952db4de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ManifestFileCommitProtocol.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.streaming import java.io.IOException -import java.util.UUID import scala.collection.mutable.ArrayBuffer @@ -111,27 +110,14 @@ class ManifestFileCommitProtocol(jobId: String, path: String) addedFiles = new ArrayBuffer[String] } - override def newTaskTempFile( - taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { - // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet - // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, - // the file name is fine and won't overflow. - val split = taskContext.getTaskAttemptID.getTaskID.getId - val uuid = UUID.randomUUID.toString - val filename = f"part-$split%05d-$uuid$ext" - - val file = dir.map { d => - new Path(new Path(path, d), filename).toString - }.getOrElse { - new Path(path, filename).toString - } - + override def newTaskTempFile(taskContext: TaskAttemptContext, relativePath: String): String = { + val file = new Path(path, relativePath).toString addedFiles += file file } override def newTaskTempFileAbsPath( - taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + taskContext: TaskAttemptContext, relativePath: String, finalPath: String): String = { throw QueryExecutionErrors.addFilesWithAbsolutePathUnsupportedError(this.toString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingFileNamingProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingFileNamingProtocol.scala new file mode 100644 index 000000000000..abc8f25b2e19 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingFileNamingProtocol.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.util.UUID + +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.TaskAttemptContext + +import org.apache.spark.internal.io.{FileContext, FileNamingProtocol} + +/** + * A [[FileNamingProtocol]] implementation to write output data in streaming processing. + */ +class StreamingFileNamingProtocol(jobId: String) extends FileNamingProtocol with Serializable { + + override def getTaskTempPath( + taskContext: TaskAttemptContext, fileContext: FileContext): String = { + // The file name looks like part-r-00000-2dd664f9-d2c4-4ffe-878f-c6c70c1fb0cb_00003.gz.parquet + // Note that %05d does not truncate the split number, so if we have more than 100000 tasks, + // the file name is fine and won't overflow. + val split = taskContext.getTaskAttemptID.getTaskID.getId + val uuid = UUID.randomUUID.toString + val ext = fileContext.ext + val filename = f"part-$split%05d-$uuid$ext" + + fileContext.relativeDir.map { d => + new Path(d, filename).toString + }.getOrElse(filename) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index b9266429f81a..0b80f1276030 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -39,7 +39,7 @@ private class OnlyDetectCustomPathFileCommitProtocol(jobId: String, path: String with Serializable with Logging { override def newTaskTempFileAbsPath( - taskContext: TaskAttemptContext, absoluteDir: String, ext: String): String = { + taskContext: TaskAttemptContext, relativePath: String, finalPath: String): String = { throw new Exception("there should be no custom partition path") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index ec189344f4fa..c063d73b920e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.hive.ql.exec.TaskRunner -import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.internal.io.{BatchFileNamingProtocol, FileCommitProtocol} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute @@ -79,16 +79,19 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } + val jobId = java.util.UUID.randomUUID().toString + val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, - jobId = java.util.UUID.randomUUID().toString, + jobId = jobId, outputPath = outputLocation) + val namingProtocol = new BatchFileNamingProtocol(jobId) FileFormatWriter.write( sparkSession = sparkSession, plan = plan, fileFormat = new HiveFileFormat(fileSinkConf), - committer = committer, + protocols = (committer, namingProtocol), outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations, outputColumns), hadoopConf = hadoopConf,