Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 25 additions & 89 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,12 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Consumer

import scala.collection.mutable.ArrayBuffer

import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
Expand Down Expand Up @@ -107,11 +102,13 @@ private[spark] class BarrierCoordinator(
// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)

// An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]
// Messages from each barrier task that have made a blocking runBarrier() call.
// The messages will be replied to all tasks once sync finished.
private val messages = Array.ofDim[String](numTasks)

// The blocking requestMethod called by tasks to sync up for this stage attempt
private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER
// The request method which is called inside this barrier sync. All tasks should make sure
// that they're calling the same method within the same barrier sync phase.
private var requestMethod: RequestMethod.Value = _

// A timer task that ensures we may timeout for a barrier() call.
private var timerTask: TimerTask = null
Expand Down Expand Up @@ -140,28 +137,18 @@ private[spark] class BarrierCoordinator(

// Process the global sync request. The barrier() call succeed if collected enough requests
// within a configured time, otherwise fail all the pending requests.
def handleRequest(
requester: RpcCallContext,
request: RequestToSync
): Unit = synchronized {
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
val taskId = request.taskAttemptId
val epoch = request.barrierEpoch
val requestMethod = request.requestMethod
val partitionId = request.partitionId
val allGatherMessage = request match {
case ag: AllGatherRequestToSync => ag.allGatherMessage
case _ => ""
}

if (requesters.size == 0) {
requestMethodToSync = requestMethod
}
val curReqMethod = request.requestMethod

if (requestMethodToSync != requestMethod) {
if (requesters.isEmpty) {
requestMethod = curReqMethod
} else if (requestMethod != curReqMethod) {
requesters.foreach(
_.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
s"the current synchronized requestMethod `$requestMethodToSync`"
s"`$curReqMethod` during barrier epoch $barrierEpoch, which does not match " +
s"the current synchronized requestMethod `$requestMethod`"
))
)
cleanupBarrierStage(barrierId)
Expand All @@ -186,10 +173,11 @@ private[spark] class BarrierCoordinator(
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
allGatherMessages(partitionId) = allGatherMessage
messages(request.partitionId) = request.message
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (maybeFinishAllRequesters(requesters, numTasks)) {
if (requesters.size == numTasks) {
requesters.foreach(_.reply(messages))
// Finished current barrier() call successfully, clean up ContextBarrierState and
// increase the barrier epoch.
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " +
Expand All @@ -201,25 +189,6 @@ private[spark] class BarrierCoordinator(
}
}

// Finish all the blocking barrier sync requests from a stage attempt successfully if we
// have received all the sync requests.
private def maybeFinishAllRequesters(
requesters: ArrayBuffer[RpcCallContext],
numTasks: Int): Boolean = {
if (requesters.size == numTasks) {
requestMethodToSync match {
case RequestMethod.BARRIER =>
requesters.foreach(_.reply(""))
case RequestMethod.ALL_GATHER =>
val json: String = compact(render(allGatherMessages))
requesters.foreach(_.reply(json))
}
true
} else {
false
}
}

// Cleanup the internal state of a barrier stage attempt.
def clear(): Unit = synchronized {
// The global sync fails so the stage is expected to retry another attempt, all sync
Expand All @@ -239,11 +208,11 @@ private[spark] class BarrierCoordinator(
}

override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case request: RequestToSync =>
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _, _, _, _) =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
val barrierId = ContextBarrierId(stageId, stageAttemptId)
states.computeIfAbsent(barrierId,
(key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
val barrierState = states.get(barrierId)

barrierState.handleRequest(context, request)
Expand All @@ -256,61 +225,28 @@ private[spark] class BarrierCoordinator(

private[spark] sealed trait BarrierCoordinatorMessage extends Serializable

private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
def numTasks: Int
def stageId: Int
def stageAttemptId: Int
def taskAttemptId: Long
def barrierEpoch: Int
def partitionId: Int
def requestMethod: RequestMethod.Value
}

/**
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
*/
private[spark] case class BarrierRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value
) extends RequestToSync

/**
* A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
* A global sync request message from BarrierTaskContext. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param barrierEpoch ID of a runBarrier() call, a task may consist multiple runBarrier() calls
* @param partitionId ID of the current partition the task is assigned to
* @param message Message sent from the BarrierTaskContext
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
* @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
*/
private[spark] case class AllGatherRequestToSync(
private[spark] case class RequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
) extends RequestToSync
message: String,
requestMethod: RequestMethod.Value) extends BarrierCoordinatorMessage

private[spark] object RequestMethod extends Enumeration {
val BARRIER, ALL_GATHER = Value
Expand Down
59 changes: 10 additions & 49 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,13 @@

package org.apache.spark

import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Properties, Timer, TimerTask}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.language.postfixOps

import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.parse

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -67,31 +60,7 @@ class BarrierTaskContext private[spark] (
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size

private def getRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptNumber: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
): RequestToSync = {
requestMethod match {
case RequestMethod.BARRIER =>
BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod)
case RequestMethod.ALL_GATHER =>
AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod, allGatherMessage)
}
}

private def runBarrier(
requestMethod: RequestMethod.Value,
allGatherMessage: String = ""
): String = {

private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
Expand All @@ -108,24 +77,24 @@ class BarrierTaskContext private[spark] (
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)

var json: String = ""

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[String](
message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, message, requestMethod),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))

// messages which consist of all barrier tasks' messages
var messages: Array[String] = null
// Wait the RPC future to be completed, but every 1 second it will jump out waiting
// and check whether current spark task is killed. If killed, then throw
// a `TaskKilledException`, otherwise continue wait RPC until it completes.
try {
while (!abortableRpcFuture.toFuture.isCompleted) {
// wait RPC future for at most 1 second
try {
json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
messages = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
} catch {
case _: TimeoutException | _: InterruptedException =>
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
Expand All @@ -144,6 +113,7 @@ class BarrierTaskContext private[spark] (
"global sync successfully, waited for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
messages
} catch {
case e: SparkException =>
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
Expand All @@ -155,7 +125,6 @@ class BarrierTaskContext private[spark] (
timerTask.cancel()
timer.purge()
}
json
}

/**
Expand Down Expand Up @@ -200,10 +169,7 @@ class BarrierTaskContext private[spark] (
*/
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
runBarrier(RequestMethod.BARRIER)
()
}
def barrier(): Unit = runBarrier("", RequestMethod.BARRIER)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra line

/**
* :: Experimental ::
Expand All @@ -217,12 +183,7 @@ class BarrierTaskContext private[spark] (
*/
@Experimental
@Since("3.0.0")
def allGather(message: String): Array[String] = {
val json = runBarrier(RequestMethod.ALL_GATHER, message)
val jsonArray = parse(json)
implicit val formats = DefaultFormats
jsonArray.extract[Array[String]]
}
def allGather(message: String): Array[String] = runBarrier(message, RequestMethod.ALL_GATHER)

/**
* :: Experimental ::
Expand Down
17 changes: 5 additions & 12 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -414,22 +414,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
)
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
var result: String = ""
requestMethod match {
val messages = requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
context.asInstanceOf[BarrierTaskContext].barrier()
result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS
Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
val messages: Array[String] = context.asInstanceOf[BarrierTaskContext].allGather(
message
)
result = compact(render(JArray(
messages.map(
(message) => JString(message)
).toList
)))
context.asInstanceOf[BarrierTaskContext].allGather(message)
}
writeUTF(result, out)
out.writeInt(messages.length)
messages.foreach(writeUTF(_, out))
} catch {
case e: SparkException =>
writeUTF(e.getMessage, out)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.scheduler

import java.io.File

import scala.collection.mutable.ArrayBuffer
import scala.util.Random

import org.apache.spark._
Expand Down
15 changes: 6 additions & 9 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import json

from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import write_int, write_with_length, UTF8Deserializer
from pyspark.serializers import read_int, write_int, write_with_length, UTF8Deserializer


class TaskContext(object):
Expand Down Expand Up @@ -133,7 +133,10 @@ def _load_from_socket(port, auth_secret, function, all_gather_message=None):
sockfile.flush()

# Collect result.
res = UTF8Deserializer().loads(sockfile)
len = read_int(sockfile)
res = []
for i in range(len):
res.append(UTF8Deserializer().loads(sockfile))

# Release resources.
sockfile.close()
Expand Down Expand Up @@ -232,13 +235,7 @@ def allGather(self, message=""):
raise Exception("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
else:
gathered_items = _load_from_socket(
self._port,
self._secret,
ALL_GATHER_FUNCTION,
message,
)
return [e for e in json.loads(gathered_items)]
return _load_from_socket(self._port, self._secret, ALL_GATHER_FUNCTION, message)

def getTaskInfos(self):
"""
Expand Down