@@ -29,15 +29,16 @@ import org.apache.spark.annotation.DeveloperApi
2929import org .apache .spark .util .collection .{ExternalAppendOnlyMap , AppendOnlyMap , CompactBuffer }
3030import org .apache .spark .util .Utils
3131import org .apache .spark .serializer .Serializer
32- import org .apache .spark .shuffle .ShuffleHandle
33-
34- private [spark] sealed trait CoGroupSplitDep extends Serializable
3532
33+ /** The references to rdd and splitIndex are transient because redundant information is stored
34+ * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from
35+ * CoGrpupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the
36+ * task closure. */
3637private [spark] case class NarrowCoGroupSplitDep (
37- rdd : RDD [_],
38- splitIndex : Int ,
38+ @ transient rdd : RDD [_],
39+ @ transient splitIndex : Int ,
3940 var split : Partition
40- ) extends CoGroupSplitDep {
41+ ) extends Serializable {
4142
4243 @ throws(classOf [IOException ])
4344 private def writeObject (oos : ObjectOutputStream ): Unit = Utils .tryOrIOException {
@@ -47,9 +48,14 @@ private[spark] case class NarrowCoGroupSplitDep(
4748 }
4849}
4950
50- private [spark] case class ShuffleCoGroupSplitDep (handle : ShuffleHandle ) extends CoGroupSplitDep
51-
52- private [spark] class CoGroupPartition (idx : Int , val deps : Array [CoGroupSplitDep ])
51+ /**
52+ * Stores information about the narrow dependencies used by a CoGroupedRdd. narrowDeps maps to
53+ * the dependencies variable in the parent RDD: for each one to one dependency in dependencies,
54+ * narrowDeps has a NarrowCoGroupSplitDep (describing the partition for that dependency) at the
55+ * corresponding index.
56+ */
57+ private [spark] class CoGroupPartition (
58+ idx : Int , val narrowDeps : Array [Option [NarrowCoGroupSplitDep ]])
5359 extends Partition with Serializable {
5460 override val index : Int = idx
5561 override def hashCode (): Int = idx
@@ -105,9 +111,9 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
105111 // Assume each RDD contributed a single dependency, and get it
106112 dependencies(j) match {
107113 case s : ShuffleDependency [_, _, _] =>
108- new ShuffleCoGroupSplitDep (s.shuffleHandle)
114+ None
109115 case _ =>
110- new NarrowCoGroupSplitDep (rdd, i, rdd.partitions(i))
116+ Some ( new NarrowCoGroupSplitDep (rdd, i, rdd.partitions(i) ))
111117 }
112118 }.toArray)
113119 }
@@ -120,20 +126,21 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
120126 val sparkConf = SparkEnv .get.conf
121127 val externalSorting = sparkConf.getBoolean(" spark.shuffle.spill" , true )
122128 val split = s.asInstanceOf [CoGroupPartition ]
123- val numRdds = split.deps .length
129+ val numRdds = dependencies .length
124130
125131 // A list of (rdd iterator, dependency number) pairs
126132 val rddIterators = new ArrayBuffer [(Iterator [Product2 [K , Any ]], Int )]
127- for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
128- case NarrowCoGroupSplitDep (rdd, _, itsSplit) =>
133+ for ((dep, depNum) <- dependencies.zipWithIndex) dep match {
134+ case oneToOneDependency : OneToOneDependency [Product2 [K , Any ]] =>
135+ val dependencyPartition = split.narrowDeps(depNum).get.split
129136 // Read them from the parent
130- val it = rdd.iterator(itsSplit , context). asInstanceOf [ Iterator [ Product2 [ K , Any ]]]
137+ val it = oneToOneDependency. rdd.iterator(dependencyPartition , context)
131138 rddIterators += ((it, depNum))
132139
133- case ShuffleCoGroupSplitDep (handle) =>
140+ case shuffleDependency : ShuffleDependency [_, _, _] =>
134141 // Read map outputs of shuffle
135142 val it = SparkEnv .get.shuffleManager
136- .getReader(handle , split.index, split.index + 1 , context)
143+ .getReader(shuffleDependency.shuffleHandle , split.index, split.index + 1 , context)
137144 .read()
138145 rddIterators += ((it, depNum))
139146 }
0 commit comments