@@ -59,12 +59,63 @@ case class Exchange(
5959
6060 override def output : Seq [Attribute ] = child.output
6161
62- /** We must copy rows when sort based shuffle is on */
63- protected def sortBasedShuffleOn = SparkEnv .get.shuffleManager.isInstanceOf [SortShuffleManager ]
62+ private val sortBasedShuffleOn = SparkEnv .get.shuffleManager.isInstanceOf [SortShuffleManager ]
6463
6564 private val bypassMergeThreshold =
6665 child.sqlContext.sparkContext.conf.getInt(" spark.shuffle.sort.bypassMergeThreshold" , 200 )
6766
67+ private val serializeMapOutputs =
68+ child.sqlContext.sparkContext.conf.getBoolean(" spark.shuffle.sort.serializeMapOutputs" , true )
69+
70+ /**
71+ * Determines whether records must be defensively copied before being sent to the shuffle.
72+ * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The
73+ * shuffle code assumes that objects are immutable and hence does not perform its own defensive
74+ * copying. In Spark SQL, however, operators' iterators return the same mutable `Row` object. In
75+ * order to properly shuffle the output of these operators, we need to perform our own copying
76+ * prior to sending records to the shuffle. This copying is expensive, so we try to avoid it
77+ * whenever possible. This method encapsulates the logic for choosing when to copy.
78+ *
79+ * In the long run, we might want to push this logic into core's shuffle APIs so that we don't
80+ * have to rely on knowledge of core internals here in SQL.
81+ *
82+ * See SPARK-2967, SPARK-4479, and SPARK-7375 for more discussion of this issue.
83+ *
84+ * @param numPartitions the number of output partitions produced by the shuffle
85+ * @param serializer the serializer that will be used to write rows
86+ * @return true if rows should be copied before being shuffled, false otherwise
87+ */
88+ private def needToCopyObjectsBeforeShuffle (
89+ numPartitions : Int ,
90+ serializer : Serializer ): Boolean = {
91+ if (newOrdering.nonEmpty) {
92+ // If a new ordering is required, then records will be sorted with Spark's `ExternalSorter`,
93+ // which requires a defensive copy.
94+ true
95+ } else if (sortBasedShuffleOn) {
96+ // Spark's sort-based shuffle also uses `ExternalSorter` to buffer records in memory.
97+ // However, there are two special cases where we can avoid the copy, described below:
98+ if (numPartitions <= bypassMergeThreshold) {
99+ // If the number of output partitions is sufficiently small, then Spark will fall back to
100+ // the old hash-based shuffle write path which doesn't buffer deserialized records.
101+ // Note that we'll have to remove this case if we fix SPARK-6026 and remove this bypass.
102+ false
103+ } else if (serializeMapOutputs && serializer.supportsRelocationOfSerializedObjects) {
104+ // SPARK-4550 extended sort-based shuffle to serialize individual records prior to sorting
105+ // them. This optimization is guarded by a feature-flag and is only applied in cases where
106+ // shuffle dependency does not specify an ordering and the record serializer has certain
107+ // properties. If this optimization is enabled, we can safely avoid the copy.
108+ false
109+ } else {
110+ // None of the special cases held, so we must copy.
111+ true
112+ }
113+ } else {
114+ // We're using hash-based shuffle, so we don't need to copy.
115+ false
116+ }
117+ }
118+
68119 private val keyOrdering = {
69120 if (newOrdering.nonEmpty) {
70121 val key = newPartitioning.keyExpressions
@@ -81,7 +132,7 @@ case class Exchange(
81132
82133 @ transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf
83134
84- def serializer (
135+ private def getSerializer (
85136 keySchema : Array [DataType ],
86137 valueSchema : Array [DataType ],
87138 numPartitions : Int ): Serializer = {
@@ -123,17 +174,11 @@ case class Exchange(
123174 override def execute (): RDD [Row ] = attachTree(this , " execute" ) {
124175 newPartitioning match {
125176 case HashPartitioning (expressions, numPartitions) =>
126- // TODO: Eliminate redundant expressions in grouping key and value.
127- // This is a workaround for SPARK-4479. When:
128- // 1. sort based shuffle is on, and
129- // 2. the partition number is under the merge threshold, and
130- // 3. no ordering is required
131- // we can avoid the defensive copies to improve performance. In the long run, we probably
132- // want to include information in shuffle dependencies to indicate whether elements in the
133- // source RDD should be copied.
134- val willMergeSort = sortBasedShuffleOn && numPartitions > bypassMergeThreshold
135-
136- val rdd = if (willMergeSort || newOrdering.nonEmpty) {
177+ val keySchema = expressions.map(_.dataType).toArray
178+ val valueSchema = child.output.map(_.dataType).toArray
179+ val serializer = getSerializer(keySchema, valueSchema, numPartitions)
180+
181+ val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
137182 child.execute().mapPartitions { iter =>
138183 val hashExpressions = newMutableProjection(expressions, child.output)()
139184 iter.map(r => (hashExpressions(r).copy(), r.copy()))
@@ -152,14 +197,14 @@ case class Exchange(
152197 } else {
153198 new ShuffledRDD [Row , Row , Row ](rdd, part)
154199 }
155- val keySchema = expressions.map(_.dataType).toArray
156- val valueSchema = child.output.map(_.dataType).toArray
157- shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
158-
200+ shuffled.setSerializer(serializer)
159201 shuffled.map(_._2)
160202
161203 case RangePartitioning (sortingExpressions, numPartitions) =>
162- val rdd = if (sortBasedShuffleOn || newOrdering.nonEmpty) {
204+ val keySchema = child.output.map(_.dataType).toArray
205+ val serializer = getSerializer(keySchema, null , numPartitions)
206+
207+ val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions, serializer)) {
163208 child.execute().mapPartitions { iter => iter.map(row => (row.copy(), null ))}
164209 } else {
165210 child.execute().mapPartitions { iter =>
@@ -178,17 +223,14 @@ case class Exchange(
178223 } else {
179224 new ShuffledRDD [Row , Null , Null ](rdd, part)
180225 }
181- val keySchema = child.output.map(_.dataType).toArray
182- shuffled.setSerializer(serializer(keySchema, null , numPartitions))
183-
226+ shuffled.setSerializer(serializer)
184227 shuffled.map(_._1)
185228
186229 case SinglePartition =>
187- // SPARK-4479: Can't turn off defensive copy as what we do for `HashPartitioning`, since
188- // operators like `TakeOrdered` may require an ordering within the partition, and currently
189- // `SinglePartition` doesn't include ordering information.
190- // TODO Add `SingleOrderedPartition` for operators like `TakeOrdered`
191- val rdd = if (sortBasedShuffleOn) {
230+ val valueSchema = child.output.map(_.dataType).toArray
231+ val serializer = getSerializer(null , valueSchema, 1 )
232+
233+ val rdd = if (needToCopyObjectsBeforeShuffle(numPartitions = 1 , serializer)) {
192234 child.execute().mapPartitions { iter => iter.map(r => (null , r.copy())) }
193235 } else {
194236 child.execute().mapPartitions { iter =>
@@ -198,8 +240,7 @@ case class Exchange(
198240 }
199241 val partitioner = new HashPartitioner (1 )
200242 val shuffled = new ShuffledRDD [Null , Row , Row ](rdd, partitioner)
201- val valueSchema = child.output.map(_.dataType).toArray
202- shuffled.setSerializer(serializer(null , valueSchema, 1 ))
243+ shuffled.setSerializer(serializer)
203244 shuffled.map(_._2)
204245
205246 case _ => sys.error(s " Exchange not implemented for $newPartitioning" )
0 commit comments