diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index aee92ba928b4..0693b9a00a35 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -1554,6 +1554,7 @@ private[spark] class DAGScheduler(
case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable =>
mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId)
sms.shuffleDep.newShuffleMergeState()
+
case _ =>
}
@@ -1873,15 +1874,28 @@ private[spark] class DAGScheduler(
private[scheduler] def handleTaskCompletion(event: CompletionEvent): Unit = {
val task = event.task
val stageId = task.stageId
+ val stageOption = stageIdToStage.get(task.stageId)
+ val isIndeterministicZombie = event.reason match {
+ case Success if stageOption.isDefined =>
+ val stage = stageOption.get
+ (task.stageAttemptId < stage.latestInfo.attemptNumber()
+ && stage.isIndeterminate) || stage.shouldDiscardResult(task.stageAttemptId)
+
+ case _ => false
+ }
outputCommitCoordinator.taskCompleted(
stageId,
task.stageAttemptId,
task.partitionId,
event.taskInfo.attemptNumber, // this is a task attempt number
- event.reason)
+ if (isIndeterministicZombie) {
+ TaskKilled(reason = "Indeterminate stage needs all tasks to be retried")
+ } else {
+ event.reason
+ })
- if (!stageIdToStage.contains(task.stageId)) {
+ if (stageOption.isEmpty) {
// The stage may have already finished when we get this event -- e.g. maybe it was a
// speculative task. It is important that we send the TaskEnd event in any case, so listeners
// are properly notified and can chose to handle it. For instance, some listeners are
@@ -1893,34 +1907,37 @@ private[spark] class DAGScheduler(
return
}
- val stage = stageIdToStage(task.stageId)
+ val stage = stageOption.get
// Make sure the task's accumulators are updated before any other processing happens, so that
// we can post a task end event before any jobs or stages are updated. The accumulators are
// only updated in certain cases.
event.reason match {
case Success =>
- task match {
- case rt: ResultTask[_, _] =>
- val resultStage = stage.asInstanceOf[ResultStage]
- resultStage.activeJob match {
- case Some(job) =>
- // Only update the accumulator once for each result task.
- if (!job.finished(rt.outputId)) {
- updateAccumulators(event)
- }
- case None => // Ignore update if task's job has finished.
- }
- case _ =>
- updateAccumulators(event)
+ if (!isIndeterministicZombie) {
+ task match {
+ case rt: ResultTask[_, _] =>
+ val resultStage = stage.asInstanceOf[ResultStage]
+ resultStage.activeJob match {
+ case Some(job) =>
+ // Only update the accumulator once for each result task.
+ if (!job.finished(rt.outputId)) {
+ updateAccumulators(event)
+ }
+ case _ => // Ignore update if task's job has finished.
+ }
+ case _ => updateAccumulators(event)
+ }
}
+
case _: ExceptionFailure | _: TaskKilled => updateAccumulators(event)
+
case _ =>
}
if (trackingCacheVisibility) {
// Update rdd blocks' visibility status.
blockManagerMaster.updateRDDBlockVisibility(
- event.taskInfo.taskId, visible = event.reason == Success)
+ event.taskInfo.taskId, visible = event.reason == Success && !isIndeterministicZombie)
}
postTaskEnd(event)
@@ -1936,7 +1953,7 @@ private[spark] class DAGScheduler(
}
task match {
- case rt: ResultTask[_, _] =>
+ case rt: ResultTask[_, _] if !isIndeterministicZombie =>
// Cast to ResultStage here because it's part of the ResultTask
// TODO Refactor this out to a function that accepts a ResultStage
val resultStage = stage.asInstanceOf[ResultStage]
@@ -1984,7 +2001,7 @@ private[spark] class DAGScheduler(
logInfo(log"Ignoring result from ${MDC(RESULT, rt)} because its job has finished")
}
- case smt: ShuffleMapTask =>
+ case smt: ShuffleMapTask if !isIndeterministicZombie =>
val shuffleStage = stage.asInstanceOf[ShuffleMapStage]
// Ignore task completion for old attempt of indeterminate stage
val ignoreIndeterminate = stage.isIndeterminate &&
@@ -2017,6 +2034,8 @@ private[spark] class DAGScheduler(
processShuffleMapStageCompletion(shuffleStage)
}
}
+
+ case _ => // ignore
}
case FetchFailed(bmAddress, shuffleId, _, mapIndex, reduceId, failureMessage) =>
@@ -2121,6 +2140,12 @@ private[spark] class DAGScheduler(
failedStages += failedStage
failedStages += mapStage
if (noResubmitEnqueued) {
+ def generateErrorMessage(stage: Stage): String = {
+ "A shuffle map stage with indeterminate output was failed and retried. " +
+ s"However, Spark cannot rollback the $stage to re-process the input data, " +
+ "and has to fail this job. Please eliminate the indeterminacy by " +
+ "checkpointing the RDD before repartition and try again."
+ }
// If the map stage is INDETERMINATE, which means the map tasks may return
// different result when re-try, we need to re-try all the tasks of the failed
// stage and its succeeding stages, because the input data will be changed after the
@@ -2147,13 +2172,6 @@ private[spark] class DAGScheduler(
}
}
- def generateErrorMessage(stage: Stage): String = {
- "A shuffle map stage with indeterminate output was failed and retried. " +
- s"However, Spark cannot rollback the $stage to re-process the input data, " +
- "and has to fail this job. Please eliminate the indeterminacy by " +
- "checkpointing the RDD before repartition and try again."
- }
-
activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil))
// The stages will be rolled back after checking
@@ -2171,7 +2189,12 @@ private[spark] class DAGScheduler(
abortStage(mapStage, reason, None)
} else {
rollingBackStages += mapStage
+ mapOutputTracker.unregisterAllMapAndMergeOutput(
+ mapStage.shuffleDep.shuffleId)
}
+ } else {
+ mapOutputTracker.unregisterAllMapAndMergeOutput(
+ mapStage.shuffleDep.shuffleId)
}
case resultStage: ResultStage if resultStage.activeJob.isDefined =>
@@ -2179,6 +2202,8 @@ private[spark] class DAGScheduler(
if (numMissingPartitions < resultStage.numTasks) {
// TODO: support to rollback result tasks.
abortStage(resultStage, generateErrorMessage(resultStage), None)
+ } else {
+ resultStage.markAllPartitionsMissing()
}
case _ =>
@@ -2186,6 +2211,19 @@ private[spark] class DAGScheduler(
logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " +
log"we will roll back and rerun below stages which include itself and all its " +
log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}")
+ } else if (failedStage.isIndeterminate) {
+ failedStage match {
+ case resultStage: ResultStage if resultStage.activeJob.isDefined =>
+ val numMissingPartitions = resultStage.findMissingPartitions().length
+ if (numMissingPartitions < resultStage.numTasks) {
+ // TODO: support to rollback result tasks.
+ abortStage(resultStage, generateErrorMessage(resultStage), None)
+ } else {
+ resultStage.markAllPartitionsMissing()
+ }
+
+ case _ =>
+ }
}
// We expect one executor failure to trigger many FetchFailures in rapid succession,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
index 7fdc3186e86b..92f76f9adec8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultStage.scala
@@ -38,6 +38,9 @@ private[spark] class ResultStage(
resourceProfileId: Int)
extends Stage(id, rdd, partitions.length, parents, firstJobId, callSite, resourceProfileId) {
+ @volatile
+ private var discardResultsForAttemptId: Int = -1
+
/**
* The active job for this result stage. Will be empty if the job has already finished
* (e.g., because the job was cancelled).
@@ -54,6 +57,14 @@ private[spark] class ResultStage(
_activeJob = None
}
+ override def makeNewStageAttempt(
+ numPartitionsToCompute: Int,
+ taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = {
+ super.makeNewStageAttempt(numPartitionsToCompute, taskLocalityPreferences)
+ // clear the attemptId set in the attemptIdAllPartitionsMissing
+ discardResultsForAttemptId = -1
+ }
+
/**
* Returns the sequence of partition ids that are missing (i.e. needs to be computed).
*
@@ -64,5 +75,16 @@ private[spark] class ResultStage(
(0 until job.numPartitions).filter(id => !job.finished(id))
}
+ def markAllPartitionsMissing(): Unit = {
+ this.discardResultsForAttemptId = this.latestInfo.attemptNumber()
+ val job = activeJob.get
+ for (id <- 0 until job.numPartitions) {
+ job.finished(id) = false
+ }
+ }
+
+ override def shouldDiscardResult(attemptId: Int): Boolean =
+ this.discardResultsForAttemptId >= attemptId
+
override def toString: String = "ResultStage " + id
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index f35beafd8748..f8420f45482f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -131,4 +131,6 @@ private[scheduler] abstract class Stage(
def isIndeterminate: Boolean = {
rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE
}
+
+ def shouldDiscardResult(attemptId: Int): Boolean = false
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 3e507df706ba..0f00ce302962 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -51,7 +51,9 @@ import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, BlockMan
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, Clock, LongAccumulator, SystemClock, ThreadUtils, Utils}
import org.apache.spark.util.ArrayImplicits._
-class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
+class DAGSchedulerEventProcessLoopTester(
+ dagScheduler: DAGScheduler,
+ dagSchedulerInterceptorOpt: Option[DagSchedulerInterceptor] = None)
extends DAGSchedulerEventProcessLoop(dagScheduler) {
dagScheduler.setEventProcessLoop(this)
@@ -64,12 +66,15 @@ class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler)
if (isProcessing) {
// `DAGSchedulerEventProcessLoop` is guaranteed to process events sequentially. So we should
// buffer events for sequent processing later instead of processing them recursively.
+ dagSchedulerInterceptorOpt.foreach(_.beforeAddingDagEventToQueue(event))
eventQueue += event
+ dagSchedulerInterceptorOpt.foreach(_.afterAddingDagEventToQueue(event))
} else {
try {
isProcessing = true
// Forward event to `onReceive` directly to avoid processing event asynchronously.
onReceive(event)
+ dagSchedulerInterceptorOpt.foreach(_.afterDirectProcessingOfDagEvent(event))
} catch {
case NonFatal(e) => onError(e)
} finally {
@@ -175,6 +180,12 @@ class DummyScheduledFuture(
class DAGSchedulerSuiteDummyException extends Exception
+trait DagSchedulerInterceptor {
+ def beforeAddingDagEventToQueue(event: DAGSchedulerEvent): Unit = {}
+ def afterAddingDagEventToQueue(event: DAGSchedulerEvent): Unit = {}
+ def afterDirectProcessingOfDagEvent(event: DAGSchedulerEvent): Unit = {}
+}
+
class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with TimeLimits {
import DAGSchedulerSuite._
@@ -300,6 +311,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
}
var sparkListener: EventInfoRecordingListener = null
+ var dagSchedulerInterceptor: DagSchedulerInterceptor = null
var blockManagerMaster: BlockManagerMaster = null
var mapOutputTracker: MapOutputTrackerMaster = null
@@ -444,7 +456,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
blockManagerMaster,
sc.env))
- dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler)
+ dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler,
+ Option(dagSchedulerInterceptor))
}
override def afterEach(): Unit = {
@@ -453,6 +466,7 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
dagEventProcessLoopTester.stop()
mapOutputTracker.stop()
broadcastManager.stop()
+ this.dagSchedulerInterceptor = null
} finally {
super.afterEach()
}
@@ -3153,25 +3167,43 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
null))
(shuffleId1, shuffleId2)
}
+ private def constructTwoIndeterminateStage(): (Int, Int) = constructTwoStages(true, true)
- private def constructTwoIndeterminateStage(): (Int, Int) = {
- val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true)
+ private def constructTwoStages(
+ stage1InDeterminate: Boolean,
+ stage2InDeterminate: Boolean,
+ isDependencyBetweenStagesTransitive: Boolean = true): (Int, Int) = {
+ val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = stage1InDeterminate)
val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2))
val shuffleId1 = shuffleDep1.shuffleId
- val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker,
- indeterminate = true)
+ val shuffleMapRdd2 = if (isDependencyBetweenStagesTransitive) {
+ new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker,
+ indeterminate = stage2InDeterminate)
+ } else {
+ new MyRDD(sc, 2, Nil, tracker = mapOutputTracker, indeterminate = stage2InDeterminate)
+ }
val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2))
val shuffleId2 = shuffleDep2.shuffleId
- val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker)
+
+ val finalRdd = if (isDependencyBetweenStagesTransitive) {
+ new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker)
+ } else {
+ new MyRDD(sc, 2, List(shuffleDep1, shuffleDep2), tracker = mapOutputTracker)
+ }
submit(finalRdd, Array(0, 1))
+ val stageId1 = this.scheduler.shuffleIdToMapStage(shuffleId1).id
// Finish the first shuffle map stage.
- completeShuffleMapStageSuccessfully(0, 0, 2)
- assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty))
-
+ completeShuffleMapStageSuccessfully(stageId1, 0, 2)
+ import org.scalatest.concurrent.Eventually._
+ import org.scalatest.matchers.should.Matchers._
+ import org.scalatest.time.SpanSugar._
+ eventually(timeout(1.minutes), interval(500.milliseconds)) {
+ mapOutputTracker.findMissingPartitions(shuffleId1) should equal(Some(Nil))
+ }
(shuffleId1, shuffleId2)
}
@@ -3185,6 +3217,157 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
"Spark can only do this while using the new shuffle block fetching protocol"))
}
+
+
+ test("SPARK-51272: retry all the partitions of result stage, if the first result task" +
+ " has failed and failing ShuffleMap stage is inDeterminate") {
+ this.dagSchedulerInterceptor = createDagInterceptorForSpark51272(
+ () => taskSets.find(_.shuffleId.isEmpty).get.tasks(1), "RELEASE_LATCH")
+
+ val numPartitions = 2
+ // The first shuffle stage is completed by the below function itself which creates two
+ // stages.
+ val (shuffleId1, shuffleId2) = constructTwoStages(
+ stage1InDeterminate = false,
+ stage2InDeterminate = true,
+ isDependencyBetweenStagesTransitive = false)
+ val shuffleStage1 = this.scheduler.shuffleIdToMapStage(shuffleId1)
+ val shuffleStage2 = this.scheduler.shuffleIdToMapStage(shuffleId2)
+ completeShuffleMapStageSuccessfully(shuffleStage2.id, 0, numPartitions)
+ val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage]
+ val activeJob = resultStage.activeJob
+ assert(activeJob.isDefined)
+ // The result stage is still waiting for its 2 tasks to complete
+ assert(resultStage.findMissingPartitions() == Seq.tabulate(numPartitions)(i => i))
+
+ // The below event is going to initiate the retry of previous indeterminate stages, and also
+ // the retry of all result tasks. But before the "ResubmitFailedStages" event is added to the
+ // queue of Scheduler, a successful completion of the result partition task is added to the
+ // event queue. Due to scenario, the bug surfaces where instead of retry of all partitions
+ // of result tasks (2 tasks in total), only some (1 task) get retried
+ runEvent(
+ makeCompletionEvent(
+ taskSets.find(_.stageId == resultStage.id).get.tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"),
+ null))
+
+ import org.scalatest.concurrent.Eventually._
+ import org.scalatest.matchers.should.Matchers._
+ import org.scalatest.time.SpanSugar._
+ eventually(timeout(3.minutes), interval(500.milliseconds)) {
+ shuffleStage1.latestInfo.attemptNumber() should equal(1)
+ }
+ completeShuffleMapStageSuccessfully(shuffleStage1.id, 1, numPartitions)
+
+ eventually(timeout(3.minutes), interval(500.milliseconds)) {
+ shuffleStage2.latestInfo.attemptNumber() should equal(1)
+ }
+ completeShuffleMapStageSuccessfully(shuffleStage2.id, 1, numPartitions)
+ eventually(timeout(3.minutes), interval(500.milliseconds)) {
+ resultStage.latestInfo.attemptNumber() should equal(1)
+ }
+ org.scalatest.Assertions.assert(resultStage.latestInfo.numTasks == numPartitions)
+ org.scalatest.Assertions.assert(resultStage.findMissingPartitions().size == numPartitions)
+ }
+
+ test("SPARK-51272: retry all the partitions of result stage, if the first result task" +
+ " has failed with failing ShuffleStage determinate but result stage has another ShuffleStage" +
+ " which is indeterminate") {
+ this.dagSchedulerInterceptor = createDagInterceptorForSpark51272(
+ () => taskSets.find(_.shuffleId.isEmpty).get.tasks(1), "RELEASE_LATCH")
+
+ val numPartitions = 2
+ // The first shuffle stage is completed by the below function itself which creates two
+ // stages.
+ val (detShuffleId1, indetShuffleId2) = constructTwoStages(
+ stage1InDeterminate = false,
+ stage2InDeterminate = true,
+ isDependencyBetweenStagesTransitive = false)
+ val detShuffleStage1 = this.scheduler.shuffleIdToMapStage(detShuffleId1)
+ val inDetshuffleStage2 = this.scheduler.shuffleIdToMapStage(indetShuffleId2)
+ completeShuffleMapStageSuccessfully(inDetshuffleStage2.id, 0, numPartitions)
+ assert(mapOutputTracker.findMissingPartitions(indetShuffleId2) === Some(Seq.empty))
+ val resultStage = scheduler.stageIdToStage(2).asInstanceOf[ResultStage]
+ val activeJob = resultStage.activeJob
+ assert(activeJob.isDefined)
+ // The result stage is still waiting for its 2 tasks to complete
+ assert(resultStage.findMissingPartitions() == Seq.tabulate(numPartitions)(i => i))
+
+
+ // The below event will cause the first task of result stage to fail.
+ // Below scenario should happen if behaving correctly:
+ // Since the result stage is dependent on two shuffles of which 1 is inDeterminate,
+ // the retry of the ResultStage should be for both tasks, even if the failed shuffle stage
+ // is deterministic, as there is no guarantee at this point, if the indeterminate shuffle
+ // stage 2 has also failed or not. If inDeterminate stage too has hypothetically failed for,
+ // for first result partition1, but successful for result partition2, then re-execution of
+ // of shuffle stage 2 ( indeterminate) , will cause wrong results. So to avoid this, once
+ // an inDeterminate Result Stage is being retried, no successful partitions should be
+ // accepted having stale attempt
+ //
+ runEvent(
+ makeCompletionEvent(
+ taskSets.find(_.shuffleId.isEmpty).get.tasks(0),
+ FetchFailed(makeBlockManagerId("hostA"), detShuffleId1, 0L, 0, 0, "ignored"),
+ null))
+
+ import org.scalatest.concurrent.Eventually._
+ import org.scalatest.matchers.should.Matchers._
+ import org.scalatest.time.SpanSugar._
+ eventually(timeout(3.minutes), interval(500.milliseconds)) {
+ detShuffleStage1.latestInfo.attemptNumber() should equal(1)
+ }
+ completeShuffleMapStageSuccessfully(detShuffleStage1.id, 1, numPartitions)
+
+ // Though the inDetShuffleStage2 has not suffered any loss, but source code of DagScheduler
+ // has code to remove shuffleoutputs based on the lost BlockManager , which in this case will
+ // result in loss of output of shuffle2 also. It looses one partition and hence will be
+ // re-attempted..
+ // But that re-attempt should fetch all partitions!
+ eventually(timeout(3.minutes), interval(500.milliseconds)) {
+ inDetshuffleStage2.latestInfo.attemptNumber() should equal(1)
+ }
+ org.scalatest.Assertions.assert(inDetshuffleStage2.latestInfo.numTasks == 2)
+ org.scalatest.Assertions.assert(inDetshuffleStage2.findMissingPartitions().size == 2)
+ completeShuffleMapStageSuccessfully(inDetshuffleStage2.id, 1, numPartitions)
+ eventually(timeout(3.minutes), interval(500.milliseconds)) {
+ resultStage.latestInfo.attemptNumber() should equal(1)
+ }
+ org.scalatest.Assertions.assert(resultStage.latestInfo.numTasks == numPartitions)
+ }
+
+ test("SPARK-51272: retry all the partitions of Shuffle stage, if any task of ShuffleStage " +
+ " has failed and failing ShuffleMap stage is inDeterminate") {
+ val numPartitions = 2
+ this.dagSchedulerInterceptor = createDagInterceptorForSpark51272(
+ () => taskSets.filter(_.shuffleId.isDefined).maxBy(_.shuffleId.get).tasks(1),
+ makeMapStatus(host = "hostZZZ", reduces = numPartitions))
+ // The first shuffle stage is completed by the below function itself which creates two
+ // indeterminate stages.
+ val (shuffleId1, shuffleId2) = constructTwoStages(
+ stage1InDeterminate = false,
+ stage2InDeterminate = true,
+ isDependencyBetweenStagesTransitive = false
+ )
+ // This will trigger the resubmit failed stage and in before adding resubmit message to the
+ // queue, a successful partition completion event will arrive.
+ runEvent(
+ makeCompletionEvent(
+ taskSets.filter(_.shuffleId.isDefined).maxBy(_.shuffleId.get).tasks(0),
+ FetchFailed(makeBlockManagerId("hostB"), shuffleId2, 0L, 0, 0, "ignored"),
+ null))
+
+ val shuffleStage2 = scheduler.shuffleIdToMapStage(shuffleId2)
+ import org.scalatest.concurrent.Eventually._
+ import org.scalatest.matchers.should.Matchers._
+ import org.scalatest.time.SpanSugar._
+
+ eventually(timeout(30.seconds), interval(500.milliseconds)) {
+ shuffleStage2.latestInfo.attemptNumber() should equal(1)
+ }
+ org.scalatest.Assertions.assert(shuffleStage2.findMissingPartitions().size == numPartitions)
+ }
+
test("SPARK-25341: retry all the succeeding stages when the map stage is indeterminate") {
val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed()
@@ -3192,9 +3375,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
val failedStages = scheduler.failedStages.toSeq
assert(failedStages.map(_.id) == Seq(1, 2))
// Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+ // As the ShuffleMapStage is inDeterminate all the partitions need to be retried
assert(failedStages.collect {
case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage
- }.head.findMissingPartitions() == Seq(0))
+ }.head.findMissingPartitions() == Seq(0, 1))
// The result stage is still waiting for its 2 tasks to complete
assert(failedStages.collect {
case stage: ResultStage => stage
@@ -4163,9 +4347,10 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
val failedStages = scheduler.failedStages.toSeq
assert(failedStages.map(_.id) == Seq(1, 2))
// Shuffle blocks of "hostC" is lost, so first task of the `shuffleMapRdd2` needs to retry.
+ // As the ShuffleMapStage is inDeterminate all the partitions need to be retried
assert(failedStages.collect {
case stage: ShuffleMapStage if stage.shuffleDep.shuffleId == shuffleId2 => stage
- }.head.findMissingPartitions() == Seq(0))
+ }.head.findMissingPartitions() == Seq(0, 1))
// The result stage is still waiting for its 2 tasks to complete
assert(failedStages.collect {
case stage: ResultStage => stage
@@ -5135,6 +5320,54 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti
}
CompletionEvent(task, reason, result, accumUpdates ++ extraAccumUpdates, metricPeaks, taskInfo)
}
+
+ private def createDagInterceptorForSpark51272(latchReleaseTask: () => Task[_], taskResult: Any):
+ DagSchedulerInterceptor = {
+ new DagSchedulerInterceptor {
+ val latch = new CountDownLatch(1)
+ override def beforeAddingDagEventToQueue(event: DAGSchedulerEvent): Unit = {
+ event match {
+ case ResubmitFailedStages =>
+ // Before the ResubmitFailedStages is added to the queue, add the successful
+ // partition task completion.
+ runEvent(makeCompletionEvent(latchReleaseTask(), Success, taskResult))
+
+ case _ =>
+ }
+ }
+
+ override def afterAddingDagEventToQueue(event: DAGSchedulerEvent): Unit = {
+ event match {
+ case CompletionEvent(_, reason, result, _, _, _) =>
+ reason match {
+ case Success if result == taskResult => latch.countDown()
+
+ case _ =>
+ }
+
+ case _ =>
+ }
+ }
+
+ override def afterDirectProcessingOfDagEvent(event: DAGSchedulerEvent): Unit = {
+ event match {
+ case CompletionEvent(_, reason, _, _, _, _) =>
+ reason match {
+ case FetchFailed(_, _, _, _, _, _) =>
+ // Do not allow this thread to exit, till spurious sucessfull task
+ // ( latchRelease task gets in the queue). This would ensure that
+ // ResubmitFailedStages task will always be processed after the spurious task
+ // is processed.
+ latch.await(50, TimeUnit.SECONDS)
+
+ case _ =>
+ }
+
+ case _ =>
+ }
+ }
+ }
+ }
}
class DAGSchedulerAbortStageOffSuite extends DAGSchedulerSuite {
diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml
index 0629c24c56dc..5b825dbe103e 100644
--- a/resource-managers/yarn/pom.xml
+++ b/resource-managers/yarn/pom.xml
@@ -37,6 +37,12 @@
spark-core_${scala.binary.version}
${project.version}
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ test
+
org.apache.spark
spark-network-yarn_${scala.binary.version}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
index 3bf6a6e84a88..7d1f4d7a989a 100644
--- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala
@@ -167,7 +167,9 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers {
extraJars: Seq[String] = Nil,
extraConf: Map[String, String] = Map(),
extraEnv: Map[String, String] = Map(),
- outFile: Option[File] = None): SparkAppHandle.State = {
+ outFile: Option[File] = None,
+ testTimeOutParams: TimeoutParams = TimeoutParams.DEFAULT
+ ): SparkAppHandle.State = {
val deployMode = if (clientMode) "client" else "cluster"
val propsFile = createConfFile(extraClassPath = extraClassPath, extraConf = extraConf)
val env = Map(
@@ -181,10 +183,12 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers {
launcher.setMainClass(klass)
launcher.setAppResource(fakeSparkJar.getAbsolutePath())
}
+
+ val numExecsOpt = extraConf.get(EXECUTOR_INSTANCES.key)
launcher.setSparkHome(sys.props("spark.test.home"))
.setMaster("yarn")
.setDeployMode(deployMode)
- .setConf(EXECUTOR_INSTANCES.key, "1")
+ .setConf(EXECUTOR_INSTANCES.key, numExecsOpt.getOrElse("1"))
.setConf(SparkLauncher.DRIVER_DEFAULT_JAVA_OPTIONS,
s"-Djava.net.preferIPv6Addresses=${Utils.preferIPv6}")
.setPropertiesFile(propsFile)
@@ -210,7 +214,8 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers {
val handle = launcher.startApplication()
try {
- eventually(timeout(3.minutes), interval(1.second)) {
+ eventually(timeout(testTimeOutParams.testTimeOut),
+ interval(testTimeOutParams.timeOutIntervalCheck)) {
assert(handle.getState().isFinal())
}
} finally {
@@ -295,3 +300,9 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers {
}
}
+
+case class TimeoutParams(testTimeOut: Duration, timeOutIntervalCheck: Duration)
+
+object TimeoutParams {
+ val DEFAULT = TimeoutParams(3.minutes, 1.seconds)
+}
diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/SparkHASuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/SparkHASuite.scala
new file mode 100644
index 000000000000..4081d48fead8
--- /dev/null
+++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/SparkHASuite.scala
@@ -0,0 +1,275 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.io.File
+import java.nio.charset.StandardCharsets
+
+import scala.concurrent.duration._
+
+import com.google.common.io.Files
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.{SparkContext, SparkException}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.{DRIVER_MEMORY, EXECUTOR_CORES, EXECUTOR_INSTANCES, EXECUTOR_MEMORY}
+import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart, SparkListenerStageCompleted, SparkListenerStageSubmitted}
+import org.apache.spark.sql.{DataFrame, Encoders, SparkSession}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.IntegerType
+import org.apache.spark.tags.ExtendedYarnTest
+
+
+@ExtendedYarnTest
+class SparkHASuite extends BaseYarnClusterSuite {
+ override def newYarnConfig(): YarnConfiguration = new YarnConfiguration()
+ test("bug SPARK-51016 and SPARK-51272: Indeterminate stage retry giving wrong results") {
+ testBasicYarnApp(
+ Map(
+ SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "2",
+ "spark.task.maxFailures" -> "8",
+ "spark.network.timeout" -> "100000s",
+ "spark.shuffle.sort.bypassMergeThreshold" -> "1",
+ "spark.sql.files.maxPartitionNum" -> "2",
+ "spark.sql.files.minPartitionNum" -> "2",
+ DRIVER_MEMORY.key -> "512m",
+ EXECUTOR_CORES.key -> "1",
+ EXECUTOR_MEMORY.key -> "512m",
+ EXECUTOR_INSTANCES.key -> "2",
+ "spark.ui.enabled" -> "false",
+ "spark.yarn.max.executor.failures" -> "100000"
+ ))
+ }
+
+ private def testBasicYarnApp(conf: Map[String, String] = Map()): Unit = {
+ val result = File.createTempFile("result", null, tempDir)
+ val finalState = runSpark(
+ clientMode = true,
+ mainClassName(SparkHASuite.getClass),
+ appArgs = Seq(result.getAbsolutePath),
+ extraConf = conf,
+ testTimeOutParams = TimeoutParams(30.minutes, 30.seconds))
+ checkResult(finalState, result)
+ }
+}
+
+private object SparkHASuite extends Logging {
+
+ object Counter {
+ var counter = 0
+ var retVal = 12
+
+ def getHash(): Int = this.synchronized {
+ counter += 1
+ val x = retVal
+ if (counter % 6 == 0) {
+ retVal += 1
+ }
+ x
+ }
+ }
+
+ private def getOuterJoinDF(spark: SparkSession) = {
+ import org.apache.spark.sql.functions.udf
+ val myudf = udf(() => Counter.getHash()).asNondeterministic()
+ spark.udf.register("myudf", myudf.asNondeterministic())
+
+ val leftOuter = spark.table("outer").select(
+ col("strleft"), when(isnull(col("pkLeftt")), myudf().
+ cast(IntegerType)).
+ otherwise(col("pkLeftt")).as("pkLeft"))
+
+ val innerRight = spark.table("inner")
+
+ val outerjoin = leftOuter.hint("SHUFFLE_HASH").
+ join(innerRight, col("pkLeft") === col("pkRight"), "left_outer")
+ outerjoin
+ }
+
+ def createBaseTables(spark: SparkSession): Unit = {
+ spark.sql("drop table if exists outer ")
+ spark.sql("drop table if exists inner ")
+ val data = Seq(
+ (java.lang.Integer.valueOf(0), "aa"),
+ (java.lang.Integer.valueOf(1), "aa"),
+ (java.lang.Integer.valueOf(1), "aa"),
+ (java.lang.Integer.valueOf(0), "aa"),
+ (java.lang.Integer.valueOf(0), "aa"),
+ (java.lang.Integer.valueOf(0), "aa"),
+ (null, "bb"),
+ (null, "bb"),
+ (null, "bb"),
+ (null, "bb"),
+ (null, "bb"),
+ (null, "bb")
+ )
+ val data1 = Seq(
+ (java.lang.Integer.valueOf(0), "bb"),
+ (java.lang.Integer.valueOf(1), "bb"))
+ val outerDf = spark.createDataset(data)(
+ Encoders.tuple(Encoders.INT, Encoders.STRING)).toDF("pkLeftt", "strleft")
+ this.logInfo("saving outer table")
+ outerDf.write.format("parquet").partitionBy("strleft").saveAsTable("outer")
+ val innerDf = spark.createDataset(data1)(
+ Encoders.tuple(Encoders.INT, Encoders.STRING)).toDF("pkRight", "strright")
+ this.logInfo("saving inner table")
+ innerDf.write.format("parquet").partitionBy("strright").saveAsTable("inner")
+ }
+
+ def main(args: Array[String]): Unit = {
+ val spark = SparkSession
+ .builder()
+ .appName("Spark51016Suite")
+ .config("spark.extraListeners", classOf[JobListener].getName)
+ .getOrCreate()
+ val sc = SparkContext.getOrCreate()
+
+ val status = new File(args(0))
+ var result = "failure"
+ try {
+ createBaseTables(spark)
+ val outerjoin: DataFrame = getOuterJoinDF(spark)
+ val correctRows = outerjoin.collect()
+ JobListener.inKillMode = true
+ JobListener.killWhen = KillPosition.KILL_IN_STAGE_SUBMISSION
+ for (i <- 0 until 100) {
+ if (i > 49) {
+ JobListener.killWhen = KillPosition.KILL_IN_STAGE_COMPLETION
+ }
+ try {
+ eventually(timeout(3.minutes), interval(100.milliseconds)) {
+ assert(sc.getExecutorIds().size == 2)
+ }
+ val rowsAfterRetry = getOuterJoinDF(spark).collect()
+ if (correctRows.length != rowsAfterRetry.length) {
+ logInfo(s"encounterted test failure incorrect query result. run index = $i ")
+ }
+ assert(correctRows.length == rowsAfterRetry.length,
+ s"correct rows length = ${correctRows.length}," +
+ s" retry rows length = ${rowsAfterRetry.length}")
+ val retriedResults = rowsAfterRetry.toBuffer
+ correctRows.foreach(r => {
+ val index = retriedResults.indexWhere(x =>
+
+ r.getString(0) == x.getString(0) &&
+ (
+ (r.getInt(1) < 2 && r.getInt(1) == x.getInt(1) && r.getInt(2) == x.getInt(2) &&
+ r.getString(3) == x.getString(3))
+ ||
+ (r.isNullAt(2) && r.isNullAt(3) && x.isNullAt(3)
+ && x.isNullAt(2))
+
+ ))
+ assert(index >= 0)
+ retriedResults.remove(index)
+ })
+ assert(retriedResults.isEmpty)
+ logInfo(s"found successful query exec on iter index = $i")
+ } catch {
+ case se: SparkException if se.getMessage.contains("Please eliminate the" +
+ " indeterminacy by checkpointing the RDD before repartition and try again") =>
+ logInfo(s"correctly encountered exception on iter index = $i")
+ // OK expected
+ }
+ }
+ result = "success"
+ } finally {
+ Files.asCharSink(status, StandardCharsets.UTF_8).write(result)
+ sc.stop()
+ }
+ }
+}
+
+object PIDGetter extends Logging {
+ def getExecutorPIds: Seq[Int] = {
+ import scala.sys.process._
+ val output = Seq("ps", "-ef").#|(Seq("grep", "java")).#|(Seq("grep", "executor-id ")).lazyLines
+ logInfo(s"pids obtained = ${output.mkString("\n")} ")
+ if (output.nonEmpty && output.size == 4) {
+ val execPidsStr = output.map(_.trim).filter(_.endsWith("--resourceProfileId 0"))
+ logInfo(s"filtered Pid String obtained = ${execPidsStr.mkString("\n")} ")
+ val pids = execPidsStr.map(str => str.split(" ")(1).toInt).sorted
+ Seq(pids.head, pids(1))
+ } else {
+ Seq.empty
+ }
+ }
+
+ def killExecutor(pid: Int): Unit = {
+ import scala.sys.process._
+ Seq("kill", "-9", pid.toString).!
+
+ }
+
+ def main(args: Array[String]): Unit = {
+ getExecutorPIds
+ }
+}
+
+private[spark] class JobListener extends SparkListener with Logging {
+ private var count: Int = 0
+ @volatile
+ private var pidToKill: Option[Int] = None
+
+ override def onJobStart(jobStart: SparkListenerJobStart): Unit = {
+ if (JobListener.inKillMode) {
+ val execids = PIDGetter.getExecutorPIds
+ assert(execids.size == 2)
+ pidToKill = Option(execids(count % 2))
+ logInfo("Pid to kill = " + pidToKill)
+ count += 1
+ }
+ }
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ if (stageSubmitted.stageInfo.shuffleDepId.isEmpty && pidToKill.nonEmpty &&
+ JobListener.killWhen == KillPosition.KILL_IN_STAGE_SUBMISSION) {
+ val pid = pidToKill.get
+ pidToKill = None
+ logInfo(s"killing executor for pid = $pid")
+ PIDGetter.killExecutor(pid)
+ }
+ }
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ if (stageCompleted.stageInfo.shuffleDepId.exists(_ % 2 == count % 2) && pidToKill.nonEmpty &&
+ JobListener.killWhen == KillPosition.KILL_IN_STAGE_COMPLETION) {
+ val pid = pidToKill.get
+ pidToKill = None
+ logInfo(s"killing executor for pid = $pid")
+ PIDGetter.killExecutor(pid)
+ }
+ }
+}
+
+object KillPosition extends Enumeration {
+ type KillPosition = Value
+ val KILL_IN_STAGE_SUBMISSION, KILL_IN_STAGE_COMPLETION, NONE = Value
+}
+
+object JobListener {
+ @volatile
+ var inKillMode: Boolean = false
+
+ import KillPosition._
+ @volatile
+ var killWhen: KillPosition = NONE
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 7d4f8c3b2564..3a5beb2362af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -115,6 +115,23 @@ abstract class Expression extends TreeNode[Expression] {
*/
lazy val deterministic: Boolean = children.forall(_.deterministic)
+ /**
+ * The information conveyed by this method hasIndeterminism differs from that conveyed
+ * by [[deterministic]] in the way that an [[Attribute]] representing an [[Expression]] having
+ * [[deterministic]] flag false, would have its [[deterministic]] flag true, but it would still
+ * have [[hasIndeterminism]] as true. Because the Attribute's evaluation represents a quantity
+ * which constitutes inDeterminism. Contrasted with [[deterministic]] flag which is always true
+ * for Leaf Expressions like [[AttributeReference]], [[hasIndeterminism]] carries information
+ * about the nature of the evaluated value, represented by the [[Expression]]
+ * @return Boolean true if the expression's evaluated value is a result of some indeterministic
+ * quantity.
+ */
+ def hasIndeterminism: Boolean = _hasIndeterminism
+
+ @transient
+ private lazy val _hasIndeterminism: Boolean = !deterministic ||
+ this.references.exists(_.hasIndeterminism)
+
def nullable: Boolean
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
index bb67c173b946..ea092093d095 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ProjectionOverSchema.scala
@@ -38,7 +38,7 @@ case class ProjectionOverSchema(schema: StructType, output: AttributeSet) {
private def getProjection(expr: Expression): Option[Expression] =
expr match {
case a: AttributeReference if fieldNames.contains(a.name) && output.contains(a) =>
- Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier))
+ Some(a.copy(dataType = schema(a.name).dataType)(a.exprId, a.qualifier, a.hasIndeterminism))
case GetArrayItem(child, arrayItemOrdinal, failOnError) =>
getProjection(child).map {
projection => GetArrayItem(projection, arrayItemOrdinal, failOnError)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 2af6a1ba84ec..f6b8a96dca97 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -111,6 +111,7 @@ abstract class Attribute extends LeafExpression with NamedExpression {
@transient
override lazy val references: AttributeSet = AttributeSet(this)
+ override def hasIndeterminism: Boolean = false
def withNullability(newNullability: Boolean): Attribute
def withQualifier(newQualifier: Seq[String]): Attribute
def withName(newName: String): Attribute
@@ -194,7 +195,8 @@ case class Alias(child: Expression, name: String)(
override def toAttribute: Attribute = {
if (resolved) {
- AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifier)
+ AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifier,
+ this.hasIndeterminism)
} else {
UnresolvedAttribute.quoted(name)
}
@@ -274,7 +276,8 @@ case class AttributeReference(
nullable: Boolean = true,
override val metadata: Metadata = Metadata.empty)(
val exprId: ExprId = NamedExpression.newExprId,
- val qualifier: Seq[String] = Seq.empty[String])
+ val qualifier: Seq[String] = Seq.empty[String],
+ override val hasIndeterminism: Boolean = false)
extends Attribute with Unevaluable {
override lazy val treePatternBits: BitSet = AttributeReferenceTreeBits.bits
@@ -312,7 +315,8 @@ case class AttributeReference(
}
override def newInstance(): AttributeReference =
- AttributeReference(name, dataType, nullable, metadata)(qualifier = qualifier)
+ AttributeReference(name, dataType, nullable, metadata)(qualifier = qualifier,
+ hasIndeterminism = hasIndeterminism)
/**
* Returns a copy of this [[AttributeReference]] with changed nullability.
@@ -321,7 +325,8 @@ case class AttributeReference(
if (nullable == newNullability) {
this
} else {
- AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier)
+ AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifier,
+ hasIndeterminism)
}
}
@@ -329,7 +334,7 @@ case class AttributeReference(
if (name == newName) {
this
} else {
- AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier)
+ AttributeReference(newName, dataType, nullable, metadata)(exprId, qualifier, hasIndeterminism)
}
}
@@ -340,7 +345,7 @@ case class AttributeReference(
if (newQualifier == qualifier) {
this
} else {
- AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier)
+ AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifier, hasIndeterminism)
}
}
@@ -348,12 +353,12 @@ case class AttributeReference(
if (exprId == newExprId) {
this
} else {
- AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier)
+ AttributeReference(name, dataType, nullable, metadata)(newExprId, qualifier, hasIndeterminism)
}
}
override def withMetadata(newMetadata: Metadata): AttributeReference = {
- AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
+ AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier, hasIndeterminism)
}
override def withDataType(newType: DataType): AttributeReference = {
@@ -361,7 +366,7 @@ case class AttributeReference(
}
override protected final def otherCopyArgs: Seq[AnyRef] = {
- exprId :: qualifier :: Nil
+ exprId :: qualifier :: Boolean.box(hasIndeterminism) :: Nil
}
/** Used to signal the column used to calculate an eventTime watermark (e.g. a#1-T{delayMs}) */
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 184f5a2a9485..d8da04ef6456 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -76,7 +76,7 @@ trait ExpressionEvalHelper extends ScalaCheckDrivenPropertyChecks with PlanTestB
case _ => expr.mapChildren(replace)
}
- private def prepareEvaluation(expression: Expression): Expression = {
+ private[expressions] def prepareEvaluation(expression: Expression): Expression = {
val serializer = new JavaSerializer(new SparkConf()).newInstance()
val resolver = ResolveTimeZone
val expr = replace(resolver.resolveTimeZones(expression))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
index bf1c930c0bd0..29f600347ea0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NondeterministicSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, KeyGroupedPartitioning, RangePartitioning}
class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
test("MonotonicallyIncreasingID") {
@@ -31,4 +32,45 @@ class NondeterministicSuite extends SparkFunSuite with ExpressionEvalHelper {
test("InputFileName") {
checkEvaluation(InputFileName(), "")
}
+
+ test("SPARK-51016: has Indeterministic Component") {
+ def assertIndeterminancyComponent(expression: Expression): Unit =
+ assert(prepareEvaluation(expression).hasIndeterminism)
+
+ assertIndeterminancyComponent(MonotonicallyIncreasingID())
+ val alias = Alias(Multiply(MonotonicallyIncreasingID(), Literal(100L)), "al1")()
+ assertIndeterminancyComponent(alias)
+ // For the attribute created from an Alias with deterministic flag false, the attribute would
+ // carry forward that information from Alias, via the hasIndeterminism flag value being true.
+ assertIndeterminancyComponent(alias.toAttribute)
+ // But the Attribute's deterministic flag would be true ( implying it does not carry forward
+ // that inDeterministic nature of evaluated quantity which Attribute represents)
+ assert(prepareEvaluation(alias.toAttribute).deterministic)
+
+ assertIndeterminancyComponent(Multiply(alias.toAttribute, Literal(1000L)))
+ assertIndeterminancyComponent(
+ HashPartitioning(Seq(Multiply(MonotonicallyIncreasingID(), Literal(100L))), 5))
+ assertIndeterminancyComponent(HashPartitioning(Seq(alias.toAttribute), 5))
+ assertIndeterminancyComponent(
+ RangePartitioning(Seq(SortOrder.apply(alias.toAttribute, Descending)), 5))
+ assertIndeterminancyComponent(KeyGroupedPartitioning(Seq(alias.toAttribute), 5))
+ }
+
+ test("SPARK-51016: has Deterministic Component") {
+ def assertNoIndeterminancyComponent(expression: Expression): Unit =
+ assert(!prepareEvaluation(expression).hasIndeterminism)
+
+ assertNoIndeterminancyComponent(Literal(1000L))
+ val alias = Alias(Multiply(Literal(10000L), Literal(100L)), "al1")()
+ assertNoIndeterminancyComponent(alias)
+ assertNoIndeterminancyComponent(alias.toAttribute)
+ assertNoIndeterminancyComponent(
+ HashPartitioning(Seq(Multiply(Literal(10L), Literal(100L))), 5))
+ assertNoIndeterminancyComponent(HashPartitioning(Seq(alias.toAttribute), 5))
+ assertNoIndeterminancyComponent(
+ RangePartitioning(Seq(SortOrder.apply(alias.toAttribute, Descending)), 5))
+ assertNoIndeterminancyComponent(KeyGroupedPartitioning(Seq(alias.toAttribute), 5))
+ }
+
+
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index f06e6ed137cc..062a27075185 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -69,7 +69,8 @@ trait PlanTestBase extends PredicateHelper with SQLHelper { self: Suite =>
protected def rewriteNameFromAttrNullability(plan: LogicalPlan): LogicalPlan = {
plan.transformAllExpressions {
case a @ AttributeReference(name, _, false, _) =>
- a.copy(name = s"*$name")(exprId = a.exprId, qualifier = a.qualifier)
+ a.copy(name = s"*$name")(exprId = a.exprId, qualifier = a.qualifier,
+ hasIndeterminism = a.hasIndeterminism)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
index 31a3f53eb719..2b68b2feb07e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala
@@ -22,15 +22,16 @@ import java.util.function.Supplier
import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future, Promise}
+import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.internal.config
-import org.apache.spark.rdd.{RDD, RDDOperationScope}
+import org.apache.spark.rdd.{DeterministicLevel, MapPartitionsRDD, RDD, RDDOperationScope}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleWriteMetricsReporter, ShuffleWriteProcessor}
import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.logical.Statistics
@@ -43,6 +44,7 @@ import org.apache.spark.util.{MutablePair, ThreadUtils}
import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, RecordComparator}
import org.apache.spark.util.random.XORShiftRandom
+
/**
* Common trait for all shuffle exchange implementations to facilitate pattern matching.
*/
@@ -452,19 +454,34 @@ object ShuffleExchangeExec {
rdd
}
+ val isIndeterministic = newPartitioning match {
+ case expr: Expression => expr.hasIndeterminism
+ case _ => false
+ }
+
// round-robin function is order sensitive if we don't sort the input.
val isOrderSensitive = isRoundRobin && !SQLConf.get.sortBeforeRepartition
if (needToCopyObjectsBeforeShuffle(part)) {
- newRdd.mapPartitionsWithIndexInternal((_, iter) => {
- val getPartitionKey = getPartitionKeyExtractor()
- iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
- }, isOrderSensitive = isOrderSensitive)
+ this.createRddWithPartition(
+ newRdd,
+ (_, iter: Iterator[InternalRow]) => {
+ val getPartitionKey = getPartitionKeyExtractor()
+ iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
+ },
+ isIndeterministic,
+ isOrderSensitive
+ )
+
} else {
- newRdd.mapPartitionsWithIndexInternal((_, iter) => {
- val getPartitionKey = getPartitionKeyExtractor()
- val mutablePair = new MutablePair[Int, InternalRow]()
- iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
- }, isOrderSensitive = isOrderSensitive)
+ this.createRddWithPartition(
+ newRdd,
+ (_, iter: Iterator[InternalRow]) => {
+ val getPartitionKey = getPartitionKeyExtractor()
+ val mutablePair = new MutablePair[Int, InternalRow]()
+ iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }
+ },
+ isIndeterministic,
+ isOrderSensitive)
}
}
@@ -481,6 +498,25 @@ object ShuffleExchangeExec {
dependency
}
+ private def createRddWithPartition[T: ClassTag, U: ClassTag](
+ rdd: RDD[T],
+ f: (Int, Iterator[T]) => Iterator[U],
+ isInDeterminate: Boolean,
+ isOrderSensitive: Boolean): RDD[U] = if (isInDeterminate) {
+ RDDOperationScope.withScope(rdd.sparkContext) {
+ new MapPartitionsRDD(
+ rdd,
+ (_: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter),
+ isOrderSensitive = isOrderSensitive) {
+ override protected def getOutputDeterministicLevel =
+ DeterministicLevel.INDETERMINATE
+ }
+ }
+ } else {
+ rdd.mapPartitionsWithIndexInternal(f, isOrderSensitive = isOrderSensitive)
+ }
+
+
/**
* Create a customized [[ShuffleWriteProcessor]] for SQL which wrap the default metrics reporter
* with [[SQLShuffleWriteMetricsReporter]] as new reporter for [[ShuffleWriteProcessor]].
diff --git a/sql/core/src/test/scala/org/apache/spark/scheduler/ShuffleMapStageTest.scala b/sql/core/src/test/scala/org/apache/spark/scheduler/ShuffleMapStageTest.scala
new file mode 100644
index 000000000000..4733ae38d918
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/scheduler/ShuffleMapStageTest.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.sql.Encoders
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.LongType
+
+class ShuffleMapStageTest extends SharedSparkSession {
+
+ test("SPARK-51016: ShuffleMapStage using indeterministic join keys should be INDETERMINATE") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ val leftDfBase = spark.createDataset(
+ Seq((1L, "aa")))(
+ Encoders.tuple(Encoders.scalaLong, Encoders.STRING)).toDF("pkLeftt", "strleft")
+
+ val rightDf = spark.createDataset(
+ Seq((1L, "11"), (2L, "22")))(
+ Encoders.tuple(Encoders.scalaLong, Encoders.STRING)).toDF("pkRight", "strright")
+
+ val leftDf = leftDfBase.select(
+ col("strleft"), when(isnull(col("pkLeftt")), floor(rand() * Literal(10000000L)).
+ cast(LongType)).
+ otherwise(col("pkLeftt")).as("pkLeft"))
+
+ val join = leftDf.hint("shuffle_hash").
+ join(rightDf, col("pkLeft") === col("pkRight"), "inner")
+ val shuffleStages: Array[ShuffleMapStage] = Array.ofDim(2)
+ spark.sparkContext.addSparkListener(new SparkListener() {
+ var i = 0
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ if (stageSubmitted.stageInfo.shuffleDepId.isDefined) {
+ shuffleStages(i) =
+ spark.sparkContext.dagScheduler.shuffleIdToMapStage(stageSubmitted.stageInfo.stageId)
+ i +=1
+ }
+ }
+ });
+ join.collect()
+ assert(shuffleStages.filter(_.isIndeterminate).size == 1)
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
index ec13d48d45f8..cf339559a169 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ProjectedOrderingAndPartitioningSuite.scala
@@ -17,14 +17,17 @@
package org.apache.spark.sql.execution
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{DeterministicLevel, RDD}
+import org.apache.spark.sql.Encoders
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, PartitioningCollection, UnknownPartitioning}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
+import org.apache.spark.sql.execution.joins.ShuffledHashJoinExec
+import org.apache.spark.sql.functions.{col, floor, isnull, rand, when}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{LongType, StringType}
class ProjectedOrderingAndPartitioningSuite
extends SharedSparkSession with AdaptiveSparkPlanHelper {
@@ -210,6 +213,37 @@ class ProjectedOrderingAndPartitioningSuite
assert(outputOrdering.head.child.asInstanceOf[Attribute].name == "a")
assert(outputOrdering.head.sameOrderExpressions.size == 0)
}
+
+ test("SPARK-51016: ShuffleRDD using indeterministic join keys should be INDETERMINATE") {
+ withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") {
+ val leftDfBase = spark.createDataset(
+ Seq((1L, "aa"), (null, "aa"), (2L, "bb"), (null, "bb"), (3L, "cc"), (null, "cc")))(
+ Encoders.tupleEncoder(Encoders.LONG, Encoders.STRING)).toDF("pkLeftt", "strleft")
+
+ val rightDf = spark.createDataset(
+ Seq((1L, "11"), (2L, "22"), (3L, "33")))(
+ Encoders.tupleEncoder(Encoders.LONG, Encoders.STRING)).toDF("pkRight", "strright")
+
+ val leftDf = leftDfBase.select(
+ col("strleft"), when(isnull(col("pkLeftt")), floor(rand() * Literal(10000000L)).
+ cast(LongType)).
+ otherwise(col("pkLeftt")).as("pkLeft"))
+
+ val join = leftDf.hint("shuffle_hash").
+ join(rightDf, col("pkLeft") === col("pkRight"), "inner")
+
+ join.collect()
+ val finalPlan = join.queryExecution.executedPlan
+ val shuffleHJExec = finalPlan.children(0).asInstanceOf[ShuffledHashJoinExec]
+ assert(shuffleHJExec.left.asInstanceOf[InputAdapter].execute().outputDeterministicLevel ==
+ DeterministicLevel.INDETERMINATE)
+
+ assert(shuffleHJExec.right.asInstanceOf[InputAdapter].execute().outputDeterministicLevel ==
+ DeterministicLevel.UNORDERED)
+
+ assert(shuffleHJExec.execute().outputDeterministicLevel == DeterministicLevel.INDETERMINATE)
+ }
+ }
}
private case class DummyLeafPlanExec(output: Seq[Attribute]) extends LeafExecNode {