diff --git a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala index 37b08980db877..a8249e123fa00 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SchedulerIntegrationSuite.scala @@ -553,10 +553,10 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor */ testScheduler("multi-stage job") { - def stageToOutputParts(stageId: Int): Int = { - stageId match { + def shuffleIdToOutputParts(shuffleId: Int): Int = { + shuffleId match { case 0 => 10 - case 2 => 20 + case 1 => 20 case _ => 30 } } @@ -577,11 +577,12 @@ class BasicSchedulerIntegrationSuite extends SchedulerIntegrationSuite[SingleCor // b/c the stage numbering is non-deterministic, so stage number alone doesn't tell // us what to check } - (task.stageId, task.stageAttemptId, task.partitionId) match { case (stage, 0, _) if stage < 4 => + val shuffleId = + scheduler.stageIdToStage(stage).asInstanceOf[ShuffleMapStage].shuffleDep.shuffleId backend.taskSuccess(taskDescription, - DAGSchedulerSuite.makeMapStatus("hostA", stageToOutputParts(stage))) + DAGSchedulerSuite.makeMapStatus("hostA", shuffleIdToOutputParts(shuffleId))) case (4, 0, partition) => backend.taskSuccess(taskDescription, 4321 + partition) }