diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 4c358629dee9..ba303680d1a0 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,20 +17,71 @@ package org.apache.spark +import java.util.Properties + import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem /** A [[TaskContext]] with extra info and tooling for a barrier stage. */ -trait BarrierTaskContext extends TaskContext { +class BarrierTaskContext( + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, + override val taskAttemptId: Long, + override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, + @transient private val metricsSystem: MetricsSystem, + // The default value is only used in tests. + override val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, + taskMemoryManager, localProperties, metricsSystem, taskMetrics) { /** * :: Experimental :: * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to * MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same * stage have reached this routine. + * + * CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all + * possible code branches. Otherwise, you may get the job hanging or a SparkException after + * timeout. Some examples of misuses listed below: + * 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it + * shall lead to timeout of the function call. + * {{{ + * rdd.barrier().mapPartitions { (iter, context) => + * if (context.partitionId() == 0) { + * // Do nothing. + * } else { + * context.barrier() + * } + * iter + * } + * }}} + * + * 2. Include barrier() function in a try-catch code block, this may lead to timeout of the + * second function call. + * {{{ + * rdd.barrier().mapPartitions { (iter, context) => + * try { + * // Do something that might throw an Exception. + * doSomething() + * context.barrier() + * } catch { + * case e: Exception => logWarning("...", e) + * } + * context.barrier() + * iter + * } + * }}} */ @Experimental @Since("2.4.0") - def barrier(): Unit + def barrier(): Unit = { + // TODO SPARK-24817 implement global barrier. + } /** * :: Experimental :: @@ -38,5 +89,8 @@ trait BarrierTaskContext extends TaskContext { */ @Experimental @Since("2.4.0") - def getTaskInfos(): Array[BarrierTaskInfo] + def getTaskInfos(): Array[BarrierTaskInfo] = { + val addressesStr = localProperties.getProperty("addresses", "") + addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) + } } diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala deleted file mode 100644 index 8ac705757a38..000000000000 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.util.Properties - -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.metrics.MetricsSystem - -/** A [[BarrierTaskContext]] implementation. */ -private[spark] class BarrierTaskContextImpl( - override val stageId: Int, - override val stageAttemptNumber: Int, - override val partitionId: Int, - override val taskAttemptId: Long, - override val attemptNumber: Int, - override val taskMemoryManager: TaskMemoryManager, - localProperties: Properties, - @transient private val metricsSystem: MetricsSystem, - // The default value is only used in tests. - override val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, - taskMemoryManager, localProperties, metricsSystem, taskMetrics) - with BarrierTaskContext { - - // TODO SPARK-24817 implement global barrier. - override def barrier(): Unit = {} - - override def getTaskInfos(): Array[BarrierTaskInfo] = { - val addressesStr = localProperties.getProperty("addresses", "") - addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala index 85565d16e271..71f38bf6967b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDBarrier.scala @@ -28,7 +28,7 @@ class RDDBarrier[T: ClassTag](rdd: RDD[T]) { /** * :: Experimental :: - * Maps partitions together with a provided BarrierTaskContext. + * Maps partitions together with a provided [[org.apache.spark.BarrierTaskContext]]. * * `preservesPartitioning` indicates whether the input function preserves the partitioner, which * should be `false` unless `rdd` is a pair RDD and the input function doesn't modify the keys. diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 89ff2038e5f8..11f85fd91ba0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -83,7 +83,7 @@ private[spark] abstract class Task[T]( // TODO SPARK-24874 Allow create BarrierTaskContext based on partitions, instead of whether // the stage is barrier. context = if (isBarrier) { - new BarrierTaskContextImpl( + new BarrierTaskContext( stageId, stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId,