Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T])

// Common RDD functions

def barrier(): JavaRDD[T] = wrapRDD(rdd.barrier())

/**
* Persist this RDD with the default storage level (`MEMORY_ONLY`).
*/
Expand Down
18 changes: 18 additions & 0 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ import java.util.concurrent.atomic.AtomicBoolean

import scala.collection.JavaConverters._

import py4j.GatewayServer

import org.apache.spark._
import org.apache.spark.barrier.BarrierTaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.util._

Expand Down Expand Up @@ -179,6 +182,21 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
// Python version of driver
PythonRDD.writeUTF(pythonVer, dataOut)
// Write out the TaskContextInfo
val isBarrier = context.isInstanceOf[BarrierTaskContext]
dataOut.writeBoolean(isBarrier)
if (isBarrier) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this would be language dependent? would need something for R runner too?

val port = GatewayServer.DEFAULT_PORT + 2 + context.partitionId() * 2
val gatewayServer = new GatewayServer(
context.asInstanceOf[BarrierTaskContext],
port,
port + 1,
GatewayServer.DEFAULT_CONNECT_TIMEOUT,
GatewayServer.DEFAULT_READ_TIMEOUT,
null)
// TODO: When to stop it?
gatewayServer.start()
context.addTaskCompletionListener(_ => gatewayServer.shutdown())
}
dataOut.writeInt(context.stageId())
dataOut.writeInt(context.partitionId())
dataOut.writeInt(context.attemptNumber())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.barrier

import java.util.{Timer, TimerTask}

import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}

class BarrierCoordinator(
numTasks: Int,
timeout: Long,
override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint {

private var epoch = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will epoch value be logged on driver and executors? It should be useful to diagnose upper level MPI program.


private val timer = new Timer("BarrierCoordinator epoch increment timer")

private val syncRequests = new scala.collection.mutable.ArrayBuffer[RpcCallContext](numTasks)

private def replyIfGetAllSyncRequest(): Unit = {
if (syncRequests.length == numTasks) {
syncRequests.foreach(_.reply(()))
syncRequests.clear()
epoch += 1
}
}

override def receive: PartialFunction[Any, Unit] = {
case IncreaseEpoch(previousEpoch) =>
if (previousEpoch == epoch) {
syncRequests.foreach(_.sendFailure(new RuntimeException(
s"The coordinator cannot get all barrier sync requests within $timeout ms.")))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we considered to increase incrementally the time out when we can't get all barrier sync requests at an epoch?

syncRequests.clear()
epoch += 1
}
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case RequestToSync(epoch) =>
if (epoch == this.epoch) {
if (syncRequests.isEmpty) {
val currentEpoch = epoch
timer.schedule(new TimerTask {
override def run(): Unit = {
// self can be null after this RPC endpoint is stopped.
if (self != null) self.send(IncreaseEpoch(currentEpoch))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this epoch fails to sync, the stage will be failed and resubmitted. I think it will begin from new task set, so IncreaseEpoch seems useless because it doesn't really increase epoch?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

register task level barriers sequence and hierarchy may be?

}
}, timeout)
}

syncRequests += context
replyIfGetAllSyncRequest()
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (epoch == this.epoch) {
 ...
} else { // Received RpcCallContext from failed previousEpoch.
  context.sendFailure(new RuntimeException(
    s"The coordinator cannot get all barrier sync requests within $timeout ms.")))
}

}

override def onStop(): Unit = timer.cancel()
}

private[barrier] sealed trait BarrierCoordinatorMessage extends Serializable

private[barrier] case class RequestToSync(epoch: Int) extends BarrierCoordinatorMessage

private[barrier] case class IncreaseEpoch(previousEpoch: Int) extends BarrierCoordinatorMessage
43 changes: 43 additions & 0 deletions core/src/main/scala/org/apache/spark/barrier/BarrierRDD.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.barrier

import scala.reflect.ClassTag

import org.apache.spark.{Partition, TaskContext}
import org.apache.spark.rdd.RDD


/**
* An RDD that supports running MPI programme.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

programme -> program

*/
class BarrierRDD[T: ClassTag](var prev: RDD[T]) extends RDD[T](prev) {

override def isBarrier(): Boolean = true

override def getPartitions: Array[Partition] = prev.partitions

override def compute(split: Partition, context: TaskContext): Iterator[T] = {
prev.iterator(split, context)
}

override def clearDependencies() {
super.clearDependencies()
prev = null
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.barrier

import java.util.Properties

import org.apache.spark.{SparkEnv, TaskContextImpl}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.RpcUtils

class BarrierTaskContext(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BarrierTaskContextImpl?

override val stageId: Int,
override val stageAttemptNumber: Int,
override val partitionId: Int,
override val taskAttemptId: Long,
override val attemptNumber: Int,
override val taskMemoryManager: TaskMemoryManager,
localProperties: Properties,
@transient private val metricsSystem: MetricsSystem,
// The default value is only used in tests.
override val taskMetrics: TaskMetrics = TaskMetrics.empty)
// TODO make this extends TaskContext
extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber,
taskMemoryManager, localProperties, metricsSystem, taskMetrics)
with Logging {

private val barrierCoordinator = {
val env = SparkEnv.get
RpcUtils.makeDriverRef(s"barrier-$stageId-$stageAttemptNumber", env.conf, env.rpcEnv)
}

private var epoch = 0

/**
* Returns an Array of executor IDs that the barrier tasks are running on.
*/
def hosts(): Array[String] = {
val hostsStr = localProperties.getProperty("hosts", "")
hostsStr.trim().split(",").map(_.trim())
}

/**
* Wait to sync all the barrier tasks in the same taskSet.
*/
def barrier(): Unit = synchronized {
barrierCoordinator.askSync[Unit](RequestToSync(epoch))
epoch += 1
}
}
3 changes: 2 additions & 1 deletion core/src/main/scala/org/apache/spark/deploy/Client.scala
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,9 @@ private[spark] class ClientApp extends SparkApplication {
override def start(args: Array[String], conf: SparkConf): Unit = {
val driverArgs = new ClientArguments(args)

// TODO remove this hack
if (!conf.contains("spark.rpc.askTimeout")) {
conf.set("spark.rpc.askTimeout", "10s")
conf.set("spark.rpc.askTimeout", "900s")
}
Logger.getRootLogger.setLevel(driverArgs.logLevel)

Expand Down
11 changes: 9 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark._
import org.apache.spark.Partitioner._
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.barrier.BarrierRDD
import org.apache.spark.internal.Logging
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
Expand All @@ -43,8 +44,7 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils}

/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
Expand Down Expand Up @@ -1317,6 +1317,8 @@ abstract class RDD[T: ClassTag](
}
}

def barrier(): BarrierRDD[T] = withScope(new BarrierRDD[T](this))

/**
* Take the first num elements of the RDD. It works by first scanning one partition, and use the
* results from that partition to estimate the number of additional partitions needed to satisfy
Expand Down Expand Up @@ -1839,6 +1841,11 @@ abstract class RDD[T: ClassTag](
def toJavaRDD() : JavaRDD[T] = {
new JavaRDD(this)(elementClassTag)
}

/**
* Whether the RDD is a BarrierRDD.
*/
def isBarrier(): Boolean = dependencies.exists(_.rdd.isBarrier())
}


Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,4 +110,6 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
super.clearDependencies()
prev = null
}

override def isBarrier(): Boolean = false
}
43 changes: 41 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ class DAGScheduler(
stage.pendingPartitions += id
new ShuffleMapTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, properties, serializedTaskMetrics, Option(jobId),
Option(sc.applicationId), sc.applicationAttemptId)
Option(sc.applicationId), sc.applicationAttemptId, stage.rdd.isBarrier)
}

case stage: ResultStage =>
Expand All @@ -1072,7 +1072,8 @@ class DAGScheduler(
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptNumber,
taskBinary, part, locs, id, properties, serializedTaskMetrics,
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId)
Option(jobId), Option(sc.applicationId), sc.applicationAttemptId,
stage.rdd.isBarrier())
}
}
} catch {
Expand Down Expand Up @@ -1310,6 +1311,44 @@ class DAGScheduler(
}
}

case failure: TaskFailedReason if task.isBarrier =>
// Always fail the current stage and retry all the tasks when a barrier task fail.
val failedStage = stageIdToStage(task.stageId)
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
"due to a barrier task failed.")
val message = "Stage failed because a barrier task finished unsuccessfully. " +
s"${failure.toErrorString}"
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
taskScheduler.cancelTasks(stageId, interruptThread = false)
} catch {
case e: UnsupportedOperationException =>
logInfo(s"Could not cancel tasks for stage $stageId", e)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Under barrier execution, will it be a problem if we can not cancel tasks?

}
markStageAsFinished(failedStage, Some(message))

failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
val shouldAbortStage =
failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
disallowStageRetryForTest

if (shouldAbortStage) {
val abortMessage = if (disallowStageRetryForTest) {
"Barrier stage will not retry stage due to testing config"
} else {
s"""$failedStage (${failedStage.name})
|has failed the maximum allowable number of
|times: $maxConsecutiveStageAttempts.
|Most recent failure reason: $message""".stripMargin.replaceAll("\n", " ")
}
abortStage(failedStage, abortMessage, None)
} else { // update failedStages and make sure a ResubmitFailedStages event is enqueued
failedStages += failedStage
logInfo(s"Resubmitting $failedStage (${failedStage.name}) due to barrier stage failure.")
messageScheduler.schedule(new Runnable {
override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
}, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}

case Resubmitted =>
logInfo("Resubmitted " + task + ", so marking it as still running")
stage match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ import org.apache.spark.rdd.RDD
* @param jobId id of the job this task belongs to
* @param appId id of the app this task belongs to
* @param appAttemptId attempt id of the app this task belongs to
*/
* @param isBarrier whether this task belongs to a barrier sync stage.
*/
private[spark] class ResultTask[T, U](
stageId: Int,
stageAttemptId: Int,
Expand All @@ -60,9 +61,10 @@ private[spark] class ResultTask[T, U](
serializedTaskMetrics: Array[Byte],
jobId: Option[Int] = None,
appId: Option[String] = None,
appAttemptId: Option[String] = None)
appAttemptId: Option[String] = None,
isBarrier: Boolean = false)
extends Task[U](stageId, stageAttemptId, partition.index, localProperties, serializedTaskMetrics,
jobId, appId, appAttemptId)
jobId, appId, appAttemptId, isBarrier)
with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import org.apache.spark.shuffle.ShuffleWriter
* @param jobId id of the job this task belongs to
* @param appId id of the app this task belongs to
* @param appAttemptId attempt id of the app this task belongs to
* @param isBarrier whether this task belongs to a barrier sync stage.
*/
private[spark] class ShuffleMapTask(
stageId: Int,
Expand All @@ -60,9 +61,10 @@ private[spark] class ShuffleMapTask(
serializedTaskMetrics: Array[Byte],
jobId: Option[Int] = None,
appId: Option[String] = None,
appAttemptId: Option[String] = None)
appAttemptId: Option[String] = None,
isBarrier: Boolean = false)
extends Task[MapStatus](stageId, stageAttemptId, partition.index, localProperties,
serializedTaskMetrics, jobId, appId, appAttemptId)
serializedTaskMetrics, jobId, appId, appAttemptId, isBarrier)
with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
Expand Down
Loading