Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,15 @@ class DAGScheduler(
def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized {
// Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms =>
bms.map(bm => TaskLocation(bm.host, bm.executorId))
// Note: if the storage level is NONE, we don't need to get locations from block manager.
val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) {
Seq.fill(rdd.partitions.size)(Nil)
} else {
val blockIds =
rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
blockManagerMaster.getLocations(blockIds).map { bms =>
bms.map(bm => TaskLocation(bm.host, bm.executorId))
}
}
cacheLocs(rdd.id) = locs
}
Expand Down Expand Up @@ -386,7 +392,8 @@ class DAGScheduler(
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
if (getCacheLocs(rdd).contains(Nil)) {
val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
if (rddHasUncachedPartitions) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class DAGSchedulerSuite
}

test("cache location preferences w/ dependency") {
val baseRdd = new MyRDD(sc, 1, Nil)
val baseRdd = new MyRDD(sc, 1, Nil).cache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify for other reviewers, I think that we need these cache() calls so these other tests don't fail due to the skipping of the cached locations lookups.

val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd)))
cacheLocations(baseRdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
Expand All @@ -331,7 +331,7 @@ class DAGSchedulerSuite
}

test("regression test for getCacheLocs") {
val rdd = new MyRDD(sc, 3, Nil)
val rdd = new MyRDD(sc, 3, Nil).cache()
cacheLocations(rdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
cacheLocations(rdd.id -> 1) =
Expand All @@ -342,6 +342,33 @@ class DAGSchedulerSuite
assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD")))
}

/**
* This test ensures that if a particular RDD is cached, RDDs earlier in the dependency chain
* are not computed. It constructs the following chain of dependencies:
* +---+ shuffle +---+ +---+ +---+
* | A |<--------| B |<---| C |<---| D |
* +---+ +---+ +---+ +---+
* Here, B is derived from A by performing a shuffle, C has a one-to-one dependency on B,
* and D similarly has a one-to-one dependency on C. If none of the RDDs were cached, this
* set of RDDs would result in a two stage job: one ShuffleMapStage, and a ResultStage that
* reads the shuffled data from RDD A. This test ensures that if C is cached, the scheduler
* doesn't perform a shuffle, and instead computes the result using a single ResultStage
* that reads C's cached data.
*/
test("getMissingParentStages should consider all ancestor RDDs' cache statuses") {
val rddA = new MyRDD(sc, 1, Nil)
val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null)))
val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache()
val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC)))
cacheLocations(rddC.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
submit(rddD, Array(0))
assert(scheduler.runningStages.size === 1)
// Make sure that the scheduler is running the final result stage.
// Because C is cached, the shuffle map stage to compute A does not need to be run.
assert(scheduler.runningStages.head.isInstanceOf[ResultStage])
}

test("avoid exponential blowup when getting preferred locs list") {
// Build up a complex dependency graph with repeated zip operations, without preferred locations
var rdd: RDD[_] = new MyRDD(sc, 1, Nil)
Expand Down Expand Up @@ -678,9 +705,9 @@ class DAGSchedulerSuite
}

test("cached post-shuffle") {
val shuffleOneRdd = new MyRDD(sc, 2, Nil)
val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache()
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne))
val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache()
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
submit(finalRdd, Array(0))
Expand Down