Skip to content
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 11 additions & 4 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,15 @@ class DAGScheduler(
def getCacheLocs(rdd: RDD[_]): Seq[Seq[TaskLocation]] = cacheLocs.synchronized {
// Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs: Seq[Seq[TaskLocation]] = blockManagerMaster.getLocations(blockIds).map { bms =>
bms.map(bm => TaskLocation(bm.host, bm.executorId))
// Note: if the storage level is NONE, we don't need to get locations from block manager.
val locs: Seq[Seq[TaskLocation]] = if (rdd.getStorageLevel == StorageLevel.NONE) {
Seq.fill(rdd.partitions.size)(Nil)
} else {
val blockIds =
rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
blockManagerMaster.getLocations(blockIds).map { bms =>
bms.map(bm => TaskLocation(bm.host, bm.executorId))
}
}
cacheLocs(rdd.id) = locs
}
Expand Down Expand Up @@ -386,7 +392,8 @@ class DAGScheduler(
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
if (getCacheLocs(rdd).contains(Nil)) {
val rddHasUncachedPartitions = getCacheLocs(rdd).contains(Nil)
if (rddHasUncachedPartitions) {
for (dep <- rdd.dependencies) {
dep match {
case shufDep: ShuffleDependency[_, _, _] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ class DAGSchedulerSuite
}

test("cache location preferences w/ dependency") {
val baseRdd = new MyRDD(sc, 1, Nil)
val baseRdd = new MyRDD(sc, 1, Nil).cache()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To clarify for other reviewers, I think that we need these cache() calls so these other tests don't fail due to the skipping of the cached locations lookups.

val finalRdd = new MyRDD(sc, 1, List(new OneToOneDependency(baseRdd)))
cacheLocations(baseRdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
Expand All @@ -331,7 +331,7 @@ class DAGSchedulerSuite
}

test("regression test for getCacheLocs") {
val rdd = new MyRDD(sc, 3, Nil)
val rdd = new MyRDD(sc, 3, Nil).cache()
cacheLocations(rdd.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
cacheLocations(rdd.id -> 1) =
Expand All @@ -342,6 +342,29 @@ class DAGSchedulerSuite
assert(locs === Seq(Seq("hostA", "hostB"), Seq("hostB", "hostC"), Seq("hostC", "hostD")))
}

/**
* +---+ shuffle +---+ +---+ +---+
* | A |<--------| B |<---| C |<---| D |
* +---+ +---+ +---+ +---+
* Here, D has one-to-one dependencies on C. C is derived from A by performing a shuffle
* and then a map. If we're trying to determine which ancestor stages need to be computed in
* order to compute D, we need to figure out whether the shuffle A -> B should be performed.
* If the RDD C, which has only one ancestor via a narrow dependency, is cached, then we won't
* need to compute A, even if it has some unavailable output partitions. The same goes for B:
* if B is 100% cached, then we can avoid the shuffle on A.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Josh's comment was an awesome description of how the dependencies should be computed, but isn't quite appropriate here as the comment for the test. What about something like:

This test ensures that if a particular RDD is cached, RDDs earlier in the dependency chain are not computed. It constructs the following chain of dependencies:
+---+ shuffle +---+ +---+ +---+
| A |<--------| B |<---| C |<---| D |
+---+ +---+ +---+ +---+
Here, B is derived from A by performing a shuffle, C has a one-to-one dependency on B, and D similarly has a one-to-one dependency on C. If none of the RDDs were cached, this set of RDDs would result in a two stage job: one ShuffleMapStage, and a ResultStage that reads the shuffled data from RDD A. This test ensures that if C is cached, the scheduler doesn't perform a shuffle, and instead computes the result using a single ResultStage that reads C's cached data.

*/
test("SPARK-7826: getMissingParentStages should consider all ancestor RDDs' cache statuses") {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove "SPARK-7826" from the name of this test, since the test isn't checking for the bug described by SPARK-7826? It's great to add this test in the PR -- but having the JIRA name in the test name is something we usually only do when the test is for the issue described by that JIRA.

val rddA = new MyRDD(sc, 1, Nil)
val rddB = new MyRDD(sc, 1, List(new ShuffleDependency(rddA, null)))
val rddC = new MyRDD(sc, 1, List(new OneToOneDependency(rddB))).cache()
val rddD = new MyRDD(sc, 1, List(new OneToOneDependency(rddC)))
cacheLocations(rddC.id -> 0) =
Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
submit(rddD, Array(0))
assert(scheduler.runningStages.size === 1)
assert(scheduler.runningStages.head.id === 1)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you actually change this to:
assert(scheduler.runningStages.head.isInstanceOf[ResultStage])?

And then add a comment saying something like "Make sure that the scheduler is running the final result stage. Because C is cached, the shuffle map stage to compute A does not need to be run."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think this is more intuitive; otherwise, it's hard for someone looking at this to understand why the ID should be 1. This also makes the test more agnostic to unrelated scheduler internals, like if we change the way we assign IDs to stages)

}

test("avoid exponential blowup when getting preferred locs list") {
// Build up a complex dependency graph with repeated zip operations, without preferred locations
var rdd: RDD[_] = new MyRDD(sc, 1, Nil)
Expand Down Expand Up @@ -678,9 +701,9 @@ class DAGSchedulerSuite
}

test("cached post-shuffle") {
val shuffleOneRdd = new MyRDD(sc, 2, Nil)
val shuffleOneRdd = new MyRDD(sc, 2, Nil).cache()
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne))
val shuffleTwoRdd = new MyRDD(sc, 2, List(shuffleDepOne)).cache()
val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
val finalRdd = new MyRDD(sc, 1, List(shuffleDepTwo))
submit(finalRdd, Array(0))
Expand Down