Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1057,8 +1057,12 @@ class DAGScheduler(
if (tasks.size > 0) {
logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " +
s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})")
def ordFunc(x: Task[_], y: Task[_]): Boolean = {
inputSizeFromShuffledRDD(stageIdToStage(x.stageId).rdd, x.partitionId) >
inputSizeFromShuffledRDD(stageIdToStage(y.stageId).rdd, y.partitionId)
}
taskScheduler.submitTasks(new TaskSet(
tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
tasks.sortWith(ordFunc).toArray, stage.id, stage.latestInfo.attemptId, jobId, properties))
stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should mark
Expand All @@ -1080,6 +1084,29 @@ class DAGScheduler(
}
}

private[scheduler] def inputSizeFromShuffledRDD(rdd: RDD[_], pId: Int): Long =
{
var ret = 0L
val waitingForVisit = new Stack[Tuple2[RDD[_], Int]]
if (getCacheLocs(rdd)(pId) == Nil) {
waitingForVisit.push((rdd, pId))
}
while(waitingForVisit.nonEmpty) {
val (rdd, split) = waitingForVisit.pop()
rdd.dependencies.foreach {
case dep: ShuffleDependency[_, _, _] =>
if (rdd.partitioner.isEmpty || rdd.partitioner == Some(dep.partitioner)) {
ret += mapOutputTracker.getStatistics(dep).bytesByPartitionId(split)
}
case dep: NarrowDependency[_] =>
dep.getParents(split).foreach {
case parentSplit => ret += inputSizeFromShuffledRDD(dep.rdd, parentSplit)
}
}
}
ret
}

/**
* Merge local values from a task into the corresponding accumulators previously registered
* here on the driver.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2277,6 +2277,43 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
(Success, 1)))
}

test("Tasks should be in descending order by input size from ShuffledRDD.") {
val partitioner = new HashPartitioner(2)
val rddA = new MyRDD(sc, 2, Nil)
val shuffleDepA = new ShuffleDependency(rddA, partitioner)
val shuffleIdA = shuffleDepA.shuffleId

val rddB = new MyRDD(sc, 2, Nil)
val shuffleDepB = new ShuffleDependency(rddB, partitioner)
val shuffleIdB = shuffleDepB.shuffleId

val rddC = new MyRDD(sc, 2, List(shuffleDepA, shuffleDepB), tracker = mapOutputTracker)
submit(rddC, Array(0, 1))

def compressAndDecompress(sizes: Array[Long]): Array[Long] = {
sizes.map(size => MapStatus.decompressSize(MapStatus.compressSize(size)))
}
assert(taskSets(0).stageId === 0 && taskSets(0).stageAttemptId === 0)
complete(taskSets(0), Seq(
(Success, MapStatus(makeBlockManagerId("hostA"), compressAndDecompress(Array(10, 1000)))),
(Success, MapStatus(makeBlockManagerId("hostA"), compressAndDecompress(Array(100, 10000))))))
assert(taskSets(1).stageId === 1 && taskSets(1).stageAttemptId === 0)
complete(taskSets(1), Seq(
(Success, MapStatus(makeBlockManagerId("hostB"), compressAndDecompress(Array(20, 2000)))),
(Success, MapStatus(makeBlockManagerId("hostB"), compressAndDecompress(Array(200, 20000))))))

assert(taskSets(2).stageId === 2 && taskSets(2).stageAttemptId === 0)

// Tasks input size from shuffled RDD should be correct
assert(scheduler.inputSizeFromShuffledRDD(rddC, 0) ===
compressAndDecompress(Array(10, 100, 20, 200)).sum)
assert(scheduler.inputSizeFromShuffledRDD(rddC, 1) ===
compressAndDecompress(Array(1000, 10000, 2000, 20000)).sum)

// Tasks should be in descending order by input size from ShuffledRDD.
assert(taskSets(2).tasks(0).partitionId === 1 && taskSets(2).tasks(1).partitionId === 0)
}

/**
* Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
* Note that this checks only the host and not the executor ID.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
}
}


test("TaskSet with no preferences") {
sc = new SparkContext("local", "test")
sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
Expand Down