Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,28 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
private type StageId = Int
private type PartitionId = Int
private type TaskAttemptNumber = Int
private case class StageState(authorizedCommitters: Array[TaskAttemptNumber],
failures: mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]])

private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1

/**
* Map from active stages's id => partition id => task attempt with exclusive lock on committing
* output for that partition.
* Map from active stages's id => authorized task attempts for each partition id, which hold an
* exclusive lock on committing task output for that partition as well as any known failed
* attempts in the stage.
*
* Entries are added to the top-level map when stages start and are removed they finish
* (either successfully or unsuccessfully).
*
* Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
*/
private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]()
private val stageStates = mutable.Map[StageId, StageState]()

/**
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
*/
def isEmpty: Boolean = {
authorizedCommittersByStage.isEmpty
stageStates.isEmpty
}

/**
Expand Down Expand Up @@ -110,14 +113,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
maxPartitionId: Int): Unit = {
val arr = new Array[TaskAttemptNumber](maxPartitionId + 1)
java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER)
val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]()
synchronized {
authorizedCommittersByStage(stage) = arr
stageStates(stage) = new StageState(arr, failures)
}
}

// Called by DAGScheduler
private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
authorizedCommittersByStage.remove(stage)
stageStates.remove(stage)
}

// Called by DAGScheduler
Expand All @@ -126,7 +130,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
partition: PartitionId,
attemptNumber: TaskAttemptNumber,
reason: TaskEndReason): Unit = synchronized {
val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, {
val stageState = stageStates.getOrElse(stage, {
logDebug(s"Ignoring task completion for completed stage")
return
})
Expand All @@ -137,10 +141,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " +
s"attempt: $attemptNumber")
case otherReason =>
if (authorizedCommitters(partition) == attemptNumber) {
// Mark the attempt as failed to blacklist from future commit protocol
stageState.failures.get(partition) match {
case Some(failures) => failures += attemptNumber
case None => stageState.failures(partition) = mutable.Set(attemptNumber)
}
if (stageState.authorizedCommitters(partition) == attemptNumber) {
logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " +
s"partition=$partition) failed; clearing lock")
authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
}
}
}
Expand All @@ -149,7 +158,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
if (isDriver) {
coordinatorRef.foreach(_ send StopCoordinator)
coordinatorRef = None
authorizedCommittersByStage.clear()
stageStates.clear()
}
}

Expand All @@ -158,13 +167,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
stage: StageId,
partition: PartitionId,
attemptNumber: TaskAttemptNumber): Boolean = synchronized {
authorizedCommittersByStage.get(stage) match {
case Some(authorizedCommitters) =>
authorizedCommitters(partition) match {
stageStates.get(stage) match {
case Some(state) if attemptFailed(stage, partition, attemptNumber) =>
logWarning(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," +
s" partition=$partition as task attempt $attemptNumber has already failed.")
false
case Some(state) =>
state.authorizedCommitters(partition) match {
case NO_AUTHORIZED_COMMITTER =>
logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " +
s"partition=$partition")
authorizedCommitters(partition) = attemptNumber
state.authorizedCommitters(partition) = attemptNumber
true
case existingCommitter if existingCommitter == attemptNumber =>
logWarning(s"Authorizing duplicate request to commit for " +
Expand All @@ -186,11 +199,22 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
}
}
case None =>
logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" +
s"partition $partition to commit")
logDebug(s"Stage $stage has completed, so not allowing" +
s" attempt number $attemptNumber of partition $partition to commit")
false
}
}

private def attemptFailed(stage: StageId,
partition: PartitionId,
attempt: TaskAttemptNumber): Boolean = synchronized {
stageStates.get(stage) match {
case Some(state) =>
state.failures.get(partition)
.exists(_.contains(attempt))
case None => false
}
}
}

private[spark] object OutputCommitCoordinator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,17 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _,
0 until rdd.partitions.size)
}

test("Do not allow failed attempts to be authorized for committing") {
val stage: Int = 1
val partition: Int = 1
val failedAttempt: Int = 0
outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
outputCommitCoordinator.taskCompleted(
stage, partition, attemptNumber = failedAttempt, reason = TaskKilled)
assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt))
assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1))
}
}

/**
Expand Down