Skip to content

Commit 658e1af

Browse files
committed
[SPARK-5360] Eliminate duplicate objects in serialized CoGroupedRDD
1 parent 1e340c3 commit 658e1af

File tree

2 files changed

+42
-29
lines changed

2 files changed

+42
-29
lines changed

core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi
2929
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
3030
import org.apache.spark.util.Utils
3131
import org.apache.spark.serializer.Serializer
32-
import org.apache.spark.shuffle.ShuffleHandle
33-
34-
private[spark] sealed trait CoGroupSplitDep extends Serializable
3532

33+
/** The references to rdd and splitIndex are transient because redundant information is stored
34+
* in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from
35+
* CoGrpupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
36+
* task closure. */
3637
private[spark] case class NarrowCoGroupSplitDep(
37-
rdd: RDD[_],
38-
splitIndex: Int,
38+
@transient rdd: RDD[_],
39+
@transient splitIndex: Int,
3940
var split: Partition
40-
) extends CoGroupSplitDep {
41+
) extends Serializable {
4142

4243
@throws(classOf[IOException])
4344
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
@@ -47,9 +48,14 @@ private[spark] case class NarrowCoGroupSplitDep(
4748
}
4849
}
4950

50-
private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
51-
52-
private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
51+
/**
52+
* Stores information about the narrow dependencies used by a CoGroupedRdd. narrowDeps maps to
53+
* the dependencies variable in the parent RDD: for each one to one dependency in dependencies,
54+
* narrowDeps has a NarrowCoGroupSplitDep (describing the partition for that dependency) at the
55+
* corresponding index.
56+
*/
57+
private[spark] class CoGroupPartition(
58+
idx: Int, val narrowDeps: Array[Option[NarrowCoGroupSplitDep]])
5359
extends Partition with Serializable {
5460
override val index: Int = idx
5561
override def hashCode(): Int = idx
@@ -105,9 +111,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
105111
// Assume each RDD contributed a single dependency, and get it
106112
dependencies(j) match {
107113
case s: ShuffleDependency[_, _, _] =>
108-
new ShuffleCoGroupSplitDep(s.shuffleHandle)
114+
None
109115
case _ =>
110-
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
116+
Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))
111117
}
112118
}.toArray)
113119
}
@@ -120,20 +126,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
120126
val sparkConf = SparkEnv.get.conf
121127
val externalSorting = sparkConf.getBoolean("spark.shuffle.spill", true)
122128
val split = s.asInstanceOf[CoGroupPartition]
123-
val numRdds = split.deps.length
129+
val numRdds = dependencies.length
124130

125131
// A list of (rdd iterator, dependency number) pairs
126132
val rddIterators = new ArrayBuffer[(Iterator[Product2[K, Any]], Int)]
127-
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
128-
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
133+
for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
134+
case oneToOneDependency: OneToOneDependency[Product2[K, Any]] =>
135+
val dependencyPartition = split.narrowDeps(depNum).get.split
129136
// Read them from the parent
130-
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
137+
val it = oneToOneDependency.rdd.iterator(dependencyPartition, context)
131138
rddIterators += ((it, depNum))
132139

133-
case ShuffleCoGroupSplitDep(handle) =>
140+
case shuffleDependency: ShuffleDependency[_, _, _] =>
134141
// Read map outputs of shuffle
135142
val it = SparkEnv.get.shuffleManager
136-
.getReader(handle, split.index, split.index + 1, context)
143+
.getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
137144
.read()
138145
rddIterators += ((it, depNum))
139146
}

core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,9 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
8181
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
8282
dependencies(j) match {
8383
case s: ShuffleDependency[_, _, _] =>
84-
new ShuffleCoGroupSplitDep(s.shuffleHandle)
84+
None
8585
case _ =>
86-
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
86+
Some(new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i)))
8787
}
8888
}.toArray)
8989
}
@@ -105,20 +105,26 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
105105
seq
106106
}
107107
}
108-
def integrate(dep: CoGroupSplitDep, op: Product2[K, V] => Unit): Unit = dep match {
109-
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
110-
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
108+
def integrate(depNum: Int, op: Product2[K, V] => Unit) = {
109+
dependencies(depNum) match {
110+
case oneToOneDependency: OneToOneDependency[_] =>
111+
val dependencyPartition = partition.narrowDeps(depNum).get.split
112+
oneToOneDependency.rdd.iterator(dependencyPartition, context)
113+
.asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
111114

112-
case ShuffleCoGroupSplitDep(handle) =>
113-
val iter = SparkEnv.get.shuffleManager
114-
.getReader(handle, partition.index, partition.index + 1, context)
115-
.read()
116-
iter.foreach(op)
115+
case shuffleDependency: ShuffleDependency[_, _, _] =>
116+
val iter = SparkEnv.get.shuffleManager
117+
.getReader(
118+
shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context)
119+
.read()
120+
iter.foreach(op)
121+
}
117122
}
123+
118124
// the first dep is rdd1; add all values to the map
119-
integrate(partition.deps(0), t => getSeq(t._1) += t._2)
125+
integrate(0, t => getSeq(t._1) += t._2)
120126
// the second dep is rdd2; remove all of its keys
121-
integrate(partition.deps(1), t => map.remove(t._1))
127+
integrate(1, t => map.remove(t._1))
122128
map.iterator.map { t => t._2.iterator.map { (t._1, _) } }.flatten
123129
}
124130

0 commit comments

Comments
 (0)