diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 6cb4f04ac7f7..854cd9fe08ae 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2775,6 +2775,13 @@ object SparkContext extends Logging { private[spark] val RDD_SCOPE_KEY = "spark.rdd.scope" private[spark] val RDD_SCOPE_NO_OVERRIDE_KEY = "spark.rdd.scope.noOverride" + // just used to record the temporary output directory of HDFS or HIVE + private[spark] val MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR = + "mapreduce.output.fileoutputformat.outputdir" + private[spark] val MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID = + "mapreduce.job.application.attempt.id" + + /** * Executor id for the driver. In earlier versions of Spark, this was ``, but this was * changed to `driver` because the angle brackets caused escaping issues in URLs and XML (see diff --git a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala index 4eeec6386c0b..80511832a4f6 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/SparkHadoopWriter.scala @@ -31,6 +31,7 @@ TaskAttemptContext => NewTaskAttemptContext, TaskAttemptID => NewTaskAttemptID, import org.apache.hadoop.mapreduce.task.{TaskAttemptContextImpl => NewTaskAttemptContextImpl} import org.apache.spark.{SerializableWritable, SparkConf, SparkException, TaskContext} +import org.apache.spark.SparkContext.{MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage @@ -78,6 +79,13 @@ object SparkHadoopWriter extends Logging { val committer = config.createCommitter(commitJobId) committer.setupJob(jobContext) + rdd.context.setLocalProperty(MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR, + jobContext.getConfiguration().get(MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR)) + rdd.context.setLocalProperty(MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, + jobContext.getConfiguration().getInt(MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, 0).toString) + + rdd.setResultStageAllowToRetry(true) + // Try to write all RDD partitions as a Hadoop OutputFormat. try { val ret = sparkContext.runJob(rdd, (context: TaskContext, iter: Iterator[(K, V)]) => { diff --git a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala index 34c04f4025a9..b3f2fbd5e9f0 100644 --- a/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala +++ b/core/src/main/scala/org/apache/spark/partial/ApproximateActionListener.scala @@ -56,6 +56,10 @@ private[spark] class ApproximateActionListener[T, U, R]( } } + override def stageFailed(): Unit = { + finishedTasks = 0 + } + override def jobFailed(exception: Exception): Unit = { synchronized { failure = Some(exception) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index b7284d251224..838bfaf509d4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -46,7 +46,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.resource.ResourceProfile import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{AccumulatorV2, BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.{ExternalAppendOnlyMap, OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, @@ -2084,6 +2084,21 @@ abstract class RDD[T: ClassTag]( } } + private[spark] var isResultStageRetryAllowed = false + + private[spark] def setResultStageAllowToRetry(isRetryAllowed: Boolean): Unit = { + isResultStageRetryAllowed = isRetryAllowed + } + + private[spark] var totalNumRowsAccumulator: Option[AccumulatorV2[_, _]] = None + + private[spark] def reset(): Unit = { + totalNumRowsAccumulator match { + case Some(accumulatorV2) => accumulatorV2.reset() + case _ => + } + } + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 9bd4a6f4478b..3d7367a42f50 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.NotSerializableException +import java.io.{IOException, NotSerializableException} import java.util.Properties import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeoutException, TimeUnit } import java.util.concurrent.atomic.AtomicInteger @@ -30,8 +30,10 @@ import scala.concurrent.duration._ import scala.util.control.NonFatal import com.google.common.util.concurrent.{Futures, SettableFuture} +import org.apache.hadoop.fs.Path import org.apache.spark._ +import org.apache.spark.SparkContext.{MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR} import org.apache.spark.broadcast.Broadcast import org.apache.spark.errors.SparkCoreErrors import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} @@ -416,6 +418,67 @@ private[spark] class DAGScheduler( cacheLocs.clear() } + private def unregisterAllResultOutput(rs: ResultStage): Unit = { + // cleanup finished partitions + rs.activeJob.get.resetAllPartitions() + // cleanup job listener state + rs.activeJob.get.listener.stageFailed() + // cleanup stage commit messages + outputCommitCoordinator.stageEnd(rs.id) + // cleanup temp directory for writing to hive tables/hdfs + cleanupJobAttemptPath() + // cleanup accumulator state for datasource v2 commands + rs.rdd.reset + } + + private def cleanupJobAttemptPath(): Unit = { + val outputDir = sc.getLocalProperty(MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR) + val jobAttemptId = sc.getLocalProperty(MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID) + + if (outputDir != null && outputDir.nonEmpty) { + val jobAttemptPath = new Path(getPendingJobAttemptsPath(new Path(outputDir)), + String.valueOf(jobAttemptId)) + + val fs = jobAttemptPath.getFileSystem(sc.hadoopConfiguration) + + if (fs.exists(jobAttemptPath)) { + var attempts = 0 + val maxAttempts = 10 + while (!fs.delete(jobAttemptPath, true)) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException(s"Job attempt dir: ${jobAttemptPath.getName} " + + s"fail to be deleted after $maxAttempts attempts!") + } + logWarning(s"Job attempt dir: ${jobAttemptPath.getName} " + + s"fail to be deleted at the ${attempts}th retry, not reach the max: $maxAttempts yet," + + s" will retry again in 1000 ms") + Thread.sleep(1000) + } + attempts = 0 + while (!fs.mkdirs(jobAttemptPath)) { + attempts += 1 + if (attempts > maxAttempts) { + throw new IOException(s"Job attempt dir: ${jobAttemptPath.getName} " + + s"fail to be recreate after $maxAttempts attempts!") + } + logWarning(s"Job attempt dir: ${jobAttemptPath.getName} " + + s"fail to be recreate at the ${attempts}th retry, " + + s"not reach the max: $maxAttempts yet, will retry again in 1000 ms") + Thread.sleep(1000) + } + logInfo(s"Job attempt dir: ${jobAttemptPath.getName} has be cleaned") + } else { + logInfo(s"Job attempt dir: ${jobAttemptPath.getName} does not exist " + + s"and does not need to be cleaned") + } + } + } + + def getPendingJobAttemptsPath(out: Path): Path = { + new Path(out, "_temporary") + } + /** * Gets a shuffle map stage if one exists in shuffleIdToMapStage. Otherwise, if the * shuffle map stage doesn't already exist, this method will create the shuffle map stage in @@ -1964,8 +2027,9 @@ private[spark] class DAGScheduler( def generateErrorMessage(stage: Stage): String = { "A shuffle map stage with indeterminate output was failed and retried. " + s"However, Spark cannot rollback the $stage to re-process the input data, " + - "and has to fail this job. Please eliminate the indeterminacy by " + - "checkpointing the RDD before repartition and try again." + "and has to fail this job In the scenario of writing to database. " + + "Please eliminate the indeterminacy by checkpointing the RDD " + + "before repartition and try again." } activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) @@ -1991,8 +2055,18 @@ private[spark] class DAGScheduler( case resultStage: ResultStage if resultStage.activeJob.isDefined => val numMissingPartitions = resultStage.findMissingPartitions().length if (numMissingPartitions < resultStage.numTasks) { - // TODO: support to rollback result tasks. - abortStage(resultStage, generateErrorMessage(resultStage), None) + if (resultStage.rdd.isResultStageRetryAllowed) { + rollingBackStages += resultStage + // FetchFailed from a indeterminate mapStage, + // so the result stage should be reran all tasks. + // if FetchFailed from a determinate mapStage, + // the result stage should not be rollback all partitions + unregisterAllResultOutput(resultStage) + } else { + // TODO: support to rollback result tasks + // in the scenario of writing to database. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } } case _ => diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala index e0f7c8f02132..c61724566cc3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -24,5 +24,6 @@ package org.apache.spark.scheduler */ private[spark] trait JobListener { def taskSucceeded(index: Int, result: Any): Unit + def stageFailed(): Unit def jobFailed(exception: Exception): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index feed83162084..524e887d3d5b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -63,6 +63,10 @@ private[spark] class JobWaiter[T]( } } + override def stageFailed(): Unit = { + finishedTasks.getAndSet(0) + } + override def jobFailed(exception: Exception): Unit = { if (!jobPromise.tryFailure(exception)) { logWarning("Ignore failure", exception) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index fc7aa06e41ef..b4743a54cf46 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.File import java.util.Properties import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, TimeUnit} import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference} @@ -25,6 +26,7 @@ import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.util.control.NonFatal +import org.apache.hadoop.fs.{FileSystem, Path} import org.mockito.Mockito._ import org.roaringbitmap.RoaringBitmap import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} @@ -32,6 +34,7 @@ import org.scalatest.exceptions.TestFailedException import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.SparkContext.{MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR} import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.executor.ExecutorMetrics import org.apache.spark.internal.config @@ -304,6 +307,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti var failure: Exception = _ val jobListener = new JobListener() { override def taskSucceeded(index: Int, result: Any) = results.put(index, result) + override def stageFailed(): Unit = results.clear() override def jobFailed(exception: Exception) = { failure = exception } } @@ -312,6 +316,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti val results = new HashMap[Int, Any] var failure: Exception = null override def taskSucceeded(index: Int, result: Any): Unit = results.put(index, result) + override def stageFailed(): Unit = results.clear() override def jobFailed(exception: Exception): Unit = { failure = exception } } @@ -699,6 +704,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti var failureReason: Option[Exception] = None val fakeListener = new JobListener() { override def taskSucceeded(partition: Int, value: Any): Unit = numResults += 1 + override def stageFailed(): Unit = numResults = 0 override def jobFailed(exception: Exception): Unit = { failureReason = Some(exception) } @@ -1852,6 +1858,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti class FailureRecordingJobListener() extends JobListener { var failureMessage: String = _ override def taskSucceeded(index: Int, result: Any): Unit = {} + override def stageFailed(): Unit = {} override def jobFailed(exception: Exception): Unit = { failureMessage = exception.getMessage } } val listener1 = new FailureRecordingJobListener() @@ -3022,6 +3029,136 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assertDataStructuresEmpty() } + test("SPARK-25342: `isResultStageRetryAllowed` Indicates whether the Result stage can retry") { + // When writing to file systems, `isResultStageRetryAllowed` will be set to true + // and the result stage will be retried + + // 1. Abort the job since the result stage of finalRdd does not support to retry + // RDD's `isResultStageRetryAllowed` is false + val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true) + assertResultStageFailToRollback(shuffleMapRdd) + + + // 2. Allow result stage to retry since RDD's `isResultStageRetryAllowed` is set to true + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + val shuffleDep = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + finalRdd.setResultStageAllowToRetry(true) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, Array.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), + null)) + + assert(scheduler.failedStages.size == 2 && + scheduler.failedStages.map(p => p.id).toSeq === Seq(2, 3) && + scheduler.failedStages.exists(p => p.isInstanceOf[ShuffleMapStage]) && + scheduler.failedStages.exists(p => p.isInstanceOf[ResultStage])) + } + + test("SPARK-25342: cleanup temp messages before retrying result stage") { + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) + + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + + val acc = Some(sc.longAccumulator) + acc.get.add(1) + assert(acc.get.value == 1) + + // Allow the result stage of finalRDD to retry + finalRdd.setResultStageAllowToRetry(true) + finalRdd.totalNumRowsAccumulator = acc + + // Create a temporary directory as the temporary output of the job + val localFs = FileSystem.getLocal(sc.hadoopConfiguration) + val rootOutputPath = localFs.makeQualified( + new Path(System.getProperty("java.io.tmpdir") + File.separator + "output")) + if (!localFs.exists(rootOutputPath)) { + localFs.mkdirs(rootOutputPath) + } + val attemptOutputPath = localFs.makeQualified( + new Path(rootOutputPath.toUri.getPath + File.separator + "_temporary/0")) + if (!localFs.exists(attemptOutputPath)) { + localFs.mkdirs(attemptOutputPath) + } + val taskOutputFile = localFs.makeQualified( + new Path(attemptOutputPath.toUri.getPath + File.separator + "r_000000_0")) + if (!localFs.exists(taskOutputFile)) { + localFs.createNewFile(taskOutputFile) + } + assert(localFs.listStatus(attemptOutputPath).length == 1) + + sc.setLocalProperty(MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR, rootOutputPath.toUri.getPath) + sc.setLocalProperty(MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, String.valueOf(0)) + + submit(finalRdd, Array(0, 1), properties = sc.localProperties.get) + + // Finish the first shuffle map stage. + completeShuffleMapStageSuccessfully(0, 0, 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + // Finish the second shuffle map stage. + completeShuffleMapStageSuccessfully(1, 0, 2, Seq("hostC", "hostD")) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // The first task of the final stage succeed + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), Success, 11, + Seq.empty, Array.empty, createFakeTaskInfoWithId(0))) + + // The second task of the final stage failed with fetch failure + runEvent(makeCompletionEvent( + taskSets(2).tasks(1), + FetchFailed(makeBlockManagerId("hostC"), shuffleId2, 0L, 0, 0, "ignored"), + null)) + + // stage 1, 2 will be retried and reran all tasks + assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2) && + scheduler.failedStages.exists(p => p.isInstanceOf[ShuffleMapStage]) && + scheduler.failedStages.exists(p => p.isInstanceOf[ResultStage])) + + // Resubmit failed stages + scheduler.resubmitFailedStages() + + // First shuffle map stage resubmitted and reran all tasks. + assert(taskSets(3).stageId == 1) + assert(taskSets(3).stageAttemptId == 1) + assert(taskSets(3).tasks.length == 2) + + // Finish mapStage 1 + completeShuffleMapStageSuccessfully(1, 1, 2, Seq("hostE", "hostF")) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // Result stage success, all job ended. + complete(taskSets(4), Seq((Success, 11), (Success, 12))) + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + + assert(localFs.exists(attemptOutputPath) && localFs.listStatus(attemptOutputPath).length == 0) + assert(acc.get == finalRdd.totalNumRowsAccumulator.get && acc.get.value == 0) + localFs.delete(rootOutputPath.getParent, true) + } + test("SPARK-25341: continuous indeterminate stage roll back") { // shuffleMapRdd1/2/3 are all indeterminate. val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) @@ -3118,6 +3255,30 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) } + private def findAllStagesToRetry(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + finalRdd.setResultStageAllowToRetry(true) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(taskSets.length - 1, 0, numShufflePartitions = 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Finish the first task of the result stage + runEvent(makeCompletionEvent( + taskSets.last.tasks(0), Success, 42, + Seq.empty, Array.empty, createFakeTaskInfoWithId(0))) + + // Fail the second task with FetchFailed. + runEvent(makeCompletionEvent( + taskSets.last.tasks(1), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), + null)) + } + test("SPARK-23207: cannot rollback a result stage") { val shuffleMapRdd = new MyRDD(sc, 2, Nil, indeterminate = true) assertResultStageFailToRollback(shuffleMapRdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index d826c7685742..5e199adfec78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ +import org.apache.spark.SparkContext.{MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR} import org.apache.spark.internal.Logging import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.shuffle.FetchFailedException @@ -199,6 +200,11 @@ object FileFormatWriter extends Logging { // V1 write command will be empty). if (Utils.isTesting) outputOrderingMatched = orderingMatched + sparkSession.sparkContext.setLocalProperty(MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR, + job.getConfiguration().get(MAPREDUCE_OUTPUT_FILEOUTPUTFORMAT_OUTPUTDIR)) + sparkSession.sparkContext.setLocalProperty(MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, + job.getConfiguration().getInt(MAPREDUCE_JOB_APPLICATION_ATTEMPT_ID, 0).toString) + try { val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { (empty2NullPlan.execute(), None) @@ -231,6 +237,8 @@ object FileFormatWriter extends Logging { rdd } + rddWithNonEmptyPartitions.setResultStageAllowToRetry(true) + val jobIdInstant = new Date().getTime val ret = new Array[WriteTaskResult](rddWithNonEmptyPartitions.partitions.length) sparkSession.sparkContext.runJob( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index d23a9e51f658..a7d30841b764 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric, SQLMetrics} import org.apache.spark.sql.util.CaseInsensitiveStringMap -import org.apache.spark.util.{LongAccumulator, Utils} +import org.apache.spark.util.Utils /** * Deprecated logical plan for writing data into data source v2. This is being replaced by more @@ -365,7 +365,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { PhysicalWriteInfoImpl(rdd.getNumPartitions)) val useCommitCoordinator = batchWrite.useCommitCoordinator val messages = new Array[WriterCommitMessage](rdd.partitions.length) - val totalNumRowsAccumulator = new LongAccumulator() + val totalNumRowsAccumulator = Some(sparkContext.longAccumulator) logInfo(s"Start processing data source write support: $batchWrite. " + s"The input RDD has ${messages.length} partitions.") @@ -373,6 +373,9 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { // Avoid object not serializable issue. val writeMetrics: Map[String, SQLMetric] = customMetrics + rdd.setResultStageAllowToRetry(true) + rdd.totalNumRowsAccumulator = totalNumRowsAccumulator + try { sparkContext.runJob( rdd, @@ -383,7 +386,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { (index, result: DataWritingSparkTaskResult) => { val commitMessage = result.writerCommitMessage messages(index) = commitMessage - totalNumRowsAccumulator.add(result.numRows) + totalNumRowsAccumulator.get.add(result.numRows) batchWrite.onDataWriterCommit(commitMessage) } ) @@ -391,7 +394,7 @@ trait V2TableWriteExec extends V2CommandExec with UnaryExecNode { logInfo(s"Data source write support $batchWrite is committing.") batchWrite.commit(messages) logInfo(s"Data source write support $batchWrite committed.") - commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.value)) + commitProgress = Some(StreamWriterCommitProgress(totalNumRowsAccumulator.get.value)) } catch { case cause: Throwable => logError(s"Data source write support $batchWrite is aborting.")