diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index e4587c96eae1..9f6ab877ee98 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -42,3 +42,12 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev = null } } + +private[spark] final class PreserveLocationsRDD[U: ClassTag, T: ClassTag]( + prev: RDD[T], + f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) + preservesPartitioning: Boolean = false, p: (Int) => Seq[String]) + extends MapPartitionsRDD[U, T](prev, f, preservesPartitioning) { + + override def getPreferredLocations(split: Partition): Seq[String] = p(split.index) +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 34d32aacfb62..5e5733be9a65 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -821,6 +821,17 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } + def mapPartitionsWithIndexPreserveLocations[U: ClassTag]( + f: (Int, Iterator[T]) => Iterator[U], + p: (Int) => Seq[String], + preservesPartitioning: Boolean = false): RDD[U] = withScope { + val cleanedF = sc.clean(f) + new PreserveLocationsRDD( + this, + (context: TaskContext, index: Int, iter: Iterator[T]) => cleanedF(index, iter), + preservesPartitioning, p) + } + /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * second element in each RDD, etc. Assumes that the two RDDs have the *same number of