Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 57 additions & 3 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,80 @@

package org.apache.spark

import java.util.Properties

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem

/** A [[TaskContext]] with extra info and tooling for a barrier stage. */
trait BarrierTaskContext extends TaskContext {
class BarrierTaskContext(
override val stageId: Int,
override val stageAttemptNumber: Int,
override val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
// The default value is only used in tests.
override val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber,
taskMemoryManager, localProperties, metricsSystem, taskMetrics) {

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of misuses listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { (iter, context) =>
* if (context.partitionId() == 0) {
* // Do nothing.
* } else {
* context.barrier()
* }
* iter
* }
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
* {{{
* rdd.barrier().mapPartitions { (iter, context) =>
* try {
* // Do something that might throw an Exception.
* doSomething()
* context.barrier()
* } catch {
* case e: Exception => logWarning("...", e)
* }
* context.barrier()
* iter
* }
* }}}
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit
def barrier(): Unit = {
// TODO SPARK-24817 implement global barrier.
}

/**
* :: Experimental ::
* Returns the all task infos in this barrier stage, the task infos are ordered by partitionId.
*/
@Experimental
@Since("2.4.0")
def getTaskInfos(): Array[BarrierTaskInfo]
def getTaskInfos(): Array[BarrierTaskInfo] = {
val addressesStr = localProperties.getProperty("addresses", "")
addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_))
}
}
49 changes: 0 additions & 49 deletions core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala

This file was deleted.

2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) {

/**
* :: Experimental ::
* Maps partitions together with a provided BarrierTaskContext.
* Maps partitions together with a provided [[org.apache.spark.BarrierTaskContext]].
*
* `preservesPartitioning` indicates whether the input function preserves the partitioner, which
* should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ private[spark] abstract class Task[T](
// TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether
// the stage is barrier.
context = if (isBarrier) {
new BarrierTaskContextImpl(
new BarrierTaskContext(
stageId,
stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal
partitionId,
Expand Down