Skip to content

Commit 6e27cb6

Browse files
Colin Patrick Mccabepwendell
authored andcommitted
SPARK-1767: Prefer HDFS-cached replicas when scheduling data-local tasks
This change reorders the replicas returned by HadoopRDD#getPreferredLocations so that replicas cached by HDFS are at the start of the list. This requires Hadoop 2.5 or higher; previous versions of Hadoop do not expose the information needed to determine whether a replica is cached. Author: Colin Patrick Mccabe <[email protected]> Closes #1486 from cmccabe/SPARK-1767 and squashes the following commits: 338d4f8 [Colin Patrick Mccabe] SPARK-1767: Prefer HDFS-cached replicas when scheduling data-local tasks
1 parent bbdf1de commit 6e27cb6

File tree

8 files changed

+162
-17
lines changed

8 files changed

+162
-17
lines changed

core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.io.EOFException
2323

2424
import scala.collection.immutable.Map
2525
import scala.reflect.ClassTag
26+
import scala.collection.mutable.ListBuffer
2627

2728
import org.apache.hadoop.conf.{Configurable, Configuration}
2829
import org.apache.hadoop.mapred.FileSplit
@@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
4344
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
4445
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
4546
import org.apache.spark.util.{NextIterator, Utils}
47+
import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
4648

4749

4850
/**
@@ -249,9 +251,21 @@ class HadoopRDD[K, V](
249251
}
250252

251253
override def getPreferredLocations(split: Partition): Seq[String] = {
252-
// TODO: Filtering out "localhost" in case of file:// URLs
253-
val hadoopSplit = split.asInstanceOf[HadoopPartition]
254-
hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost")
254+
val hsplit = split.asInstanceOf[HadoopPartition].inputSplit.value
255+
val locs: Option[Seq[String]] = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
256+
case Some(c) =>
257+
try {
258+
val lsplit = c.inputSplitWithLocationInfo.cast(hsplit)
259+
val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf[Array[AnyRef]]
260+
Some(HadoopRDD.convertSplitLocationInfo(infos))
261+
} catch {
262+
case e: Exception =>
263+
logDebug("Failed to use InputSplitWithLocations.", e)
264+
None
265+
}
266+
case None => None
267+
}
268+
locs.getOrElse(hsplit.getLocations.filter(_ != "localhost"))
255269
}
256270

257271
override def checkpoint() {
@@ -261,7 +275,7 @@ class HadoopRDD[K, V](
261275
def getConf: Configuration = getJobConf()
262276
}
263277

264-
private[spark] object HadoopRDD {
278+
private[spark] object HadoopRDD extends Logging {
265279
/** Constructing Configuration objects is not threadsafe, use this lock to serialize. */
266280
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
267281

@@ -309,4 +323,42 @@ private[spark] object HadoopRDD {
309323
f(inputSplit, firstParent[T].iterator(split, context))
310324
}
311325
}
326+
327+
private[spark] class SplitInfoReflections {
328+
val inputSplitWithLocationInfo =
329+
Class.forName("org.apache.hadoop.mapred.InputSplitWithLocationInfo")
330+
val getLocationInfo = inputSplitWithLocationInfo.getMethod("getLocationInfo")
331+
val newInputSplit = Class.forName("org.apache.hadoop.mapreduce.InputSplit")
332+
val newGetLocationInfo = newInputSplit.getMethod("getLocationInfo")
333+
val splitLocationInfo = Class.forName("org.apache.hadoop.mapred.SplitLocationInfo")
334+
val isInMemory = splitLocationInfo.getMethod("isInMemory")
335+
val getLocation = splitLocationInfo.getMethod("getLocation")
336+
}
337+
338+
private[spark] val SPLIT_INFO_REFLECTIONS: Option[SplitInfoReflections] = try {
339+
Some(new SplitInfoReflections)
340+
} catch {
341+
case e: Exception =>
342+
logDebug("SplitLocationInfo and other new Hadoop classes are " +
343+
"unavailable. Using the older Hadoop location info code.", e)
344+
None
345+
}
346+
347+
private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = {
348+
val out = ListBuffer[String]()
349+
infos.foreach { loc => {
350+
val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get.
351+
getLocation.invoke(loc).asInstanceOf[String]
352+
if (locationStr != "localhost") {
353+
if (HadoopRDD.SPLIT_INFO_REFLECTIONS.get.isInMemory.
354+
invoke(loc).asInstanceOf[Boolean]) {
355+
logDebug("Partition " + locationStr + " is cached by Hadoop.")
356+
out += new HDFSCacheTaskLocation(locationStr).toString
357+
} else {
358+
out += new HostTaskLocation(locationStr).toString
359+
}
360+
}
361+
}}
362+
out.seq
363+
}
312364
}

