Skip to content

Commit a51e5f6

Browse files
committed
[SPARK-3543] Clean up Java TaskContext implementation.
1 parent 5e34855 commit a51e5f6

File tree

5 files changed

+22
-29
lines changed

5 files changed

+22
-29
lines changed

core/src/main/java/org/apache/spark/TaskContext.java

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public class TaskContext implements Serializable {
5656
* @param taskMetrics performance metrics of the task
5757
*/
5858
@DeveloperApi
59-
public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally,
59+
public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally,
6060
TaskMetrics taskMetrics) {
6161
this.attemptId = attemptId;
6262
this.partitionId = partitionId;
@@ -65,7 +65,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean
6565
this.taskMetrics = taskMetrics;
6666
}
6767

68-
6968
/**
7069
* :: DeveloperApi ::
7170
* Contextual information about a task which can be read or mutated during execution.
@@ -76,16 +75,14 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean
7675
* @param runningLocally whether the task is running locally in the driver JVM
7776
*/
7877
@DeveloperApi
79-
public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
80-
Boolean runningLocally) {
78+
public TaskContext(int stageId, int partitionId, long attemptId, boolean runningLocally) {
8179
this.attemptId = attemptId;
8280
this.partitionId = partitionId;
8381
this.runningLocally = runningLocally;
8482
this.stageId = stageId;
8583
this.taskMetrics = TaskMetrics.empty();
8684
}
8785

88-
8986
/**
9087
* :: DeveloperApi ::
9188
* Contextual information about a task which can be read or mutated during execution.
@@ -95,7 +92,7 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
9592
* @param attemptId the number of attempts to execute this task
9693
*/
9794
@DeveloperApi
98-
public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
95+
public TaskContext(int stageId, int partitionId, long attemptId) {
9996
this.attemptId = attemptId;
10097
this.partitionId = partitionId;
10198
this.runningLocally = false;
@@ -107,9 +104,9 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
107104
new ThreadLocal<TaskContext>();
108105

109106
/**
110-
* :: Internal API ::
111-
* This is spark internal API, not intended to be called from user programs.
112-
*/
107+
* :: Internal API ::
108+
* This is spark internal API, not intended to be called from user programs.
109+
*/
113110
public static void setTaskContext(TaskContext tc) {
114111
taskContext.set(tc);
115112
}
@@ -118,10 +115,8 @@ public static TaskContext get() {
118115
return taskContext.get();
119116
}
120117

121-
/**
122-
* :: Internal API ::
123-
*/
124-
public static void remove() {
118+
/** :: Internal API :: */
119+
public static void unset() {
125120
taskContext.remove();
126121
}
127122

@@ -130,22 +125,22 @@ public static void remove() {
130125
new ArrayList<TaskCompletionListener>();
131126

132127
// Whether the corresponding task has been killed.
133-
private volatile Boolean interrupted = false;
128+
private volatile boolean interrupted = false;
134129

135130
// Whether the task has completed.
136-
private volatile Boolean completed = false;
131+
private volatile boolean completed = false;
137132

138133
/**
139134
* Checks whether the task has completed.
140135
*/
141-
public Boolean isCompleted() {
136+
public boolean isCompleted() {
142137
return completed;
143138
}
144139

145140
/**
146141
* Checks whether the task has been killed.
147142
*/
148-
public Boolean isInterrupted() {
143+
public boolean isInterrupted() {
149144
return interrupted;
150145
}
151146

@@ -246,12 +241,12 @@ public long attemptId() {
246241
}
247242

248243
@Deprecated
249-
/** Deprecated: use getRunningLocally() */
244+
/** Deprecated: use isRunningLocally() */
250245
public boolean runningLocally() {
251246
return runningLocally;
252247
}
253248

254-
public boolean getRunningLocally() {
249+
public boolean isRunningLocally() {
255250
return runningLocally;
256251
}
257252

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ class DAGScheduler(
641641
job.listener.taskSucceeded(0, result)
642642
} finally {
643643
taskContext.markTaskCompleted()
644-
TaskContext.remove()
644+
TaskContext.unset()
645645
}
646646
} catch {
647647
case e: Exception =>

core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,7 @@ private[spark] class ResultTask[T, U](
5858
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
5959

6060
metrics = Some(context.taskMetrics)
61-
try {
62-
func(context, rdd.iterator(partition, context))
63-
} finally {
64-
context.markTaskCompleted()
65-
}
61+
func(context, rdd.iterator(partition, context))
6662
}
6763

6864
// This is only callable on the driver side.

core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,6 @@ private[spark] class ShuffleMapTask(
7878
log.debug("Could not stop writer", e)
7979
}
8080
throw e
81-
} finally {
82-
context.markTaskCompleted()
8381
}
8482
}
8583

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
5252
if (_killed) {
5353
kill(interruptThread = false)
5454
}
55-
runTask(context)
55+
try {
56+
runTask(context)
57+
} finally {
58+
context.markTaskCompleted()
59+
TaskContext.unset()
60+
}
5661
}
5762

5863
def runTask(context: TaskContext): T
@@ -93,7 +98,6 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex
9398
if (interruptThread && taskThread != null) {
9499
taskThread.interrupt()
95100
}
96-
TaskContext.remove()
97101
}
98102
}
99103

0 commit comments

Comments
 (0)