@@ -23,6 +23,7 @@ import java.io.EOFException
2323
2424import scala .collection .immutable .Map
2525import scala .reflect .ClassTag
26+ import scala .collection .mutable .ListBuffer
2627
2728import org .apache .hadoop .conf .{Configurable , Configuration }
2829import org .apache .hadoop .mapred .FileSplit
@@ -43,6 +44,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
4344import org .apache .spark .executor .{DataReadMethod , InputMetrics }
4445import org .apache .spark .rdd .HadoopRDD .HadoopMapPartitionsWithSplitRDD
4546import org .apache .spark .util .{NextIterator , Utils }
47+ import org .apache .spark .scheduler .{HostTaskLocation , HDFSCacheTaskLocation }
4648
4749
4850/**
@@ -249,9 +251,21 @@ class HadoopRDD[K, V](
249251 }
250252
251253 override def getPreferredLocations (split : Partition ): Seq [String ] = {
252- // TODO: Filtering out "localhost" in case of file:// URLs
253- val hadoopSplit = split.asInstanceOf [HadoopPartition ]
254- hadoopSplit.inputSplit.value.getLocations.filter(_ != " localhost" )
254+ val hsplit = split.asInstanceOf [HadoopPartition ].inputSplit.value
255+ val locs : Option [Seq [String ]] = HadoopRDD .SPLIT_INFO_REFLECTIONS match {
256+ case Some (c) =>
257+ try {
258+ val lsplit = c.inputSplitWithLocationInfo.cast(hsplit)
259+ val infos = c.getLocationInfo.invoke(lsplit).asInstanceOf [Array [AnyRef ]]
260+ Some (HadoopRDD .convertSplitLocationInfo(infos))
261+ } catch {
262+ case e : Exception =>
263+ logDebug(" Failed to use InputSplitWithLocations." , e)
264+ None
265+ }
266+ case None => None
267+ }
268+ locs.getOrElse(hsplit.getLocations.filter(_ != " localhost" ))
255269 }
256270
257271 override def checkpoint () {
@@ -261,7 +275,7 @@ class HadoopRDD[K, V](
261275 def getConf : Configuration = getJobConf()
262276}
263277
264- private [spark] object HadoopRDD {
278+ private [spark] object HadoopRDD extends Logging {
265279 /** Constructing Configuration objects is not threadsafe, use this lock to serialize. */
266280 val CONFIGURATION_INSTANTIATION_LOCK = new Object ()
267281
@@ -309,4 +323,42 @@ private[spark] object HadoopRDD {
309323 f(inputSplit, firstParent[T ].iterator(split, context))
310324 }
311325 }
326+
327+ private [spark] class SplitInfoReflections {
328+ val inputSplitWithLocationInfo =
329+ Class .forName(" org.apache.hadoop.mapred.InputSplitWithLocationInfo" )
330+ val getLocationInfo = inputSplitWithLocationInfo.getMethod(" getLocationInfo" )
331+ val newInputSplit = Class .forName(" org.apache.hadoop.mapreduce.InputSplit" )
332+ val newGetLocationInfo = newInputSplit.getMethod(" getLocationInfo" )
333+ val splitLocationInfo = Class .forName(" org.apache.hadoop.mapred.SplitLocationInfo" )
334+ val isInMemory = splitLocationInfo.getMethod(" isInMemory" )
335+ val getLocation = splitLocationInfo.getMethod(" getLocation" )
336+ }
337+
338+ private [spark] val SPLIT_INFO_REFLECTIONS : Option [SplitInfoReflections ] = try {
339+ Some (new SplitInfoReflections )
340+ } catch {
341+ case e : Exception =>
342+ logDebug(" SplitLocationInfo and other new Hadoop classes are " +
343+ " unavailable. Using the older Hadoop location info code." , e)
344+ None
345+ }
346+
347+ private [spark] def convertSplitLocationInfo (infos : Array [AnyRef ]): Seq [String ] = {
348+ val out = ListBuffer [String ]()
349+ infos.foreach { loc => {
350+ val locationStr = HadoopRDD .SPLIT_INFO_REFLECTIONS .get.
351+ getLocation.invoke(loc).asInstanceOf [String ]
352+ if (locationStr != " localhost" ) {
353+ if (HadoopRDD .SPLIT_INFO_REFLECTIONS .get.isInMemory.
354+ invoke(loc).asInstanceOf [Boolean ]) {
355+ logDebug(" Partition " + locationStr + " is cached by Hadoop." )
356+ out += new HDFSCacheTaskLocation (locationStr).toString
357+ } else {
358+ out += new HostTaskLocation (locationStr).toString
359+ }
360+ }
361+ }}
362+ out.seq
363+ }
312364}
0 commit comments