diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 11ce608f52ee..317e1932d6b0 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter, FileOutputFormat} import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.internal.Logging @@ -91,7 +91,31 @@ class HadoopMapReduceCommitProtocol( */ private def stagingDir = new Path(path, ".spark-staging-" + jobId) + /** + * Get the desired output path for the job. The output will be [[path]] when + * dynamicPartitionOverwrite is disabled, otherwise, it will be [[stagingDir]]. We choose + * [[stagingDir]] over [[path]] to avoid potential collision of concurrent write jobs as the same + * output will be specified when writing to the same table dynamically. + * + * @return Path the desired output path. + */ + protected def getOutputPath(context: TaskAttemptContext): Path = { + if (dynamicPartitionOverwrite) { + val conf = context.getConfiguration + val outputPath = stagingDir.getFileSystem(conf).makeQualified(stagingDir) + outputPath + } else { + new Path(path) + } + } + protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + // set output path to stagingDir to avoid potential collision of multiple concurrent write tasks + if (dynamicPartitionOverwrite) { + val newOutputPath = getOutputPath(context) + context.getConfiguration.set(FileOutputFormat.OUTDIR, newOutputPath.toString) + } + val format = context.getOutputFormatClass.getConstructor().newInstance() // If OutputFormat is Configurable, we should set conf to it. format match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 39c594a9bc61..ef499ee26ace 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -55,7 +55,8 @@ class SQLHadoopMapReduceCommitProtocol( // The specified output committer is a FileOutputCommitter. // So, we will use the FileOutputCommitter-specified constructor. val ctor = clazz.getDeclaredConstructor(classOf[Path], classOf[TaskAttemptContext]) - committer = ctor.newInstance(new Path(path), context) + val outputPath = getOutputPath(context) + committer = ctor.newInstance(outputPath, context) } else { // The specified output committer is just an OutputCommitter. // So, we will use the no-argument constructor. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index ab1d1f80e739..183c27861edd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -19,10 +19,13 @@ package org.apache.spark.sql.sources import java.io.File import java.sql.Timestamp +import java.util.concurrent.Semaphore -import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapreduce.{OutputCommitter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat -import org.apache.spark.TestUtils +import org.apache.spark.{SparkContext, TestUtils} import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils @@ -30,7 +33,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode +import org.apache.spark.sql.test.{SharedSparkSession, TestSparkSession} import org.apache.spark.util.Utils private class OnlyDetectCustomPathFileCommitProtocol(jobId: String, path: String) @@ -43,9 +47,34 @@ private class OnlyDetectCustomPathFileCommitProtocol(jobId: String, path: String } } +private class DetectCorrectOutputPathFileCommitProtocol( + jobId: String, path: String, dynamicPartitionOverwrite: Boolean) + extends SQLHadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { + + override def setupCommitter(context: TaskAttemptContext): OutputCommitter = { + val committer = super.setupCommitter(context) + + val newOutputPath = context.getConfiguration.get(FileOutputFormat.OUTDIR, "") + if (dynamicPartitionOverwrite) { + assert(new Path(newOutputPath).getName.startsWith(".spark-staging")) + } else { + assert(newOutputPath == path) + } + committer + } +} + class PartitionedWriteSuite extends QueryTest with SharedSparkSession { import testImplicits._ + // create sparkSession with 4 cores to support concurrent write. + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[4]", + "test-partitioned-write-context", + sparkConf.set("spark.sql.testkey", "true"))) + test("write many partitions") { val path = Utils.createTempDir() path.delete() @@ -156,4 +185,65 @@ class PartitionedWriteSuite extends QueryTest with SharedSparkSession { } } } + + test("Output path should be staging dir when dynamicPartitionOverwrite is enabled") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[DetectCorrectOutputPathFileCommitProtocol].getName) { + Seq((1, 2)).toDF("a", "b") + .write + .partitionBy("b") + .mode("overwrite") + .saveAsTable("t") + } + } + } + } + + test("Concurrent write to the same table with different partitions should be possible") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + val sem = new Semaphore(0) + Seq((1, 2)).toDF("a", "b") + .write + .partitionBy("b") + .mode("overwrite") + .saveAsTable("t") + + val df1 = spark.range(0, 10).map(x => (x, 1)).toDF("a", "b") + val df2 = spark.range(0, 10).map(x => (x, 2)).toDF("a", "b") + val dfs = Seq(df1, df2) + + var throwable: Option[Throwable] = None + for (i <- 0 until 2) { + new Thread { + override def run(): Unit = { + try { + dfs(i) + .write + .mode("overwrite") + .insertInto("t") + } catch { + case t: Throwable => + throwable = Some(t) + } finally { + sem.release() + } + } + }.start() + } + // make sure writing table in two threads are executed. + sem.acquire(2) + throwable.foreach { t => throw improveStackTrace(t) } + checkAnswer(spark.sql("select a, b from t where b = 1"), df1) + checkAnswer(spark.sql("select a, b from t where b = 2"), df2) + } + } + } + + private def improveStackTrace(t: Throwable): Throwable = { + t.setStackTrace(t.getStackTrace ++ Thread.currentThread.getStackTrace) + t + } }