From 333c7d644973c721f0b0509399579723d2a43446 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 17 Sep 2014 13:10:47 +0530 Subject: [PATCH 1/7] Translated Task context from scala to java. --- .../java/org/apache/spark/TaskContext.java | 165 ++++++++++++++++++ .../scala/org/apache/spark/TaskContext.scala | 126 ------------- .../java/org/apache/spark/JavaAPISuite.java | 2 +- .../util/JavaTaskCompletionListenerImpl.java | 4 - .../org/apache/spark/CacheManagerSuite.scala | 2 +- 5 files changed, 167 insertions(+), 132 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/TaskContext.java delete mode 100644 core/src/main/scala/org/apache/spark/TaskContext.scala diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java new file mode 100644 index 000000000000..ceffe7bdc0ec --- /dev/null +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import scala.Function0; +import scala.Function1; +import scala.Unit; +import scala.collection.JavaConversions; + +import org.apache.spark.annotation.DeveloperApi; +import org.apache.spark.executor.TaskMetrics; +import org.apache.spark.util.TaskCompletionListener; +import org.apache.spark.util.TaskCompletionListenerException; + +@DeveloperApi +public class TaskContext implements Serializable { + + public Integer stageId; + public Integer partitionId; + public Long attemptId; + public Boolean runningLocally; + public TaskMetrics taskMetrics; + + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, + TaskMetrics taskMetrics) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = taskMetrics; + } + + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, + Boolean runningLocally) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = false; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + } + + // List of callback functions to execute when the task completes. + private transient List onCompleteCallbacks = + new ArrayList(); + + // Whether the corresponding task has been killed. + private volatile Boolean interrupted = false; + + // Whether the task has completed. + private volatile Boolean completed = false; + + /** + * Checks whether the task has completed. + */ + public Boolean isCompleted() { + return completed; + } + + /** + * Checks whether the task has been killed. + */ + public Boolean isInterrupted() { + return interrupted; + } + + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(TaskCompletionListener listener){ + onCompleteCallbacks.add(listener); + return this; + } + + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + * + * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(final Function1 f) { + onCompleteCallbacks.add( new TaskCompletionListener() { + + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(context); + } + }); + return this; + } + + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * @param f Callback function. + */ + @Deprecated + public void addOnCompleteCallback(final Function0 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(); + } + }); + } + + /** Marks the task as completed and triggers the listeners. */ + public void markTaskCompleted() throws TaskCompletionListenerException { + completed = true; + List errorMsgs = new ArrayList(2); + // Process complete callbacks in the reverse order of registration + List revlist = + new ArrayList(onCompleteCallbacks); + Collections.reverse(revlist); + for (TaskCompletionListener tcl : revlist){ + try { + tcl.onTaskCompletion(this); + } catch (Throwable e){ + errorMsgs.add(e.getMessage()); + } + } + + if (!errorMsgs.isEmpty()) { + throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); + } + } + + /** Marks the task for interruption, i.e. cancellation. */ + public void markInterrupted() { + interrupted = true; + } +} diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala deleted file mode 100644 index 51b3e4d5e093..000000000000 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.util.{TaskCompletionListenerException, TaskCompletionListener} - - -/** - * :: DeveloperApi :: - * Contextual information about a task which can be read or mutated during execution. - * - * @param stageId stage id - * @param partitionId index of the partition - * @param attemptId the number of attempts to execute this task - * @param runningLocally whether the task is running locally in the driver JVM - * @param taskMetrics performance metrics of the task - */ -@DeveloperApi -class TaskContext( - val stageId: Int, - val partitionId: Int, - val attemptId: Long, - val runningLocally: Boolean = false, - private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends Serializable with Logging { - - @deprecated("use partitionId", "0.8.1") - def splitId = partitionId - - // List of callback functions to execute when the task completes. - @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] - - // Whether the corresponding task has been killed. - @volatile private var interrupted: Boolean = false - - // Whether the task has completed. - @volatile private var completed: Boolean = false - - /** Checks whether the task has completed. */ - def isCompleted: Boolean = completed - - /** Checks whether the task has been killed. */ - def isInterrupted: Boolean = interrupted - - // TODO: Also track whether the task has completed successfully or with exception. - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { - onCompleteCallbacks += listener - this - } - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - def addTaskCompletionListener(f: TaskContext => Unit): this.type = { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f(context) - } - this - } - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * @param f Callback function. - */ - @deprecated("use addTaskCompletionListener", "1.1.0") - def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f() - } - } - - /** Marks the task as completed and triggers the listeners. */ - private[spark] def markTaskCompleted(): Unit = { - completed = true - val errorMsgs = new ArrayBuffer[String](2) - // Process complete callbacks in the reverse order of registration - onCompleteCallbacks.reverse.foreach { listener => - try { - listener.onTaskCompletion(this) - } catch { - case e: Throwable => - errorMsgs += e.getMessage - logError("Error in TaskCompletionListener", e) - } - } - if (errorMsgs.nonEmpty) { - throw new TaskCompletionListenerException(errorMsgs) - } - } - - /** Marks the task for interruption, i.e. cancellation. */ - private[spark] def markInterrupted(): Unit = { - interrupted = true - } -} diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index b8574dfb42e6..6502c56cc0b9 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -776,7 +776,7 @@ public void persist() { @Test public void iterator() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2); - TaskContext context = new TaskContext(0, 0, 0, false, new TaskMetrics()); + TaskContext context = new TaskContext(0, 0, 0L, false, new TaskMetrics()); Assert.assertEquals(1, rdd.iterator(rdd.partitions().get(0), context).next().intValue()); } diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index af34cdb03e4d..070284a9af3d 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,10 +30,6 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); - context.stageId(); - context.partitionId(); - context.runningLocally(); - context.taskMetrics(); context.addTaskCompletionListener(this); } } diff --git a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala index 90dcadcffd09..d735010d7c9d 100644 --- a/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/CacheManagerSuite.scala @@ -94,7 +94,7 @@ class CacheManagerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar } whenExecuting(blockManager) { - val context = new TaskContext(0, 0, 0, runningLocally = true) + val context = new TaskContext(0, 0, 0, true) val value = cacheManager.getOrCompute(rdd, split, context, StorageLevel.MEMORY_ONLY) assert(value.toList === List(1, 2, 3, 4)) } From f716fd1b0d84bcb08889b62af2d5f7a6d14b1cab Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 17 Sep 2014 13:45:39 +0530 Subject: [PATCH 2/7] introduced thread local for getting the task context. --- core/src/main/java/org/apache/spark/TaskContext.java | 11 +++++++++++ core/src/main/scala/org/apache/spark/rdd/RDD.scala | 1 + 2 files changed, 12 insertions(+) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index ceffe7bdc0ec..89e2d244a35e 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -48,6 +48,7 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean this.runningLocally = runningLocally; this.stageId = stageId; this.taskMetrics = taskMetrics; + taskContext.set(this); } public TaskContext(Integer stageId, Integer partitionId, Long attemptId, @@ -57,6 +58,7 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, this.runningLocally = runningLocally; this.stageId = stageId; this.taskMetrics = TaskMetrics.empty(); + taskContext.set(this); } public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { @@ -65,6 +67,15 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { this.runningLocally = false; this.stageId = stageId; this.taskMetrics = TaskMetrics.empty(); + taskContext.set(this); + } + + + private static ThreadLocal taskContext = + new ThreadLocal(); + + public static TaskContext get() { + return taskContext.get(); } // List of callback functions to execute when the task completes. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index a9b905b0d1a6..cb40ea703e2e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -619,6 +619,7 @@ abstract class RDD[T: ClassTag]( * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. */ @DeveloperApi + @deprecated("use TaskContext.get", "1.2.0") def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { From edf945e6765314f0b87092d460f71aa70feecdc5 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 17 Sep 2014 14:42:05 +0530 Subject: [PATCH 3/7] Code review git add -A --- .../java/org/apache/spark/TaskContext.java | 292 ++++++++++-------- .../util/JavaTaskCompletionListenerImpl.java | 4 + 2 files changed, 161 insertions(+), 135 deletions(-) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 89e2d244a35e..d38ec19d3f4e 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -35,142 +35,164 @@ @DeveloperApi public class TaskContext implements Serializable { - public Integer stageId; - public Integer partitionId; - public Long attemptId; - public Boolean runningLocally; - public TaskMetrics taskMetrics; - - public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, - TaskMetrics taskMetrics) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = taskMetrics; - taskContext.set(this); + private Integer stageId; + private Integer partitionId; + private Long attemptId; + private Boolean runningLocally; + private TaskMetrics taskMetrics; + + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, + TaskMetrics taskMetrics) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = taskMetrics; + taskContext.set(this); + } + + public TaskContext(Integer stageId, Integer partitionId, Long attemptId, + Boolean runningLocally) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = runningLocally; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + taskContext.set(this); + } + + public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { + this.attemptId = attemptId; + this.partitionId = partitionId; + this.runningLocally = false; + this.stageId = stageId; + this.taskMetrics = TaskMetrics.empty(); + taskContext.set(this); + } + + private static ThreadLocal taskContext = + new ThreadLocal(); + + public static TaskContext get() { + return taskContext.get(); + } + + // List of callback functions to execute when the task completes. + private transient List onCompleteCallbacks = + new ArrayList(); + + // Whether the corresponding task has been killed. + private volatile Boolean interrupted = false; + + // Whether the task has completed. + private volatile Boolean completed = false; + + /** + * Checks whether the task has completed. + */ + public Boolean isCompleted() { + return completed; + } + + /** + * Checks whether the task has been killed. + */ + public Boolean isInterrupted() { + return interrupted; + } + + /** + * Add a (Java friendly) listener to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + *

+ * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { + onCompleteCallbacks.add(listener); + return this; + } + + /** + * Add a listener in the form of a Scala closure to be executed on task completion. + * This will be called in all situation - success, failure, or cancellation. + *

+ * An example use is for HadoopRDD to register a callback to close the input stream. + */ + public TaskContext addTaskCompletionListener(final Function1 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(context); + } + }); + return this; + } + + /** + * Add a callback function to be executed on task completion. An example use + * is for HadoopRDD to register a callback to close the input stream. + * Will be called in any situation - success, failure, or cancellation. + * + * @param f Callback function. + */ + @Deprecated + public void addOnCompleteCallback(final Function0 f) { + onCompleteCallbacks.add(new TaskCompletionListener() { + @Override + public void onTaskCompletion(TaskContext context) { + f.apply(); + } + }); + } + + /** + * Marks the task as completed and triggers the listeners. + */ + public void markTaskCompleted() throws TaskCompletionListenerException { + completed = true; + List errorMsgs = new ArrayList(2); + // Process complete callbacks in the reverse order of registration + List revlist = + new ArrayList(onCompleteCallbacks); + Collections.reverse(revlist); + for (TaskCompletionListener tcl : revlist) { + try { + tcl.onTaskCompletion(this); + } catch (Throwable e) { + errorMsgs.add(e.getMessage()); + } } - public TaskContext(Integer stageId, Integer partitionId, Long attemptId, - Boolean runningLocally) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = runningLocally; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - taskContext.set(this); - } - - public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { - this.attemptId = attemptId; - this.partitionId = partitionId; - this.runningLocally = false; - this.stageId = stageId; - this.taskMetrics = TaskMetrics.empty(); - taskContext.set(this); - } - - - private static ThreadLocal taskContext = - new ThreadLocal(); - - public static TaskContext get() { - return taskContext.get(); - } - - // List of callback functions to execute when the task completes. - private transient List onCompleteCallbacks = - new ArrayList(); - - // Whether the corresponding task has been killed. - private volatile Boolean interrupted = false; - - // Whether the task has completed. - private volatile Boolean completed = false; - - /** - * Checks whether the task has completed. - */ - public Boolean isCompleted() { - return completed; - } - - /** - * Checks whether the task has been killed. - */ - public Boolean isInterrupted() { - return interrupted; - } - - - /** - * Add a (Java friendly) listener to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public TaskContext addTaskCompletionListener(TaskCompletionListener listener){ - onCompleteCallbacks.add(listener); - return this; - } - - - /** - * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. - * - * An example use is for HadoopRDD to register a callback to close the input stream. - */ - public TaskContext addTaskCompletionListener(final Function1 f) { - onCompleteCallbacks.add( new TaskCompletionListener() { - - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(context); - } - }); - return this; - } - - /** - * Add a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * @param f Callback function. - */ - @Deprecated - public void addOnCompleteCallback(final Function0 f) { - onCompleteCallbacks.add(new TaskCompletionListener() { - @Override - public void onTaskCompletion(TaskContext context) { - f.apply(); - } - }); - } - - /** Marks the task as completed and triggers the listeners. */ - public void markTaskCompleted() throws TaskCompletionListenerException { - completed = true; - List errorMsgs = new ArrayList(2); - // Process complete callbacks in the reverse order of registration - List revlist = - new ArrayList(onCompleteCallbacks); - Collections.reverse(revlist); - for (TaskCompletionListener tcl : revlist){ - try { - tcl.onTaskCompletion(this); - } catch (Throwable e){ - errorMsgs.add(e.getMessage()); - } - } - - if (!errorMsgs.isEmpty()) { - throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); - } - } - - /** Marks the task for interruption, i.e. cancellation. */ - public void markInterrupted() { - interrupted = true; + if (!errorMsgs.isEmpty()) { + throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); } + taskContext.remove(); + } + + /** + * Marks the task for interruption, i.e. cancellation. + */ + public void markInterrupted() { + interrupted = true; + } + + public Integer stageId() { + return stageId; + } + + public Integer partitionId() { + return partitionId; + } + + public Long attemptId() { + return attemptId; + } + + public Boolean runningLocally() { + return runningLocally; + } + + public TaskMetrics taskMetrics() { + return taskMetrics; + } } diff --git a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java index 070284a9af3d..af34cdb03e4d 100644 --- a/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java +++ b/core/src/test/java/org/apache/spark/util/JavaTaskCompletionListenerImpl.java @@ -30,6 +30,10 @@ public class JavaTaskCompletionListenerImpl implements TaskCompletionListener { public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); + context.stageId(); + context.partitionId(); + context.runningLocally(); + context.taskMetrics(); context.addTaskCompletionListener(this); } } From a7d5e23330a159e91581a35d46e1846770eb421d Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Wed, 17 Sep 2014 15:27:51 +0530 Subject: [PATCH 4/7] Added doc comments. --- .../java/org/apache/spark/TaskContext.java | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index d38ec19d3f4e..f0fa334ee886 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -32,6 +32,10 @@ import org.apache.spark.util.TaskCompletionListener; import org.apache.spark.util.TaskCompletionListenerException; +/** +* :: DeveloperApi :: +* Contextual information about a task which can be read or mutated during execution. +*/ @DeveloperApi public class TaskContext implements Serializable { @@ -41,6 +45,17 @@ public class TaskContext implements Serializable { private Boolean runningLocally; private TaskMetrics taskMetrics; + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + * @param taskMetrics performance metrics of the task + */ + @DeveloperApi public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally, TaskMetrics taskMetrics) { this.attemptId = attemptId; @@ -51,6 +66,17 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean taskContext.set(this); } + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + * @param runningLocally whether the task is running locally in the driver JVM + */ + @DeveloperApi public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean runningLocally) { this.attemptId = attemptId; @@ -61,6 +87,16 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, taskContext.set(this); } + + /** + * :: DeveloperApi :: + * Contextual information about a task which can be read or mutated during execution. + * + * @param stageId stage id + * @param partitionId index of the partition + * @param attemptId the number of attempts to execute this task + */ + @DeveloperApi public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { this.attemptId = attemptId; this.partitionId = partitionId; From ddb8cbec6e2585535cc84848aa2357b56bd11964 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 18 Sep 2014 15:28:12 +0530 Subject: [PATCH 5/7] Moved setting the thread local to where TaskContext is instantiated. --- .../java/org/apache/spark/TaskContext.java | 23 ++++++++++--------- .../apache/spark/scheduler/DAGScheduler.scala | 3 ++- .../org/apache/spark/scheduler/Task.scala | 3 ++- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index f0fa334ee886..1928d08a6525 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -39,10 +39,10 @@ @DeveloperApi public class TaskContext implements Serializable { - private Integer stageId; - private Integer partitionId; - private Long attemptId; - private Boolean runningLocally; + private int stageId; + private int partitionId; + private long attemptId; + private boolean runningLocally; private TaskMetrics taskMetrics; /** @@ -63,7 +63,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, Boolean this.runningLocally = runningLocally; this.stageId = stageId; this.taskMetrics = taskMetrics; - taskContext.set(this); } @@ -84,7 +83,6 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId, this.runningLocally = runningLocally; this.stageId = stageId; this.taskMetrics = TaskMetrics.empty(); - taskContext.set(this); } @@ -103,12 +101,15 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { this.runningLocally = false; this.stageId = stageId; this.taskMetrics = TaskMetrics.empty(); - taskContext.set(this); } private static ThreadLocal taskContext = new ThreadLocal(); + public static void setTaskContext(TaskContext tc) { + taskContext.set(tc); + } + public static TaskContext get() { return taskContext.get(); } @@ -212,19 +213,19 @@ public void markInterrupted() { interrupted = true; } - public Integer stageId() { + public int stageId() { return stageId; } - public Integer partitionId() { + public int partitionId() { return partitionId; } - public Long attemptId() { + public long attemptId() { return attemptId; } - public Boolean runningLocally() { + public boolean runningLocally() { return runningLocally; } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b2774dfc4755..ccc920af1f2b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -634,7 +634,8 @@ class DAGScheduler( val rdd = job.finalStage.rdd val split = rdd.partitions(job.partitions(0)) val taskContext = - new TaskContext(job.finalStage.id, job.partitions(0), 0, runningLocally = true) + new TaskContext(job.finalStage.id, job.partitions(0), 0, true) + TaskContext.setTaskContext(taskContext) try { val result = job.func(taskContext, rdd.iterator(split, taskContext)) job.listener.taskSucceeded(0, result) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 6aa0cca06878..7ad26ef08c2f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -45,7 +45,8 @@ import org.apache.spark.util.Utils private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable { final def run(attemptId: Long): T = { - context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false) + context = new TaskContext(stageId, partitionId, attemptId, false) + TaskContext.setTaskContext(context) context.taskMetrics.hostname = Utils.localHostName() taskThread = Thread.currentThread() if (_killed) { From ee8bd009fe8065c4ea018c4ff5baa32378606243 Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Thu, 18 Sep 2014 15:31:36 +0530 Subject: [PATCH 6/7] Added internal API in docs comments. --- core/src/main/java/org/apache/spark/TaskContext.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 1928d08a6525..058ac94d8d68 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -183,6 +183,7 @@ public void onTaskCompletion(TaskContext context) { } /** + * ::Internal API:: * Marks the task as completed and triggers the listeners. */ public void markTaskCompleted() throws TaskCompletionListenerException { @@ -207,6 +208,7 @@ public void markTaskCompleted() throws TaskCompletionListenerException { } /** + * ::Internal API:: * Marks the task for interruption, i.e. cancellation. */ public void markInterrupted() { @@ -229,6 +231,7 @@ public boolean runningLocally() { return runningLocally; } + /** ::Internal API:: */ public TaskMetrics taskMetrics() { return taskMetrics; } From 8ae414c1ff2af5328cac7cc36b28e66f3aa6647d Mon Sep 17 00:00:00 2001 From: Shashank Sharma Date: Fri, 26 Sep 2014 15:49:13 +0530 Subject: [PATCH 7/7] CR --- .../java/org/apache/spark/TaskContext.java | 42 +++++++++++++++++-- .../apache/spark/scheduler/DAGScheduler.scala | 1 + .../org/apache/spark/scheduler/Task.scala | 3 +- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java index 058ac94d8d68..09b8ce02bd3d 100644 --- a/core/src/main/java/org/apache/spark/TaskContext.java +++ b/core/src/main/java/org/apache/spark/TaskContext.java @@ -106,6 +106,10 @@ public TaskContext(Integer stageId, Integer partitionId, Long attemptId) { private static ThreadLocal taskContext = new ThreadLocal(); + /** + * :: Internal API :: + * This is spark internal API, not intended to be called from user programs. + */ public static void setTaskContext(TaskContext tc) { taskContext.set(tc); } @@ -114,6 +118,13 @@ public static TaskContext get() { return taskContext.get(); } + /** + * :: Internal API :: + */ + public static void remove() { + taskContext.remove(); + } + // List of callback functions to execute when the task completes. private transient List onCompleteCallbacks = new ArrayList(); @@ -151,7 +162,7 @@ public TaskContext addTaskCompletionListener(TaskCompletionListener listener) { /** * Add a listener in the form of a Scala closure to be executed on task completion. - * This will be called in all situation - success, failure, or cancellation. + * This will be called in all situations - success, failure, or cancellation. *

* An example use is for HadoopRDD to register a callback to close the input stream. */ @@ -170,6 +181,8 @@ public void onTaskCompletion(TaskContext context) { * is for HadoopRDD to register a callback to close the input stream. * Will be called in any situation - success, failure, or cancellation. * + * Deprecated: use addTaskCompletionListener + * * @param f Callback function. */ @Deprecated @@ -193,7 +206,7 @@ public void markTaskCompleted() throws TaskCompletionListenerException { List revlist = new ArrayList(onCompleteCallbacks); Collections.reverse(revlist); - for (TaskCompletionListener tcl : revlist) { + for (TaskCompletionListener tcl: revlist) { try { tcl.onTaskCompletion(this); } catch (Throwable e) { @@ -204,7 +217,6 @@ public void markTaskCompleted() throws TaskCompletionListenerException { if (!errorMsgs.isEmpty()) { throw new TaskCompletionListenerException(JavaConversions.asScalaBuffer(errorMsgs)); } - taskContext.remove(); } /** @@ -215,22 +227,46 @@ public void markInterrupted() { interrupted = true; } + @Deprecated + /** Deprecated: use getStageId() */ public int stageId() { return stageId; } + @Deprecated + /** Deprecated: use getPartitionId() */ public int partitionId() { return partitionId; } + @Deprecated + /** Deprecated: use getAttemptId() */ public long attemptId() { return attemptId; } + @Deprecated + /** Deprecated: use getRunningLocally() */ public boolean runningLocally() { return runningLocally; } + public boolean getRunningLocally() { + return runningLocally; + } + + public int getStageId() { + return stageId; + } + + public int getPartitionId() { + return partitionId; + } + + public long getAttemptId() { + return attemptId; + } + /** ::Internal API:: */ public TaskMetrics taskMetrics() { return taskMetrics; diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ccc920af1f2b..32cf29ed140e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -641,6 +641,7 @@ class DAGScheduler( job.listener.taskSucceeded(0, result) } finally { taskContext.markTaskCompleted() + TaskContext.remove() } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7ad26ef08c2f..bf73f6f7bd0e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -93,7 +93,8 @@ private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) ex if (interruptThread && taskThread != null) { taskThread.interrupt() } - } + TaskContext.remove() + } } /**