core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,21 @@ class NewHadoopRDD[K, V](
173173
new NewHadoopMapPartitionsWithSplitRDD(this, f, preservesPartitioning)
174174
}
175175

176-
override def getPreferredLocations(split: Partition): Seq[String] = {
177-
val theSplit = split.asInstanceOf[NewHadoopPartition]
178-
theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost")
176+
override def getPreferredLocations(hsplit: Partition): Seq[String] = {
177+
val split = hsplit.asInstanceOf[NewHadoopPartition].serializableHadoopSplit.value
178+
val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match {
179+
case Some(c) =>
180+
try {
181+
val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]]
182+
Some(HadoopRDD.convertSplitLocationInfo(infos))
183+
} catch {
184+
case e : Exception =>
185+
logDebug("Failed to use InputSplit#getLocationInfo.", e)
186+
None
187+
}
188+
case None => None
189+
}
190+
locs.getOrElse(split.getLocations.filter(_ != "localhost"))
179191
}
180192

181193
def getConf: Configuration = confBroadcast.value.value

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ abstract class RDD[T: ClassTag](
208208
}
209209

210210
/**
211-
* Get the preferred locations of a partition (as hostnames), taking into account whether the
211+
* Get the preferred locations of a partition, taking into account whether the
212212
* RDD is checkpointed.
213213
*/
214214
final def preferredLocations(split: Partition): Seq[String] = {

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ class DAGScheduler(
13031303
// If the RDD has some placement preferences (as is the case for input RDDs), get those
13041304
val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
13051305
if (!rddPrefs.isEmpty) {
1306-
return rddPrefs.map(host => TaskLocation(host))
1306+
return rddPrefs.map(TaskLocation(_))
13071307
}
13081308
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
13091309
// that has any placement preferences. Ideally we would choose based on transfer sizes,

core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,51 @@ package org.apache.spark.scheduler
2222
* In the latter case, we will prefer to launch the task on that executorID, but our next level
2323
* of preference will be executors on the same host if this is not possible.
2424
*/
25-
private[spark]
26-
class TaskLocation private (val host: String, val executorId: Option[String]) extends Serializable {
27-
override def toString: String = "TaskLocation(" + host + ", " + executorId + ")"
25+
private[spark] sealed trait TaskLocation {
26+
def host: String
27+
}
28+
29+
/**
30+
* A location that includes both a host and an executor id on that host.
31+
*/
32+
private [spark] case class ExecutorCacheTaskLocation(override val host: String,
33+
val executorId: String) extends TaskLocation {
34+
}
35+
36+
/**
37+
* A location on a host.
38+
*/
39+
private [spark] case class HostTaskLocation(override val host: String) extends TaskLocation {
40+
override def toString = host
41+
}
42+
43+
/**
44+
* A location on a host that is cached by HDFS.
45+
*/
46+
private [spark] case class HDFSCacheTaskLocation(override val host: String)
47+
extends TaskLocation {
48+
override def toString = TaskLocation.inMemoryLocationTag + host
2849
}
2950

3051
private[spark] object TaskLocation {
31-
def apply(host: String, executorId: String) = new TaskLocation(host, Some(executorId))
52+
// We identify hosts on which the block is cached with this prefix. Because this prefix contains
53+
// underscores, which are not legal characters in hostnames, there should be no potential for
54+
// confusion. See RFC 952 and RFC 1123 for information about the format of hostnames.
55+
val inMemoryLocationTag = "hdfs_cache_"
56+
57+
def apply(host: String, executorId: String) = new ExecutorCacheTaskLocation(host, executorId)
3258

33-
def apply(host: String) = new TaskLocation(host, None)
59+
/**
60+
* Create a TaskLocation from a string returned by getPreferredLocations.
61+
* These strings have the form [hostname] or hdfs_cache_[hostname], depending on whether the
62+
* location is cached.
63+
*/
64+
def apply(str: String) = {
65+
val hstr = str.stripPrefix(inMemoryLocationTag)
66+
if (hstr.equals(str)) {
67+
new HostTaskLocation(str)
68+
} else {
69+
new HostTaskLocation(hstr)
70+
}
71+
}
3472
}

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,24 @@ private[spark] class TaskSetManager(
181181
}
182182

183183
for (loc <- tasks(index).preferredLocations) {
184-
for (execId <- loc.executorId) {
185-
addTo(pendingTasksForExecutor.getOrElseUpdate(execId, new ArrayBuffer))
184+
loc match {
185+
case e: ExecutorCacheTaskLocation =>
186+
addTo(pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer))
187+
case e: HDFSCacheTaskLocation => {
188+
val exe = sched.getExecutorsAliveOnHost(loc.host)
189+
exe match {
190+
case Some(set) => {
191+
for (e <- set) {
192+
addTo(pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer))
193+
}
194+
logInfo(s"Pending task $index has a cached location at ${e.host} " +
195+
", where there are executors " + set.mkString(","))
196+
}
197+
case None => logDebug(s"Pending task $index has a cached location at ${e.host} " +
198+
", but there are no executors alive there.")
199+
}
200+
}
201+
case _ => Unit
186202
}
187203
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
188204
for (rack <- sched.getRackForHost(loc.host)) {
@@ -283,7 +299,10 @@ private[spark] class TaskSetManager(
283299
// on multiple nodes when we replicate cached blocks, as in Spark Streaming
284300
for (index <- speculatableTasks if canRunOnHost(index)) {
285301
val prefs = tasks(index).preferredLocations
286-
val executors = prefs.flatMap(_.executorId)
302+
val executors = prefs.flatMap(_ match {
303+
case e: ExecutorCacheTaskLocation => Some(e.executorId)
304+
case _ => None
305+
});
287306
if (executors.contains(execId)) {
288307
speculatableTasks -= index
289308
return Some((index, TaskLocality.PROCESS_LOCAL))

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,28 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
642642
assert(manager.resourceOffer("execC", "host3", ANY) !== None)
643643
}
644644

645+
test("Test that locations with HDFSCacheTaskLocation are treated as PROCESS_LOCAL.") {
646+
// Regression test for SPARK-2931
647+
sc = new SparkContext("local", "test")
648+
val sched = new FakeTaskScheduler(sc,
649+
("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
650+
val taskSet = FakeTask.createTaskSet(3,
651+
Seq(HostTaskLocation("host1")),
652+
Seq(HostTaskLocation("host2")),
653+
Seq(HDFSCacheTaskLocation("host3")))
654+
val clock = new FakeClock
655+
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
656+
assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
657+
sched.removeExecutor("execA")
658+
manager.executorAdded()
659+
assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
660+
sched.removeExecutor("execB")
661+
manager.executorAdded()
662+
assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, ANY)))
663+
sched.removeExecutor("execC")
664+
manager.executorAdded()
665+
assert(manager.myLocalityLevels.sameElements(Array(ANY)))
666+
}
645667

646668
def createTaskResult(id: Int): DirectTaskResult[Int] = {
647669
val valueSer = SparkEnv.get.serializer.newInstance()

project/MimaExcludes.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ object MimaExcludes {
4141
MimaBuild.excludeSparkClass("mllib.linalg.Matrix") ++
4242
MimaBuild.excludeSparkClass("mllib.linalg.Vector") ++
4343
Seq(
44+
ProblemFilters.exclude[IncompatibleTemplateDefProblem](
45+
"org.apache.spark.scheduler.TaskLocation"),
4446
// Added normL1 and normL2 to trait MultivariateStatisticalSummary
4547
ProblemFilters.exclude[MissingMethodProblem](
4648
"org.apache.spark.mllib.stat.MultivariateStatisticalSummary.normL1"),

0 commit comments

Comments
 (0)