Skip to content

Commit dcca830

Browse files
committed
Add API to support configuring preferred resource per rdd
1 parent 8141d55 commit dcca830

File tree

16 files changed

+346
-3
lines changed

16 files changed

+346
-3
lines changed

core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
2222
import scala.reflect.ClassTag
2323

2424
import org.apache.spark._
25+
import org.apache.spark.rdd.resource.PreferredResources
2526
import org.apache.spark.util.Utils
2627

2728
private[spark]
@@ -70,6 +71,13 @@ class CartesianRDD[T: ClassTag, U: ClassTag](
7071
(rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)).distinct
7172
}
7273

74+
private[spark] override def getPreferredResources(split: Partition): PreferredResources = {
75+
val currSplit = split.asInstanceOf[CartesianPartition]
76+
rdd1.getPreferredResources(currSplit.s1)
77+
.mergeOther(rdd2.getPreferredResources(currSplit.s2))
78+
.mergeOther(super.getPreferredResources(split))
79+
}
80+
7381
override def compute(split: Partition, context: TaskContext): Iterator[(T, U)] = {
7482
val currSplit = split.asInstanceOf[CartesianPartition]
7583
for (x <- rdd1.iterator(currSplit.s1, context);

core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.reflect.ClassTag
2525

2626
import org.apache.spark._
2727
import org.apache.spark.annotation.DeveloperApi
28+
import org.apache.spark.rdd.resource.PreferredResources
2829
import org.apache.spark.serializer.Serializer
2930
import org.apache.spark.util.Utils
3031
import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap}
@@ -126,6 +127,15 @@ class CoGroupedRDD[K: ClassTag](
126127
array
127128
}
128129

130+
private[spark] override def getPreferredResources(split: Partition): PreferredResources = {
131+
split.asInstanceOf[CoGroupPartition].narrowDeps.map {
132+
case Some(s) => s.rdd.getPreferredResources(s.split)
133+
case None => PreferredResources.EMPTY
134+
}
135+
.reduce { (s1, s2) => s1.mergeOther(s2) }
136+
.mergeOther(super.getPreferredResources(split))
137+
}
138+
129139
override val partitioner: Some[Partitioner] = Some(part)
130140

131141
override def compute(s: Partition, context: TaskContext): Iterator[(K, Array[Iterable[_]])] = {

core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import scala.language.existentials
2525
import scala.reflect.ClassTag
2626

2727
import org.apache.spark._
28+
import org.apache.spark.rdd.resource.PreferredResources
2829
import org.apache.spark.util.Utils
2930

3031
/**
@@ -122,6 +123,16 @@ private[spark] class CoalescedRDD[T: ClassTag](
122123
override def getPreferredLocations(partition: Partition): Seq[String] = {
123124
partition.asInstanceOf[CoalescedRDDPartition].preferredLocation.toSeq
124125
}
126+
127+
private[spark] override def getPreferredResources(split: Partition): PreferredResources = {
128+
val parentResources = split.asInstanceOf[CoalescedRDDPartition].parents.map { p =>
129+
prev.getPreferredResources(p)
130+
}
131+
132+
parentResources
133+
.reduce { (r1, r2) => r1.mergeOther(r2) }
134+
.mergeOther(super.getPreferredResources(split))
135+
}
125136
}
126137

127138
/**

core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import scala.reflect.ClassTag
2121

2222
import org.apache.spark.{NarrowDependency, Partition, TaskContext}
2323
import org.apache.spark.annotation.DeveloperApi
24+
import org.apache.spark.rdd.resource.PreferredResources
2425

2526
private[spark] class PartitionPruningRDDPartition(idx: Int, val parentSplit: Partition)
2627
extends Partition {
@@ -66,8 +67,12 @@ class PartitionPruningRDD[T: ClassTag](
6667

6768
override protected def getPartitions: Array[Partition] =
6869
dependencies.head.asInstanceOf[PruneDependency[T]].partitions
69-
}
7070

71+
private[spark] override def getPreferredResources(split: Partition): PreferredResources = {
72+
prev.getPreferredResources(split.asInstanceOf[PartitionPruningRDDPartition].parentSplit)
73+
.mergeOther(super.getPreferredResources(split))
74+
}
75+
}
7176

7277
@DeveloperApi
7378
object PartitionPruningRDD {

core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
2222
import scala.reflect.ClassTag
2323

2424
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
25+
import org.apache.spark.rdd.resource.PreferredResources
2526
import org.apache.spark.util.Utils
2627

2728
/**
@@ -94,6 +95,15 @@ class PartitionerAwareUnionRDD[T: ClassTag](
9495
location.toSeq
9596
}
9697

98+
private[spark] override def getPreferredResources(split: Partition): PreferredResources = {
99+
val partition = split.asInstanceOf[PartitionerAwareUnionRDDPartition]
100+
partition.rdds.zip(partition.parents).map { case (rdd, p) =>
101+
rdd.getPreferredResources(p)
102+
}
103+
.reduce { (s1, s2) => s1.mergeOther(s2) }
104+
.mergeOther(super.getPreferredResources(split))
105+
}
106+
97107
override def compute(s: Partition, context: TaskContext): Iterator[T] = {
98108
val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
99109
rdds.zip(parentPartitions).iterator.flatMap {

core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import java.util.Random
2222
import scala.reflect.ClassTag
2323

2424
import org.apache.spark.{Partition, TaskContext}
25+
import org.apache.spark.rdd.resource.PreferredResources
2526
import org.apache.spark.util.Utils
2627
import org.apache.spark.util.random.RandomSampler
2728

@@ -67,4 +68,9 @@ private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
6768
thisSampler.setSeed(split.seed)
6869
thisSampler.sample(firstParent[T].iterator(split.prev, context))
6970
}
71+
72+
private[spark] override def getPreferredResources(split: Partition): PreferredResources = {
73+
prev.getPreferredResources(split.asInstanceOf[PartitionwiseSampledRDDPartition].prev)
74+
.mergeOther(super.getPreferredResources(split))
75+
}
7076
}

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import org.apache.spark.partial.BoundedDouble
4040
import org.apache.spark.partial.CountEvaluator
4141
import org.apache.spark.partial.GroupedCountEvaluator
4242
import org.apache.spark.partial.PartialResult
43+
import org.apache.spark.rdd.resource.{PreferredResources, ResourcesPreferenceBuilder}
4344
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
4445
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
4546
import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
@@ -136,6 +137,14 @@ abstract class RDD[T: ClassTag](
136137
*/
137138
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
138139

140+
private[spark] def getPreferredResources(split: Partition): PreferredResources = {
141+
dependencies.headOption.flatMap {
142+
case d: OneToOneDependency[_] =>
143+
Some(d.rdd.getPreferredResources(split).mergeOther(this.preferredResources))
144+
case _ => None
145+
}.getOrElse(preferredResources)
146+
}
147+
139148
/** Optionally overridden by subclasses to specify how they are partitioned. */
140149
@transient val partitioner: Option[Partitioner] = None
141150

@@ -158,6 +167,21 @@ abstract class RDD[T: ClassTag](
158167
this
159168
}
160169

170+
@transient protected var preferredResources: PreferredResources = PreferredResources.EMPTY
171+
172+
private[spark] def setPreferredResources(info: PreferredResources): Unit = {
173+
this.preferredResources = info
174+
}
175+
176+
def withResources(): ResourcesPreferenceBuilder[T] = {
177+
new ResourcesPreferenceBuilder(this)
178+
}
179+
180+
def clearResources(): this.type = {
181+
preferredResources = PreferredResources.EMPTY
182+
this
183+
}
184+
161185
/**
162186
* Mark this RDD for persisting using the specified level.
163187
*

core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import org.apache.spark.Partitioner
3030
import org.apache.spark.ShuffleDependency
3131
import org.apache.spark.SparkEnv
3232
import org.apache.spark.TaskContext
33+
import org.apache.spark.rdd.resource.PreferredResources
3334

3435
/**
3536
* An optimized version of cogroup for set difference/subtraction.
@@ -84,6 +85,15 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
8485
array
8586
}
8687

88+
override private[spark] def getPreferredResources(split: Partition): PreferredResources = {
89+
split.asInstanceOf[CoGroupPartition].narrowDeps.map {
90+
case Some(s) => s.rdd.getPreferredResources(s.split)
91+
case None => PreferredResources.EMPTY
92+
}
93+
.reduce { (s1, s2) => s1.mergeOther(s2) }
94+
.mergeOther(super.getPreferredResources(split))
95+
}
96+
8797
override val partitioner = Some(part)
8898

8999
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {

core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import scala.reflect.ClassTag
2626

2727
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
2828
import org.apache.spark.annotation.DeveloperApi
29+
import org.apache.spark.rdd.resource.PreferredResources
2930
import org.apache.spark.util.Utils
3031

3132
/**
@@ -39,7 +40,7 @@ import org.apache.spark.util.Utils
3940
*/
4041
private[spark] class UnionPartition[T: ClassTag](
4142
idx: Int,
42-
@transient private val rdd: RDD[T],
43+
@transient val rdd: RDD[T],
4344
val parentRddIndex: Int,
4445
@transient private val parentRddPartitionIndex: Int)
4546
extends Partition {
@@ -112,4 +113,10 @@ class UnionRDD[T: ClassTag](
112113
super.clearDependencies()
113114
rdds = null
114115
}
116+
117+
override private[spark] def getPreferredResources(split: Partition): PreferredResources = {
118+
val partition = split.asInstanceOf[UnionPartition[_]]
119+
partition.rdd.getPreferredResources(partition.parentPartition)
120+
.mergeOther(super.getPreferredResources(split))
121+
}
115122
}

core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,12 @@ import java.io.{IOException, ObjectOutputStream}
2222
import scala.reflect.ClassTag
2323

2424
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
25+
import org.apache.spark.rdd.resource.PreferredResources
2526
import org.apache.spark.util.Utils
2627

2728
private[spark] class ZippedPartitionsPartition(
2829
idx: Int,
29-
@transient private val rdds: Seq[RDD[_]],
30+
@transient val rdds: Seq[RDD[_]],
3031
@transient val preferredLocations: Seq[String])
3132
extends Partition {
3233

@@ -74,6 +75,15 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
7475
super.clearDependencies()
7576
rdds = null
7677
}
78+
79+
override private[spark] def getPreferredResources(split: Partition): PreferredResources = {
80+
val partition = split.asInstanceOf[ZippedPartitionsPartition]
81+
partition.rdds.zip(partition.partitions).map { case (rdd, p) =>
82+
rdd.getPreferredResources(p)
83+
}
84+
.reduce { (s1, s2) => s1.mergeOther(s2) }
85+
.mergeOther(super.getPreferredResources(split))
86+
}
7787
}
7888

7989
private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](

0 commit comments

Comments
 (0)