Skip to content

Commit a6fc300

Browse files
advancedxycloud-fan
authored andcommitted
[SPARK-22897][CORE] Expose stageAttemptId in TaskContext
## What changes were proposed in this pull request? stageAttemptId added in TaskContext and corresponding construction modification ## How was this patch tested? Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE <[email protected]> Closes #20082 from advancedxy/SPARK-22897.
1 parent e0c090f commit a6fc300

File tree

13 files changed

+54
-11
lines changed

13 files changed

+54
-11
lines changed

core/src/main/scala/org/apache/spark/TaskContext.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ object TaskContext {
6666
* An empty task context that does not represent an actual task. This is only used in tests.
6767
*/
6868
private[spark] def empty(): TaskContextImpl = {
69-
new TaskContextImpl(0, 0, 0, 0, null, new Properties, null)
69+
new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null)
7070
}
7171
}
7272

@@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable {
150150
*/
151151
def stageId(): Int
152152

153+
/**
154+
* How many times the stage that this task belongs to has been attempted. The first stage attempt
155+
* will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt
156+
* numbers.
157+
*/
158+
def stageAttemptNumber(): Int
159+
153160
/**
154161
* The ID of the RDD partition that is computed by this task.
155162
*/

core/src/main/scala/org/apache/spark/TaskContextImpl.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,9 @@ import org.apache.spark.util._
4141
* `TaskMetrics` & `MetricsSystem` objects are not thread safe.
4242
*/
4343
private[spark] class TaskContextImpl(
44-
val stageId: Int,
45-
val partitionId: Int,
44+
override val stageId: Int,
45+
override val stageAttemptNumber: Int,
46+
override val partitionId: Int,
4647
override val taskAttemptId: Long,
4748
override val attemptNumber: Int,
4849
override val taskMemoryManager: TaskMemoryManager,

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ private[spark] abstract class Task[T](
7979
SparkEnv.get.blockManager.registerTask(taskAttemptId)
8080
context = new TaskContextImpl(
8181
stageId,
82+
stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
8283
partitionId,
8384
taskAttemptId,
8485
attemptNumber,

core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ public static void test() {
3838
tc.attemptNumber();
3939
tc.partitionId();
4040
tc.stageId();
41+
tc.stageAttemptNumber();
4142
tc.taskAttemptId();
4243
}
4344

@@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) {
5152
context.isCompleted();
5253
context.isInterrupted();
5354
context.stageId();
55+
context.stageAttemptNumber();
5456
context.partitionId();
5557
context.addTaskCompletionListener(this);
5658
}

core/src/test/scala/org/apache/spark/ShuffleSuite.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
363363

364364
// first attempt -- its successful
365365
val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
366-
new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
366+
new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem))
367367
val data1 = (1 to 10).map { x => x -> x}
368368

369369
// second attempt -- also successful. We'll write out different data,
370370
// just to simulate the fact that the records may get written differently
371371
// depending on what gets spilled, what gets combined, etc.
372372
val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
373-
new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
373+
new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem))
374374
val data2 = (11 to 20).map { x => x -> x}
375375

376376
// interleave writes of both attempts -- we want to test that both attempts can occur
@@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
398398
}
399399

400400
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
401-
new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
401+
new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
402402
val readData = reader.read().toIndexedSeq
403403
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
404404

core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ object MemoryTestingUtils {
2929
val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0)
3030
new TaskContextImpl(
3131
stageId = 0,
32+
stageAttemptNumber = 0,
3233
partitionId = 0,
3334
taskAttemptId = 0,
3435
attemptNumber = 0,

core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager
2929
import org.apache.spark.metrics.source.JvmSource
3030
import org.apache.spark.network.util.JavaUtils
3131
import org.apache.spark.rdd.RDD
32+
import org.apache.spark.shuffle.FetchFailedException
3233
import org.apache.spark.util._
3334

3435
class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext {
@@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
158159
assert(attemptIdsWithFailedTask.toSet === Set(0, 1))
159160
}
160161

162+
test("TaskContext.stageAttemptNumber getter") {
163+
sc = new SparkContext("local[1,2]", "test")
164+
165+
// Check stageAttemptNumbers are 0 for initial stage
166+
val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ =>
167+
Seq(TaskContext.get().stageAttemptNumber()).iterator
168+
}.collect()
169+
assert(stageAttemptNumbers.toSet === Set(0))
170+
171+
// Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException
172+
val stageAttemptNumbersWithFailedStage =
173+
sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ =>
174+
val stageAttemptNumber = TaskContext.get().stageAttemptNumber()
175+
if (stageAttemptNumber < 2) {
176+
// Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception
177+
// will only trigger task resubmission in the same stage.
178+
throw new FetchFailedException(null, 0, 0, 0, "Fake")
179+
}
180+
Seq(stageAttemptNumber).iterator
181+
}.collect()
182+
183+
assert(stageAttemptNumbersWithFailedStage.toSet === Set(2))
184+
}
185+
161186
test("accumulators are updated on exception failures") {
162187
// This means use 1 core and 4 max task failures
163188
sc = new SparkContext("local[1,4]", "test")
@@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
190215
// accumulator updates from it.
191216
val taskMetrics = TaskMetrics.empty
192217
val task = new Task[Int](0, 0, 0) {
193-
context = new TaskContextImpl(0, 0, 0L, 0,
218+
context = new TaskContextImpl(0, 0, 0, 0L, 0,
194219
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
195220
new Properties,
196221
SparkEnv.get.metricsSystem,
@@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark
213238
// accumulator updates from it.
214239
val taskMetrics = TaskMetrics.registered
215240
val task = new Task[Int](0, 0, 0) {
216-
context = new TaskContextImpl(0, 0, 0L, 0,
241+
context = new TaskContextImpl(0, 0, 0, 0L, 0,
217242
new TaskMemoryManager(SparkEnv.get.memoryManager, 0L),
218243
new Properties,
219244
SparkEnv.get.metricsSystem,

core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach {
6262
private def withTaskId[T](taskAttemptId: Long)(block: => T): T = {
6363
try {
6464
TaskContext.setTaskContext(
65-
new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null))
65+
new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null))
6666
block
6767
} finally {
6868
TaskContext.unset()

project/MimaExcludes.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ object MimaExcludes {
3636

3737
// Exclude rules for 2.3.x
3838
lazy val v23excludes = v22excludes ++ Seq(
39+
// [SPARK-22897] Expose stageAttemptId in TaskContext
40+
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"),
41+
3942
// SPARK-22789: Map-only continuous processing execution
4043
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"),
4144
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"),

sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite
7070

7171
TaskContext.setTaskContext(new TaskContextImpl(
7272
stageId = 0,
73+
stageAttemptNumber = 0,
7374
partitionId = 0,
7475
taskAttemptId = Random.nextInt(10000),
7576
attemptNumber = 0,

0 commit comments

Comments
 (0)