Skip to content

Commit ddb8cbe

Browse files
committed
Moved setting the thread local to where TaskContext is instantiated.
1 parent a7d5e23 commit ddb8cbe

File tree

3 files changed

+16
-13
lines changed

3 files changed

+16
-13
lines changed

core/src/main/java/org/apache/spark/TaskContext.java

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@
3939
@DeveloperApi
4040
public class TaskContext implements Serializable {
4141

42-
private Integer stageId;
43-
private Integer partitionId;
44-
private Long attemptId;
45-
private Boolean runningLocally;
42+
private int stageId;
43+
private int partitionId;
44+
private long attemptId;
45+
private boolean runningLocally;
4646
private TaskMetrics taskMetrics;
4747

4848
/**
@@ -63,7 +63,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean
6363
this.runningLocally = runningLocally;
6464
this.stageId = stageId;
6565
this.taskMetrics = taskMetrics;
66-
taskContext.set(this);
6766
}
6867

6968

@@ -84,7 +83,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId,
8483
this.runningLocally = runningLocally;
8584
this.stageId = stageId;
8685
this.taskMetrics = TaskMetrics.empty();
87-
taskContext.set(this);
8886
}
8987

9088

@@ -103,12 +101,15 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) {
103101
this.runningLocally = false;
104102
this.stageId = stageId;
105103
this.taskMetrics = TaskMetrics.empty();
106-
taskContext.set(this);
107104
}
108105

109106
private static ThreadLocal<TaskContext> taskContext =
110107
new ThreadLocal<TaskContext>();
111108

109+
public static void setTaskContext(TaskContext tc) {
110+
taskContext.set(tc);
111+
}
112+
112113
public static TaskContext get() {
113114
return taskContext.get();
114115
}
@@ -212,19 +213,19 @@ public void markInterrupted() {
212213
interrupted = true;
213214
}
214215

215-
public Integer stageId() {
216+
public int stageId() {
216217
return stageId;
217218
}
218219

219-
public Integer partitionId() {
220+
public int partitionId() {
220221
return partitionId;
221222
}
222223

223-
public Long attemptId() {
224+
public long attemptId() {
224225
return attemptId;
225226
}
226227

227-
public Boolean runningLocally() {
228+
public boolean runningLocally() {
228229
return runningLocally;
229230
}
230231

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,8 @@ class DAGScheduler(
634634
val rdd = job.finalStage.rdd
635635
val split = rdd.partitions(job.partitions(0))
636636
val taskContext =
637-
new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true)
637+
new TaskContext(job.finalStage.id, job.partitions(0), 0, true)
638+
TaskContext.setTaskContext(taskContext)
638639
try {
639640
val result = job.func(taskContext, rdd.iterator(split, taskContext))
640641
job.listener.taskSucceeded(0, result)

core/src/main/scala/org/apache/spark/scheduler/Task.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ import org.apache.spark.util.Utils
4545
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
4646

4747
final def run(attemptId: Long): T = {
48-
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
48+
context = new TaskContext(stageId, partitionId, attemptId, false)
49+
TaskContext.setTaskContext(context)
4950
context.taskMetrics.hostname = Utils.localHostName()
5051
taskThread = Thread.currentThread()
5152
if (_killed) {

0 commit comments

Comments
 (0)