diff --git a/core/src/main/scala/org/apache/spark/SparkException.scala b/core/src/main/scala/org/apache/spark/SparkException.scala index 4ad9a0cc4b10..2ca354c7fb79 100644 --- a/core/src/main/scala/org/apache/spark/SparkException.scala +++ b/core/src/main/scala/org/apache/spark/SparkException.scala @@ -43,3 +43,9 @@ private[spark] case class SparkUserAppException(exitCode: Int) */ private[spark] case class ExecutorDeadException(message: String) extends SparkException(message) + +/** + * Exception thrown when several InsertHadoopFsRelation operations are conflicted. + */ +private[spark] case class InsertFileSourceConflictException(message: String) + extends SparkException(message) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 0746e43babf9..55980fcd06f6 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -23,7 +23,6 @@ import org.apache.hadoop.mapreduce._ import org.apache.spark.internal.Logging import org.apache.spark.util.Utils - /** * An interface to define how a single Spark job commits its outputs. Three notes: * @@ -31,7 +30,8 @@ import org.apache.spark.util.Utils * will be used for tasks on executors. * 2. Implementations should have a constructor with 2 or 3 arguments: * (jobId: String, path: String) or - * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) or + * (jobId: String, path: String, fileSourceWriteDesc: Option[FileSourceWriteDesc]) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -169,4 +169,34 @@ object FileCommitProtocol extends Logging { ctor.newInstance(jobId, outputPath) } } + + /** + * Instantiates a FileCommitProtocol with file source write description. + */ + def instantiate( + className: String, + jobId: String, + outputPath: String, + fileSourceWriteDesc: Option[FileSourceWriteDesc]): FileCommitProtocol = { + + logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" + + s" fileSourceWriteDesc= $fileSourceWriteDesc") + val clazz = Utils.classForName[FileCommitProtocol](className) + // First try the constructor with arguments (jobId: String, outputPath: String, + // fileSourceWriteDesc: Option[FileSourceWriteDesc]). + // If that doesn't exist, try to invoke `FileCommitProtocol.instance(className, + // JobId, outputPath, dynamicPartitionOverwrite)`. + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], + classOf[Option[FileSourceWriteDesc]]) + logDebug("Using (String, String, Option[FileSourceWriteDesc]) constructor") + ctor.newInstance(jobId, outputPath, fileSourceWriteDesc) + } catch { + case _: NoSuchMethodException => + logDebug("Falling back to invoke instance(className, JobId, outputPath," + + " dynamicPartitionOverwrite)") + instantiate(className, jobId, outputPath, + fileSourceWriteDesc.map(_.dynamicPartitionOverwrite).getOrElse(false)) + } + } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileSourceWriteDesc.scala b/core/src/main/scala/org/apache/spark/internal/io/FileSourceWriteDesc.scala new file mode 100644 index 000000000000..15587126481e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/internal/io/FileSourceWriteDesc.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.internal.io + +/** + * A class to describe the properties for file source write operation. + * + * @param isInsertIntoHadoopFsRelation whether is a InsertIntoHadoopFsRelation operation + * @param dynamicPartitionOverwrite dynamic overwrite is enabled, the save mode is overwrite and + * not all partition keys are specified + * @param escapedStaticPartitionKVs static partition key and value pairs, which have been escaped + */ +class FileSourceWriteDesc( + val isInsertIntoHadoopFsRelation: Boolean = false, + val dynamicPartitionOverwrite: Boolean = false, + val escapedStaticPartitionKVs: Seq[(String, String)] = Seq.empty[(String, String)]) + extends Serializable diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 11ce608f52ee..a943b86ac73d 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -17,18 +17,19 @@ package org.apache.spark.internal.io -import java.io.IOException +import java.io.{File, FileNotFoundException, IOException} import java.util.{Date, UUID} import scala.collection.mutable -import scala.util.Try +import scala.util.{Failure, Success, Try} import org.apache.hadoop.conf.Configurable -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl +import org.apache.spark.SparkEnv import org.apache.spark.internal.Logging import org.apache.spark.mapred.SparkHadoopMapRedUtil @@ -40,22 +41,31 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop - * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime - * dynamically, i.e., we first write files under a staging - * directory with partition path, e.g. - * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, - * we first clean up the corresponding partition directories at - * destination path, e.g. /path/to/destination/a=1/b=1, and move - * files from staging directory to the corresponding partition - * directories under destination path. + * @param fileSourceWriteDesc a description for file source write operation */ class HadoopMapReduceCommitProtocol( jobId: String, path: String, - dynamicPartitionOverwrite: Boolean = false) + fileSourceWriteDesc: Option[FileSourceWriteDesc]) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ + import HadoopMapReduceCommitProtocol._ + + def this(jobId: String, path: String, dynamicPartitionOverwrite: Boolean = false) = + this(jobId, path, Some(new FileSourceWriteDesc(dynamicPartitionOverwrite = + dynamicPartitionOverwrite))) + + /** + * If true, Spark will overwrite partition directories at runtime dynamically, i.e., we first + * write files under a staging directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. + * When committing the job, we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move files from staging directory to + * the corresponding partition directories under destination path. + */ + def dynamicPartitionOverwrite: Boolean = + fileSourceWriteDesc.map(_.dynamicPartitionOverwrite).getOrElse(false) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ @@ -91,7 +101,63 @@ class HadoopMapReduceCommitProtocol( */ private def stagingDir = new Path(path, ".spark-staging-" + jobId) + /** + * For InsertIntoHadoopFsRelation operation, we support concurrent write to different partitions + * in a same table. + */ + def supportConcurrent: Boolean = + fileSourceWriteDesc.map(_.isInsertIntoHadoopFsRelation).getOrElse(false) + + /** + * Get escaped static partition key and value pairs, the default is empty. + */ + private def escapedStaticPartitionKVs = + fileSourceWriteDesc.map(_.escapedStaticPartitionKVs).getOrElse(Seq.empty) + + /** + * The staging root directory for InsertIntoHadoopFsRelation operation. + */ + @transient private var insertStagingDir: Path = null + + /** + * The staging output path for InsertIntoHadoopFsRelation operation. + */ + @transient private var stagingOutputPath: Path = null + + /** + * Get the desired output path for the job. The output will be [[path]] when current operation + * is not a InsertIntoHadoopFsRelation operation. Otherwise, we choose a sub path composed of + * [[escapedStaticPartitionKVs]] under [[insertStagingDir]] over [[path]] to mark this operation + * and we can detect whether there is a operation conflict with current by checking the existence + * of relative output path. + * + * @return Path the desired output path. + */ + protected def getOutputPath(context: TaskAttemptContext): Path = { + if (supportConcurrent) { + val insertStagingPath = ".spark-staging-" + escapedStaticPartitionKVs.size + insertStagingDir = new Path(path, insertStagingPath) + val appId = SparkEnv.get.conf.getAppId + val outputPath = new Path(path, Array(insertStagingPath, + getEscapedStaticPartitionPath(escapedStaticPartitionKVs), appId, jobId) + .mkString(File.separator)) + insertStagingDir.getFileSystem(context.getConfiguration).makeQualified(outputPath) + outputPath + } else { + new Path(path) + } + } + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + if (supportConcurrent) { + stagingOutputPath = getOutputPath(context) + context.getConfiguration.set(FileOutputFormat.OUTDIR, stagingOutputPath.toString) + logDebug("Set file output committer algorithm version to 2 implicitly," + + " for that the task output would be committed to staging output path firstly," + + " which is equivalent to algorithm 1.") + context.getConfiguration.setInt(FileOutputCommitter.FILEOUTPUTCOMMITTER_ALGORITHM_VERSION, 2) + } + val format = context.getOutputFormatClass.getConstructor().newInstance() // If OutputFormat is Configurable, we should set conf to it. format match { @@ -200,6 +266,16 @@ class HadoopMapReduceCommitProtocol( } fs.rename(new Path(stagingDir, part), finalPartPath) } + } else if (supportConcurrent) { + // For InsertIntoHadoopFsRelation operation, the result has been committed to staging + // output path, merge it to destination path. + mergeStagingPath(fs, stagingOutputPath, new Path(path)) + } + + if (supportConcurrent) { + // For InsertIntoHadoopFsRelation operation, try to delete its staging output path. + deleteStagingInsertOutputPath(fs, insertStagingDir, stagingOutputPath, + escapedStaticPartitionKVs) } fs.delete(stagingDir, true) @@ -224,6 +300,8 @@ class HadoopMapReduceCommitProtocol( if (hasValidPath) { val fs = stagingDir.getFileSystem(jobContext.getConfiguration) fs.delete(stagingDir, true) + deleteStagingInsertOutputPath(fs, insertStagingDir, stagingOutputPath, + escapedStaticPartitionKVs) } } catch { case e: IOException => @@ -272,3 +350,121 @@ class HadoopMapReduceCommitProtocol( } } } + +object HadoopMapReduceCommitProtocol extends Logging { + + /** + * Get a path according to specified partition key-value pairs. + */ + def getEscapedStaticPartitionPath(staticPartitionKVs: Iterable[(String, String)]): String = { + staticPartitionKVs.map{kv => + kv._1 + "=" + kv._2 + }.mkString(File.separator) + } + + /** + * Delete the staging output path of current InsertIntoHadoopFsRelation operation. This output + * path is used to mark a InsertIntoHadoopFsRelation operation and we can detect conflict when + * there are several operations write same partition or a non-partitioned table concurrently. + * + * The output path is a multi level path and is composed of specified partition key value pairs + * formatted `.spark-staging-${depth}/p1=v1/p2=v2/.../pn=vn/appId/jobId`. When deleting the + * staging output path, delete the last level with recursive firstly. Then try to delete upper + * level without recursive, if success, then delete upper level with same way, until delete the + * insertStagingDir. + */ + def deleteStagingInsertOutputPath( + fs: FileSystem, + insertStagingDir: Path, + stagingOutputDir: Path, + escapedStaticPartitionKVs: Seq[(String, String)]): Unit = { + if (insertStagingDir == null || stagingOutputDir ==null || !fs.isDirectory(stagingOutputDir)) { + return + } + + // Firstly, delete the staging output dir with recursive, because it is unique. + deleteSilently(fs, stagingOutputDir, true) + + var currentLevelPath = stagingOutputDir.getParent + while (currentLevelPath != insertStagingDir) { + deleteSilently(fs, currentLevelPath, false) + currentLevelPath = currentLevelPath.getParent + } + + deleteSilently(fs, insertStagingDir, false) + } + + private def deleteSilently(fs: FileSystem, path: Path, recursive: Boolean): Unit = { + try { + if (!fs.delete(path, recursive)) { + logWarning(s"Failed to delete path:$path with recursive:$recursive") + } + } catch { + case e: Exception => + logWarning(s"Exception occurred when deleting dir: $path.", e) + } + } + + /** + * Merge files under staging output path to destination path. Before merging, we need delete the + * succeeded file under staging output path and regenerate it after merging completed. + */ + private def mergeStagingPath( + fs: FileSystem, + stagingOutputPath: Path, + destPath: Path): Unit = { + val stagingMarkerPath = new Path(stagingOutputPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fs.delete(stagingMarkerPath, true) + + doMergePaths(fs, fs.getFileStatus(stagingOutputPath), destPath) + + val markerPath = new Path(destPath, FileOutputCommitter.SUCCEEDED_FILE_NAME) + fs.create(markerPath, true).close() + } + + /** + * This is a reflected implementation of [[FileOutputCommitter]]'s mergePaths. + * Just remove some unnecessary operations to improve performance. + */ + @throws[IOException] + private def doMergePaths(fs: FileSystem, from: FileStatus, to: Path): Unit = { + logDebug(s"Merging data from $from to $to") + + val toStat: FileStatus = Try { + fs.getFileStatus(to) + } match { + case Success(stat) => stat + case Failure(_: FileNotFoundException) => null + case Failure(e) => throw e + } + + if (from.isFile) { + if (toStat != null && !fs.delete(to, true)) { + throw new IOException(s"Failed to delete $to" ) + } + rename(fs, from, to) + } else if (from.isDirectory) { + if (toStat != null) { + if (!toStat.isDirectory) { + if (!fs.delete(to, true)) { + throw new IOException(s"Failed to delete $to") + } + rename(fs, from, to) + } else { + fs.listStatus(from.getPath).foreach { fileToMove => + doMergePaths(fs, fileToMove, new Path(to, fileToMove.getPath.getName)) + } + } + } else { + rename(fs, from, to) + } + } + } + + @throws[IOException] + private def rename(fs: FileSystem, from: FileStatus, to: Path): Unit = { + if (!fs.rename(from.getPath, to)) { + throw new IOException(s"Failed to rename $from to $to") + } + } +} diff --git a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala index 2bd32fc927e2..2e43b39d9847 100644 --- a/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala +++ b/core/src/test/scala/org/apache/spark/internal/io/FileCommitProtocolInstantiationSuite.scala @@ -53,6 +53,15 @@ class FileCommitProtocolInstantiationSuite extends SparkFunSuite { "Wrong constructor argument count") } + + test("With file source write desc arg constructor has priority when file source" + + " write description specified") { + val instance = instantiateWithFileSourceWriteDesc( + Some(new FileSourceWriteDesc(true, false, Seq.empty))) + assert(3 == instance.argCount, "Wrong constructor argument count") + assert("with file source write desc" == instance.msg, "Wrong constructor invoked") + } + test("The protocol must be of the correct class") { intercept[ClassCastException] { FileCommitProtocol.instantiate( @@ -75,7 +84,7 @@ class FileCommitProtocolInstantiationSuite extends SparkFunSuite { /** * Create a classic two-arg protocol instance. - * @param dynamic dyanmic partitioning mode + * @param dynamic dynamic partitioning mode * @return the instance */ private def instantiateClassic(dynamic: Boolean): ClassicConstructorCommitProtocol = { @@ -88,7 +97,7 @@ class FileCommitProtocolInstantiationSuite extends SparkFunSuite { /** * Create a three-arg protocol instance. - * @param dynamic dyanmic partitioning mode + * @param dynamic dynamic partitioning mode * @return the instance */ private def instantiateNew( @@ -100,6 +109,19 @@ class FileCommitProtocolInstantiationSuite extends SparkFunSuite { dynamic).asInstanceOf[FullConstructorCommitProtocol] } + /** + * Create a four-arg protocol instance. + * @param desc file source write description + * @return the instance + */ + private def instantiateWithFileSourceWriteDesc( + desc: Option[FileSourceWriteDesc]): FullConstructorCommitProtocol = { + FileCommitProtocol.instantiate( + classOf[FullConstructorCommitProtocol].getCanonicalName, + "job", + "path", + None).asInstanceOf[FullConstructorCommitProtocol] + } } /** @@ -119,16 +141,21 @@ private class ClassicConstructorCommitProtocol(arg1: String, arg2: String) private class FullConstructorCommitProtocol( arg1: String, arg2: String, - b: Boolean, - val argCount: Int) - extends HadoopMapReduceCommitProtocol(arg1, arg2, b) { + desc: Option[FileSourceWriteDesc], + val argCount: Int, + val msg: String = "") + extends HadoopMapReduceCommitProtocol(arg1, arg2, desc) { def this(arg1: String, arg2: String) = { - this(arg1, arg2, false, 2) + this(arg1, arg2, None, 2) } def this(arg1: String, arg2: String, b: Boolean) = { - this(arg1, arg2, false, 3) + this(arg1, arg2, Some(new FileSourceWriteDesc(dynamicPartitionOverwrite = b)), 3) + } + + def this(arg1: String, arg2: String, desc: Option[FileSourceWriteDesc]) { + this(arg1, arg2, desc, 3, "with file source write desc") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index fbe874b3e8bc..3e792b1a15e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -17,13 +17,18 @@ package org.apache.spark.sql.execution.datasources -import java.io.IOException +import java.io.{File, IOException} +import java.util.Date + +import scala.collection.mutable +import scala.collection.mutable.ListBuffer import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.internal.io.FileCommitProtocol +import org.apache.spark.{InsertFileSourceConflictException, SparkEnv} +import org.apache.spark.internal.io.{FileCommitProtocol, FileSourceWriteDesc, HadoopMapReduceCommitProtocol} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTablePartition, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -60,6 +65,10 @@ case class InsertIntoHadoopFsRelationCommand( extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName + // Staging dirs may be created for InsertHadoopFsRelation operation. + var insertStagingDir: Path = null + var stagingOutputDir: Path = null + override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkColumnNameDuplication( @@ -103,103 +112,133 @@ case class InsertIntoHadoopFsRelationCommand( val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && staticPartitions.size < partitionColumns.length + val appId = SparkEnv.get.conf.getAppId + val jobId = java.util.UUID.randomUUID().toString + + val escapedStaticPartitionKVs = partitionColumns + .filter(c => staticPartitions.contains(c.name)) + .map { attr => + val escapedKey = ExternalCatalogUtils.escapePathName(attr.name) + val escapedValue = ExternalCatalogUtils.escapePathName(staticPartitions.get(attr.name).get) + (escapedKey, escapedValue) + } + + val fileSourceWriteDesc = Some(new FileSourceWriteDesc( + isInsertIntoHadoopFsRelation = true, + dynamicPartitionOverwrite = dynamicPartitionOverwrite, + escapedStaticPartitionKVs = escapedStaticPartitionKVs)) + val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, - jobId = java.util.UUID.randomUUID().toString, + jobId = jobId, outputPath = outputPath.toString, - dynamicPartitionOverwrite = dynamicPartitionOverwrite) + fileSourceWriteDesc = fileSourceWriteDesc) - val doInsertion = if (mode == SaveMode.Append) { - true - } else { - val pathExists = fs.exists(qualifiedOutputPath) - (mode, pathExists) match { - case (SaveMode.ErrorIfExists, true) => - throw new AnalysisException(s"path $qualifiedOutputPath already exists.") - case (SaveMode.Overwrite, true) => - if (ifPartitionNotExists && matchingPartitions.nonEmpty) { - false - } else if (dynamicPartitionOverwrite) { - // For dynamic partition overwrite, do not delete partition directories ahead. - true - } else { - deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) + try { + var doDeleteMatchingPartitions: Boolean = false + val doInsertion = if (mode == SaveMode.Append) { + true + } else { + val pathExists = fs.exists(qualifiedOutputPath) + (mode, pathExists) match { + case (SaveMode.ErrorIfExists, true) => + throw new AnalysisException(s"path $qualifiedOutputPath already exists.") + case (SaveMode.Overwrite, true) => + if (ifPartitionNotExists && matchingPartitions.nonEmpty) { + false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true + } else { + doDeleteMatchingPartitions = true + true + } + case (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => true - } - case (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) => - true - case (SaveMode.Ignore, exists) => - !exists - case (s, exists) => - throw new IllegalStateException(s"unsupported save mode $s ($exists)") + case (SaveMode.Ignore, exists) => + !exists + case (s, exists) => + throw new IllegalStateException(s"unsupported save mode $s ($exists)") + } } - } - if (doInsertion) { + if (doInsertion) { + // For insertion operation, detect whether there is a conflict. + detectConflict(committer, fs, outputPath, escapedStaticPartitionKVs, appId, jobId) - def refreshUpdatedPartitions(updatedPartitionPaths: Set[String]): Unit = { - val updatedPartitions = updatedPartitionPaths.map(PartitioningUtils.parsePathFragment) - if (partitionsTrackedByCatalog) { - val newPartitions = updatedPartitions -- initialMatchingPartitions - if (newPartitions.nonEmpty) { - AlterTableAddPartitionCommand( - catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), - ifNotExists = true).run(sparkSession) - } - // For dynamic partition overwrite, we never remove partitions but only update existing - // ones. - if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { - val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions - if (deletedPartitions.nonEmpty) { - AlterTableDropPartitionCommand( - catalogTable.get.identifier, deletedPartitions.toSeq, - ifExists = true, purge = false, - retainData = true /* already deleted */).run(sparkSession) + if (doDeleteMatchingPartitions) { + deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) + } + + def refreshUpdatedPartitions(updatedPartitionPaths: Set[String]): Unit = { + val updatedPartitions = updatedPartitionPaths.map(PartitioningUtils.parsePathFragment) + if (partitionsTrackedByCatalog) { + val newPartitions = updatedPartitions -- initialMatchingPartitions + if (newPartitions.nonEmpty) { + AlterTableAddPartitionCommand( + catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), + ifNotExists = true).run(sparkSession) + } + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { + val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions + if (deletedPartitions.nonEmpty) { + AlterTableDropPartitionCommand( + catalogTable.get.identifier, deletedPartitions.toSeq, + ifExists = true, purge = false, + retainData = true /* already deleted */).run(sparkSession) + } } } } - } - val updatedPartitionPaths = - FileFormatWriter.write( - sparkSession = sparkSession, - plan = child, - fileFormat = fileFormat, - committer = committer, - outputSpec = FileFormatWriter.OutputSpec( - qualifiedOutputPath.toString, customPartitionLocations, outputColumns), - hadoopConf = hadoopConf, - partitionColumns = partitionColumns, - bucketSpec = bucketSpec, - statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), - options = options) - - - // update metastore partition metadata - if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty - && partitionColumns.length == staticPartitions.size) { - // Avoid empty static partition can't loaded to datasource table. - val staticPathFragment = - PartitioningUtils.getPathFragment(staticPartitions, partitionColumns) - refreshUpdatedPartitions(Set(staticPathFragment)) - } else { - refreshUpdatedPartitions(updatedPartitionPaths) - } + val updatedPartitionPaths = + FileFormatWriter.write( + sparkSession = sparkSession, + plan = child, + fileFormat = fileFormat, + committer = committer, + outputSpec = FileFormatWriter.OutputSpec( + qualifiedOutputPath.toString, customPartitionLocations, outputColumns), + hadoopConf = hadoopConf, + partitionColumns = partitionColumns, + bucketSpec = bucketSpec, + statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), + options = options) + - // refresh cached files in FileIndex - fileIndex.foreach(_.refresh()) - // refresh data cache if table is cached - sparkSession.catalog.refreshByPath(outputPath.toString) + // update metastore partition metadata + if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty + && partitionColumns.length == staticPartitions.size) { + // Avoid empty static partition can't loaded to datasource table. + val staticPathFragment = + PartitioningUtils.getPathFragment(staticPartitions, partitionColumns) + refreshUpdatedPartitions(Set(staticPathFragment)) + } else { + refreshUpdatedPartitions(updatedPartitionPaths) + } + + // refresh cached files in FileIndex + fileIndex.foreach(_.refresh()) + // refresh data cache if table is cached + sparkSession.catalog.refreshByPath(outputPath.toString) + + if (catalogTable.nonEmpty) { + CommandUtils.updateTableStats(sparkSession, catalogTable.get) + } - if (catalogTable.nonEmpty) { - CommandUtils.updateTableStats(sparkSession, catalogTable.get) + } else { + logInfo("Skipping insertion into a relation that already exists.") } - } else { - logInfo("Skipping insertion into a relation that already exists.") + Seq.empty[Row] + } catch { + case e: Exception => + HadoopMapReduceCommitProtocol.deleteStagingInsertOutputPath(fs, insertStagingDir, + stagingOutputDir, escapedStaticPartitionKVs) + throw e } - - Seq.empty[Row] } /** @@ -266,4 +305,143 @@ case class InsertIntoHadoopFsRelationCommand( } }.toMap } + + /** + * Check current committer whether supports several InsertIntoHadoopFsRelation operations write + * to different partitions in a same table concurrently. If supports, then detect the conflict + * whether there are several operations write to same partition in the same table or write to + * a non-partitioned table. + */ + private def detectConflict( + commitProtocol: FileCommitProtocol, + fs: FileSystem, + path: Path, + staticPartitionKVs: Seq[(String, String)], + appId: String, + jobId: String): Unit = { + import HadoopMapReduceCommitProtocol._ + + val supportConcurrent = commitProtocol.isInstanceOf[HadoopMapReduceCommitProtocol] && + commitProtocol.asInstanceOf[HadoopMapReduceCommitProtocol].supportConcurrent + if (supportConcurrent && fs.isDirectory(outputPath)) { + var conflictedPathAndDepths = mutable.Map[Path, Int]() + try { + val insertStagingPath = ".spark-staging-" + staticPartitionKVs.size + val checkedPath = new Path(outputPath, Array(insertStagingPath, + getEscapedStaticPartitionPath(staticPartitionKVs)).mkString(File.separator)) + insertStagingDir = new Path(outputPath, insertStagingPath) + + if (fs.exists(checkedPath)) { + conflictedPathAndDepths += insertStagingDir -> staticPartitionKVs.size + } else { + stagingOutputDir = new Path(outputPath, Array(insertStagingPath, + getEscapedStaticPartitionPath(staticPartitionKVs), appId, jobId) + .mkString(File.separator)) + fs.mkdirs(stagingOutputDir) + } + + for (i <- 0 to partitionColumns.size) { + if (i != staticPartitionKVs.size) { + val stagingDir = new Path(path, ".spark-staging-" + i) + if (fs.exists(stagingDir)) { + val subPath = getEscapedStaticPartitionPath( + staticPartitionKVs.slice(0, i)) + val checkedPath = if (!subPath.isEmpty) { + new Path(stagingDir, subPath) + } else { + stagingDir + } + if (fs.exists(checkedPath)) { + conflictedPathAndDepths += stagingDir -> i + } + } + } + } + } finally { + if (!conflictedPathAndDepths.isEmpty) { + throwConflictException(fs, conflictedPathAndDepths.toMap, staticPartitionKVs) + } + } + + } + } + + private def throwConflictException( + fs: FileSystem, + conflictedPathAndDepths: Map[Path, Int], + staticPartitionKVs: Seq[(String, String)]): Unit = { + val conflictedInfo = ListBuffer[(String, String, Date)]() + + for ((stagingDir, depth) <- conflictedPathAndDepths) { + val conflictedPaths = ListBuffer[Path]() + val currentPath = if (depth == staticPartitionKVs.size || staticPartitionKVs.size == 0) { + stagingDir + } else { + new Path(stagingDir, HadoopMapReduceCommitProtocol.getEscapedStaticPartitionPath( + staticPartitionKVs.slice(0, staticPartitionKVs.size - depth))) + } + + findConflictedStagingOutputPaths(fs, currentPath, depth, conflictedPaths) + + val pathsInfo = conflictedPaths + .map { path => + val absolutePath = path.toUri.getRawPath + val relativePath = absolutePath.substring(absolutePath.lastIndexOf(stagingDir.getName)) + var appId: Option[String] = None + var modificationTime: Date = null + try { + val files = fs.listStatus(path) + if (files.size > 0) { + appId = Some(files.apply(0).getPath.getName) + } + modificationTime = new Date(fs.getFileStatus(path).getModificationTime) + } catch { + case e: Exception => logWarning("Exception occurred", e) + } + (relativePath, appId.getOrElse("Not Found"), modificationTime) + } + conflictedInfo ++= pathsInfo + } + + throw new InsertFileSourceConflictException( + s""" + | Conflict is detected, some other conflicted output path(s) under tablePath: + | ($outputPath) existed. + | Relative path, appId and last modification time information is shown as below: + | ${conflictedInfo}. + | There may be two possibilities: + | 1. Another InsertDataSource operation is executing, you need wait for it to + | complete. + | 2. This dir is belong to a killed application and not be cleaned up gracefully. + | + | Please check the last modification time and use given appId to judge whether + | relative application is running now. If not, you should delete responding path + | without recursive manually. + |""".stripMargin) + } + + /** + * Find relative staging output paths, which is conflicted with current + * InsertIntoHadoopFsRelation operation. + */ + private def findConflictedStagingOutputPaths( + fs: FileSystem, + path: Path, + depth: Int, + paths: ListBuffer[Path]): Unit = { + try { + if (fs.exists(path)) { + if (depth == 0) { + paths += path + } else { + for (file <- fs.listStatus(path)) { + findConflictedStagingOutputPaths(fs, file.getPath, depth - 1, paths) + } + } + } + } catch { + case e: Exception => + logWarning("Exception occurred when finding conflicted staging output paths.", e) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 39c594a9bc61..e1e87cd900b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.spark.internal.Logging -import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol +import org.apache.spark.internal.io.{FileSourceWriteDesc, HadoopMapReduceCommitProtocol} import org.apache.spark.sql.internal.SQLConf /** @@ -32,10 +32,14 @@ import org.apache.spark.sql.internal.SQLConf class SQLHadoopMapReduceCommitProtocol( jobId: String, path: String, - dynamicPartitionOverwrite: Boolean = false) - extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + fileSourceWriteDesc: Option[FileSourceWriteDesc]) + extends HadoopMapReduceCommitProtocol(jobId, path, fileSourceWriteDesc) with Serializable with Logging { + def this(jobId: String, path: String, dynamicPartitionOverwrite: Boolean = false) = + this(jobId, path, Some(new FileSourceWriteDesc(dynamicPartitionOverwrite = + dynamicPartitionOverwrite))) + override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { var committer = super.setupCommitter(context) @@ -55,7 +59,7 @@ class SQLHadoopMapReduceCommitProtocol( // The specified output committer is a FileOutputCommitter. // So, we will use the FileOutputCommitter-specified constructor. val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(new Path(path), context) + committer = ctor.newInstance(getOutputPath(context), context) } else { // The specified output committer is just an OutputCommitter. // So, we will use the no-argument constructor. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5b0de1baa553..c4de3eed8b1a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.parallel.immutable.ParVector -import org.apache.spark.{AccumulatorSuite, SparkException} +import org.apache.spark.{AccumulatorSuite, InsertFileSourceConflictException, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.util.StringUtils @@ -38,12 +38,14 @@ import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode} import org.apache.spark.sql.test.{SharedSparkSession, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval class SQLQuerySuite extends QueryTest with SharedSparkSession { + import testImplicits._ setupTestData() @@ -62,7 +64,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { def getFunctions(pattern: String): Seq[Row] = { StringUtils.filterPattern( spark.sessionState.catalog.listFunctions("default").map(_._1.funcName) - ++ FunctionsCommand.virtualOperators, pattern) + ++ FunctionsCommand.virtualOperators, pattern) .map(Row(_)) } @@ -153,14 +155,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { def unindentAndTrim(s: String): String = { s.replaceAll("\n\\s+", "\n").trim } + val beginSqlStmtRe = " > ".r val endSqlStmtRe = ";\n".r + def checkExampleSyntax(example: String): Unit = { val beginStmtNum = beginSqlStmtRe.findAllIn(example).length val endStmtNum = endSqlStmtRe.findAllIn(example).length assert(beginStmtNum === endStmtNum, "The number of ` > ` does not match to the number of `;`") } + val exampleRe = """^(.+);\n(?s)(.+)$""".r val ignoreSet = Set( // One of examples shows getting the current timestamp @@ -238,12 +243,12 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { } test("self join with alias in agg") { - Seq(1, 2, 3) - .map(i => (i, i.toString)) - .toDF("int", "str") - .groupBy("str") - .agg($"str", count("str").as("strCount")) - .createOrReplaceTempView("df") + Seq(1, 2, 3) + .map(i => (i, i.toString)) + .toDF("int", "str") + .groupBy("str") + .agg($"str", count("str").as("strCount")) + .createOrReplaceTempView("df") checkAnswer( sql( @@ -291,17 +296,17 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { // Since the ID is only materialized once, then all of the records // should come from the cache, not by re-computing. Otherwise, the ID // will be different - assert(d0.map(_(0)) === d2.map(_(0))) - assert(d0.map(_(1)) === d2.map(_(1))) + assert(d0.map(_ (0)) === d2.map(_ (0))) + assert(d0.map(_ (1)) === d2.map(_ (1))) - assert(d1.map(_(0)) === d2.map(_(0))) - assert(d1.map(_(1)) === d2.map(_(1))) + assert(d1.map(_ (0)) === d2.map(_ (0))) + assert(d1.map(_ (1)) === d2.map(_ (1))) } test("grouping on nested fields") { spark.read .json(Seq("""{"nested": {"attribute": 1}, "value": 2}""").toDS()) - .createOrReplaceTempView("rows") + .createOrReplaceTempView("rows") checkAnswer( sql( @@ -428,7 +433,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { |FROM testData3x |GROUP BY value """.stripMargin, - (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1))) + (1 to 100).map(i => Row(i.toString, i * 3, i, i, i, 3, 1))) testCodeGen( "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x", Row(100, 1, 50.5, 300, 100) :: Nil) @@ -616,10 +621,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { testData.take(10).toSeq) checkAnswer( - sql(""" - |with q1 as (select * from testData where key= '5'), - |q2 as (select * from testData where key = '4') - |select * from q1 union all select * from q2""".stripMargin), + sql( + """ + |with q1 as (select * from testData where key= '5'), + |q2 as (select * from testData where key = '4') + |select * from q1 union all select * from q2""".stripMargin), Row(5, "5") :: Row(4, "4") :: Nil) } @@ -747,14 +753,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql( """ - |SELECT * FROM - | (SELECT * FROM testdata2 WHERE a = 1) x JOIN - | (SELECT * FROM testdata2 WHERE a = 1) y - |WHERE x.a = y.a""".stripMargin), + |SELECT * FROM + | (SELECT * FROM testdata2 WHERE a = 1) x JOIN + | (SELECT * FROM testdata2 WHERE a = 1) y + |WHERE x.a = y.a""".stripMargin), Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: - Row(1, 2, 1, 1) :: - Row(1, 2, 1, 2) :: Nil) + Row(1, 1, 1, 2) :: + Row(1, 2, 1, 1) :: + Row(1, 2, 1, 2) :: Nil) } } @@ -834,11 +840,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { | ON leftTable.N = rightTable.N """.stripMargin), Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", 3, "C") :: - Row (4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("SPARK-11111 null-safe join should not use cartesian product") { @@ -865,11 +871,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset1 ON subset1.n = lowerCaseData.n"), Row(3, "c", 3) :: - Row(4, "d", 4) :: Nil) + Row(4, "d", 4) :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INNER JOIN subset2 ON subset2.n = lowerCaseData.n"), Row(1, "a", 1) :: - Row(2, "b", 2) :: Nil) + Row(2, "b", 2) :: Nil) } test("mixed-case keywords") { @@ -882,11 +888,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { | oN leftTable.N = rightTable.N """.stripMargin), Row(1, "A", null, null) :: - Row(2, "B", null, null) :: - Row(3, "C", 3, "C") :: - Row(4, "D", 4, "D") :: - Row(null, null, 5, "E") :: - Row(null, null, 6, "F") :: Nil) + Row(2, "B", null, null) :: + Row(3, "C", 3, "C") :: + Row(4, "D", 4, "D") :: + Row(null, null, 5, "E") :: + Row(null, null, 6, "F") :: Nil) } test("select with table name as qualifier") { @@ -956,14 +962,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"), Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: - Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(2, "b") :: Row(3, "c") :: Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"), Row(1, "a") :: Row(1, "a") :: Row(2, "b") :: Row(2, "b") :: Row(3, "c") :: Row(3, "c") :: - Row(4, "d") :: Row(4, "d") :: Nil) + Row(4, "d") :: Row(4, "d") :: Nil) } test("UNION with column mismatches") { @@ -971,7 +977,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"), Row(1, "A") :: Row(1, "a") :: Row(2, "B") :: Row(2, "b") :: Row(3, "C") :: Row(3, "c") :: - Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) + Row(4, "D") :: Row(4, "d") :: Row(5, "E") :: Row(6, "F") :: Nil) // Column type mismatches are not allowed, forcing a type coercion. checkAnswer( sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"), @@ -987,9 +993,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"), Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil) checkAnswer( @@ -1010,9 +1016,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM lowerCaseData"), Row(1, "a") :: - Row(2, "b") :: - Row(3, "c") :: - Row(4, "d") :: Nil) + Row(2, "b") :: + Row(3, "c") :: + Row(4, "d") :: Nil) checkAnswer( sql("SELECT * FROM lowerCaseData INTERSECT SELECT * FROM upperCaseData"), Nil) } @@ -1114,9 +1120,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { test("apply schema") { val schema1 = StructType( StructField("f1", IntegerType, false) :: - StructField("f2", StringType, false) :: - StructField("f3", BooleanType, false) :: - StructField("f4", IntegerType, true) :: Nil) + StructField("f2", StringType, false) :: + StructField("f3", BooleanType, false) :: + StructField("f4", IntegerType, true) :: Nil) val rowRDD1 = unparsedStrings.map { r => val values = r.split(",").map(_.trim) @@ -1131,22 +1137,22 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT * FROM applySchema1"), Row(1, "A1", true, null) :: - Row(2, "B2", false, null) :: - Row(3, "C3", true, null) :: - Row(4, "D4", true, 2147483644) :: Nil) + Row(2, "B2", false, null) :: + Row(3, "C3", true, null) :: + Row(4, "D4", true, 2147483644) :: Nil) checkAnswer( sql("SELECT f1, f4 FROM applySchema1"), Row(1, null) :: - Row(2, null) :: - Row(3, null) :: - Row(4, 2147483644) :: Nil) + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) val schema2 = StructType( StructField("f1", StructType( StructField("f11", IntegerType, false) :: - StructField("f12", BooleanType, false) :: Nil), false) :: - StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) + StructField("f12", BooleanType, false) :: Nil), false) :: + StructField("f2", MapType(StringType, IntegerType, true), false) :: Nil) val rowRDD2 = unparsedStrings.map { r => val values = r.split(",").map(_.trim) @@ -1161,16 +1167,16 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT * FROM applySchema2"), Row(Row(1, true), Map("A1" -> null)) :: - Row(Row(2, false), Map("B2" -> null)) :: - Row(Row(3, true), Map("C3" -> null)) :: - Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) + Row(Row(2, false), Map("B2" -> null)) :: + Row(Row(3, true), Map("C3" -> null)) :: + Row(Row(4, true), Map("D4" -> 2147483644)) :: Nil) checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema2"), Row(1, null) :: - Row(2, null) :: - Row(3, null) :: - Row(4, 2147483644) :: Nil) + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) // The value of a MapType column can be a mutable map. val rowRDD3 = unparsedStrings.map { r => @@ -1187,9 +1193,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), Row(1, null) :: - Row(2, null) :: - Row(3, null) :: - Row(4, 2147483644) :: Nil) + Row(2, null) :: + Row(3, null) :: + Row(4, 2147483644) :: Nil) } test("SPARK-3423 BETWEEN") { @@ -1211,7 +1217,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { test("SPARK-17863: SELECT distinct does not work correctly if order by missing attribute") { checkAnswer( - sql("""select distinct struct.a, struct.b + sql( + """select distinct struct.a, struct.b |from ( | select named_struct('a', 1, 'b', 2, 'c', 3) as struct | union all @@ -1221,13 +1228,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { Row(1, 2) :: Nil) val error = intercept[AnalysisException] { - sql("""select distinct struct.a, struct.b - |from ( - | select named_struct('a', 1, 'b', 2, 'c', 3) as struct - | union all - | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp - |order by struct.a, struct.b - |""".stripMargin) + sql( + """select distinct struct.a, struct.b + |from ( + | select named_struct('a', 1, 'b', 2, 'c', 3) as struct + | union all + | select named_struct('a', 1, 'b', 2, 'c', 4) as struct) tmp + |order by struct.a, struct.b + |""".stripMargin) } assert(error.message contains "cannot resolve '`struct.a`' given input columns: [a, b]") @@ -1251,9 +1259,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) val personWithMeta = spark.createDataFrame(person.rdd, schemaWithMeta) + def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } + personWithMeta.createOrReplaceTempView("personWithMeta") validateMetadata(personWithMeta.select($"name")) validateMetadata(personWithMeta.select($"name")) @@ -1339,12 +1349,12 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { test("SPARK-4154 Query does not work if it has 'not between' in Spark SQL and HQL") { checkAnswer(sql("SELECT key FROM testData WHERE key not between 0 and 10 order by key"), - (11 to 100).map(i => Row(i))) + (11 to 100).map(i => Row(i))) } test("SPARK-4207 Query which has syntax like 'not like' is not working in Spark SQL") { checkAnswer(sql("SELECT key FROM testData WHERE value not like '100%' order by key"), - (1 to 99).map(i => Row(i))) + (1 to 99).map(i => Row(i))) } test("SPARK-4322 Grouping field with struct field as sub expression") { @@ -1382,7 +1392,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { rdd2.toDF().createOrReplaceTempView("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), - (1 to 2).map(i => Row(i))) + (1 to 2).map(i => Row(i))) } test("Multi-column COUNT(DISTINCT ...)") { @@ -1480,30 +1490,30 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { checkAnswer( sql( """ - |SELECT a, sum(b) - |FROM orderByData - |GROUP BY a - |ORDER BY sum(b) + 1 + |SELECT a, sum(b) + |FROM orderByData + |GROUP BY a + |ORDER BY sum(b) + 1 """.stripMargin), Row("4", 3) :: Row("1", 7) :: Row("3", 11) :: Row("2", 15) :: Nil) checkAnswer( sql( """ - |SELECT count(*) - |FROM orderByData - |GROUP BY a - |ORDER BY count(*) + |SELECT count(*) + |FROM orderByData + |GROUP BY a + |ORDER BY count(*) """.stripMargin), Row(2) :: Row(2) :: Row(2) :: Row(2) :: Nil) checkAnswer( sql( """ - |SELECT a - |FROM orderByData - |GROUP BY a - |ORDER BY a, count(*), sum(b) + |SELECT a + |FROM orderByData + |GROUP BY a + |ORDER BY a, count(*), sum(b) """.stripMargin), Row("1") :: Row("2") :: Row("3") :: Row("4") :: Nil) } @@ -1557,7 +1567,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { import org.apache.spark.unsafe.types.CalendarInterval val df = sql("select interval 3 years -3 month 7 week 123 microseconds") - checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7 * 7, 123 ))) + checkAnswer(df, Row(new CalendarInterval(12 * 3 - 3, 7 * 7, 123))) withTempPath(f => { // Currently we don't yet support saving out values of interval data type. val e = intercept[AnalysisException] { @@ -1651,22 +1661,22 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { intercept[AnalysisException] { spark.sql( s""" - |CREATE TEMPORARY VIEW db.t - |USING parquet - |OPTIONS ( - | path '$path' - |) + |CREATE TEMPORARY VIEW db.t + |USING parquet + |OPTIONS ( + | path '$path' + |) """.stripMargin) }.getMessage // If you use backticks to quote the name then it's OK. spark.sql( s""" - |CREATE TEMPORARY VIEW `db.t` - |USING parquet - |OPTIONS ( - | path '$path' - |) + |CREATE TEMPORARY VIEW `db.t` + |USING parquet + |OPTIONS ( + | path '$path' + |) """.stripMargin) checkAnswer(spark.table("`db.t`"), df) } @@ -1924,8 +1934,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val specialCharacterPath = sql( """ - | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM - | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp + | SELECT struct(`col$.a_`, `a.b.c.`) as `r&&b.c` FROM + | (SELECT struct(a, b) as `col$.a_`, struct(b, a) as `a.b.c.` FROM testData2) tmp """.stripMargin) withTempView("specialCharacterTable") { specialCharacterPath.createOrReplaceTempView("specialCharacterTable") @@ -1933,9 +1943,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { specialCharacterPath.select($"`r&&b.c`.*"), nestedStructData.select($"record.*")) checkAnswer( - sql( - "SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), - nestedStructData.select($"record.r1")) + sql( + "SELECT `r&&b.c`.`col$.a_` FROM specialCharacterTable"), + nestedStructData.select($"record.r1")) checkAnswer( sql("SELECT `r&&b.c`.`a.b.c.` FROM specialCharacterTable"), nestedStructData.select($"record.r2")) @@ -2067,12 +2077,13 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { // NullPropagation rule: COUNT(v) got replaced with COUNT(1) because the output column of // UNION was incorrectly considered non-nullable: checkAnswer( - sql("""SELECT count(v) FROM ( - | SELECT v FROM ( - | SELECT 'foo' AS v UNION ALL - | SELECT NULL AS v - | ) my_union WHERE isnull(v) - |) my_subview""".stripMargin), + sql( + """SELECT count(v) FROM ( + | SELECT v FROM ( + | SELECT 'foo' AS v UNION ALL + | SELECT NULL AS v + | ) my_union WHERE isnull(v) + |) my_subview""".stripMargin), Seq(Row(0))) } @@ -2195,7 +2206,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { } test("SPARK-15327: fail to compile generated code with complex data structure") { - withTempDir{ dir => + withTempDir { dir => val json = """ |{"h": {"b": {"c": [{"e": "adfgd"}], "a": [{"e": "testing", "count": 3}], @@ -2650,7 +2661,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { withTable("_tbl") { sql("CREATE TABLE `_tbl`(i INT) USING parquet") sql("INSERT INTO `_tbl` VALUES (1), (2), (3)") - checkAnswer( sql("SELECT * FROM `_tbl`"), Row(1) :: Row(2) :: Row(3) :: Nil) + checkAnswer(sql("SELECT * FROM `_tbl`"), Row(1) :: Row(2) :: Row(3) :: Nil) } } @@ -2827,11 +2838,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { val df = sql(query) val physical = df.queryExecution.sparkPlan val aggregateExpressions = physical.collectFirst { - case agg : HashAggregateExec => agg.aggregateExpressions - case agg : SortAggregateExec => agg.aggregateExpressions + case agg: HashAggregateExec => agg.aggregateExpressions + case agg: SortAggregateExec => agg.aggregateExpressions } - assert (aggregateExpressions.isDefined) - assert (aggregateExpressions.get.size == 1) + assert(aggregateExpressions.isDefined) + assert(aggregateExpressions.get.size == 1) checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil) } @@ -2840,11 +2851,11 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { val df = sql(query) val physical = df.queryExecution.sparkPlan val aggregateExpressions = physical.collectFirst { - case agg : HashAggregateExec => agg.aggregateExpressions - case agg : SortAggregateExec => agg.aggregateExpressions + case agg: HashAggregateExec => agg.aggregateExpressions + case agg: SortAggregateExec => agg.aggregateExpressions } - assert (aggregateExpressions.isDefined) - assert (aggregateExpressions.get.size == 2) + assert(aggregateExpressions.isDefined) + assert(aggregateExpressions.get.size == 2) } test("SPARK-22356: overlapped columns between data and partition schema in data source tables") { @@ -2920,9 +2931,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { withView("spark_25084") { val count = 1000 val df = spark.range(count) - val columns = (0 until 400).map{ i => s"id as id$i" } + val columns = (0 until 400).map { i => s"id as id$i" } val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") - df.selectExpr(columns : _*).createTempView("spark_25084") + df.selectExpr(columns: _*).createTempView("spark_25084") assert( spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) } @@ -3266,9 +3277,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { test("SPARK-29239: Subquery should not cause NPE when eliminating subexpression") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false", - SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false", - SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", - SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ConvertToLocalRelation.ruleName) { + SQLConf.SUBQUERY_REUSE_ENABLED.key -> "false", + SQLConf.CODEGEN_FACTORY_MODE.key -> "CODEGEN_ONLY", + SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ConvertToLocalRelation.ruleName) { withTempView("t1", "t2") { sql("create temporary view t1 as select * from values ('val1a', 10L) as t1(t1a, t1b)") sql("create temporary view t2 as select * from values ('val3a', 110L) as t2(t2a, t2b)") @@ -3313,6 +3324,70 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession { cubeDF.join(cubeDF, "nums"), Row(1, 0, 0) :: Row(2, 0, 0) :: Row(3, 0, 0) :: Nil) } + + test("SPARK-28945 SPARK-29037: Fix the issue that spark gives duplicate result and support" + + " concurrent file source write operations write to different partitions in the same table.") { + withSQLConf(PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.STATIC.toString) { + withTable("ta", "tb", "tc") { + // partitioned table + sql("create table ta(id int, p1 int, p2 int) using parquet partitioned by (p1, p2)") + sql("insert overwrite table ta partition(p1=1,p2) select 1, 3") + val df1 = sql("select * from ta order by p2") + checkAnswer(df1, Array(Row(1, 1, 3))) + sql("insert overwrite table ta partition(p1=1,p2) select 1, 4") + val df2 = sql("select * from ta order by p2") + checkAnswer(df2, Array(Row(1, 1, 4))) + sql("insert overwrite table ta partition(p1=1,p2=5) select 1") + val df3 = sql("select * from ta order by p2") + checkAnswer(df3, Array(Row(1, 1, 4), Row(1, 1, 5))) + sql("insert overwrite table ta select 1, 2, 3") + val df4 = sql("select * from ta order by p2") + checkAnswer(df4, Array(Row(1, 2, 3))) + sql("insert overwrite table ta select 9, 9, 9") + val df5 = sql("select * from ta order by p2") + checkAnswer(df5, Array(Row(9, 9, 9))) + sql("insert into table ta select 6, 6, 6") + val df6 = sql("select * from ta order by p2") + checkAnswer(df6, Array(Row(6, 6, 6), Row(9, 9, 9))) + + // non-partitioned table + sql("create table tb(id int) using parquet") + sql("insert into table tb select 7") + val df7 = sql("select * from tb order by id") + checkAnswer(df7, Array(Row(7))) + sql("insert overwrite table tb select 8") + val df8 = sql("select * from tb order by id") + checkAnswer(df8, Array(Row(8))) + sql("insert into table tb select 9") + val df9 = sql("select * from tb order by id") + checkAnswer(df9, Array(Row(8), Row(9))) + + // detect concurrent conflict + sql("create table tc(id int, p1 int, p2 int) using parquet partitioned by (p1, p2)") + sql("insert overwrite table tc partition(p1=1, p2) select 1, 3") + + val warehouse = SQLConf.get.warehousePath.split(":").last + val tblPath = Array(warehouse, "org.apache.spark.sql.SQLQuerySuite", "tc") + .mkString(File.separator) + val staging1 = new File(Array(tblPath, ".spark-staging-1", "p1=1", "application_1234", + "jobUUID").mkString(File.separator)) + staging1.mkdirs() + + val msg = intercept[InsertFileSourceConflictException]( + sql("insert overwrite table tc partition(p1=1, p2) select 1, 2")).message + assert(msg.contains(".spark-staging-1/p1=1") && msg.contains("application_1234")) + intercept[InsertFileSourceConflictException]( + sql("insert overwrite table tc partition(p1=1, p2=2) select 1")) + intercept[InsertFileSourceConflictException]( + sql("insert overwrite table tc select 1, 2, 3")) + intercept[InsertFileSourceConflictException]( + sql("insert into table tc select 1, 2, 3")) + + sql("insert overwrite table tc partition(p1=2, p2) select 1, 2") + sql("insert overwrite table tc partition(p1=2, p2=3) select 1") + } + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index ab1d1f80e739..4ca260f0c2c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -19,18 +19,23 @@ package org.apache.spark.sql.sources import java.io.File import java.sql.Timestamp +import java.util.concurrent.Semaphore -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat -import org.apache.spark.TestUtils +import org.apache.spark.{SparkContext, TestUtils} import org.apache.spark.internal.Logging +import org.apache.spark.internal.io.FileSourceWriteDesc import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} import org.apache.spark.util.Utils private class OnlyDetectCustomPathFileCommitProtocol(jobId: String, path: String) @@ -43,9 +48,36 @@ private class OnlyDetectCustomPathFileCommitProtocol(jobId: String, path: String } } +private class DetectCorrectOutputPathFileCommitProtocol( + jobId: String, + path: String, + fileSourceWriteDesc: Option[FileSourceWriteDesc]) + extends SQLHadoopMapReduceCommitProtocol(jobId, path, fileSourceWriteDesc) with Serializable + with Logging { + + override def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + val committer = super.setupCommitter(context) + + val newOutputPath = context.getConfiguration.get(FileOutputFormat.OUTDIR, "") + if (dynamicPartitionOverwrite) { + assert(new Path(newOutputPath).getName.startsWith(jobId)) + } else { + assert(newOutputPath == path) + } + committer + } +} + class PartitionedWriteSuite extends QueryTest with SharedSparkSession { import testImplicits._ + // create sparkSession with 4 cores to support concurrent write. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[4]", + "test-partitioned-write-context", + sparkConf.set("spark.sql.testkey", "true"))) + test("write many partitions") { val path = Utils.createTempDir() path.delete() @@ -156,4 +188,68 @@ class PartitionedWriteSuite extends QueryTest with SharedSparkSession { } } } + + test("Output path should be a staging output dir, whose last level path name is jobId," + + " when dynamicPartitionOverwrite is enabled") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[DetectCorrectOutputPathFileCommitProtocol].getName) { + Seq((1, 2)).toDF("a", "b") + .write + .partitionBy("b") + .mode("overwrite") + .saveAsTable("t") + } + } + } + } + + test("Concurrent write to the same table with different partitions should be possible") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("ta", "tb") { + val sem = new Semaphore(0) + Seq((1, 2)).toDF("a", "b") + .write + .partitionBy("b") + .mode("overwrite") + .saveAsTable("ta") + + spark.range(0, 10).toDF("a").write.mode("overwrite").saveAsTable("tb") + val stat1 = "insert overwrite table ta partition(b=1) select a from tb" + val stat2 = "insert overwrite table ta partition(b=2) select a from tb" + val stats = Seq(stat1, stat2) + + var throwable: Option[Throwable] = None + for (i <- 0 until 2) { + new Thread { + override def run(): Unit = { + try { + val stat = stats(i) + sql(stat) + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } + } + }.start() + } + // make sure writing table in two threads are executed. + sem.acquire(2) + throwable.foreach { t => throw improveStackTrace(t) } + + val df1 = spark.range(0, 10).map(x => (x, 1)).toDF("a", "b") + val df2 = spark.range(0, 10).map(x => (x, 2)).toDF("a", "b") + checkAnswer(spark.sql("select a, b from ta where b = 1"), df1) + checkAnswer(spark.sql("select a, b from ta where b = 2"), df2) + } + } + } + + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t + } }