Skip to content

Commit a2e658d

Browse files
committed
[SPARK-2967][SQL] Fix sort based shuffle for spark sql.
Add explicit row copies when sort based shuffle is on. Author: Michael Armbrust <[email protected]> Closes #2066 from marmbrus/sortShuffle and squashes the following commits: fcd7bb2 [Michael Armbrust] Fix sort based shuffle for spark sql.
1 parent fb60bec commit a2e658d

File tree

1 file changed

+23
-7
lines changed

1 file changed

+23
-7
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
package org.apache.spark.sql.execution
1919

2020
import org.apache.spark.annotation.DeveloperApi
21-
import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf}
21+
import org.apache.spark.shuffle.sort.SortShuffleManager
22+
import org.apache.spark.{SparkEnv, HashPartitioner, RangePartitioner, SparkConf}
2223
import org.apache.spark.rdd.ShuffledRDD
2324
import org.apache.spark.sql.{SQLContext, Row}
2425
import org.apache.spark.sql.catalyst.errors.attachTree
@@ -37,6 +38,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
3738

3839
def output = child.output
3940

41+
/** We must copy rows when sort based shuffle is on */
42+
protected def sortBasedShuffleOn = SparkEnv.get.shuffleManager.isInstanceOf[SortShuffleManager]
43+
4044
def execute() = attachTree(this , "execute") {
4145
newPartitioning match {
4246
case HashPartitioning(expressions, numPartitions) =>
@@ -45,8 +49,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
4549
@transient val hashExpressions =
4650
newMutableProjection(expressions, child.output)()
4751

48-
val mutablePair = new MutablePair[Row, Row]()
49-
iter.map(r => mutablePair.update(hashExpressions(r), r))
52+
if (sortBasedShuffleOn) {
53+
iter.map(r => (hashExpressions(r), r.copy()))
54+
} else {
55+
val mutablePair = new MutablePair[Row, Row]()
56+
iter.map(r => mutablePair.update(hashExpressions(r), r))
57+
}
5058
}
5159
val part = new HashPartitioner(numPartitions)
5260
val shuffled = new ShuffledRDD[Row, Row, Row](rdd, part)
@@ -58,8 +66,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
5866
implicit val ordering = new RowOrdering(sortingExpressions, child.output)
5967

6068
val rdd = child.execute().mapPartitions { iter =>
61-
val mutablePair = new MutablePair[Row, Null](null, null)
62-
iter.map(row => mutablePair.update(row, null))
69+
if (sortBasedShuffleOn) {
70+
iter.map(row => (row.copy(), null))
71+
} else {
72+
val mutablePair = new MutablePair[Row, Null](null, null)
73+
iter.map(row => mutablePair.update(row, null))
74+
}
6375
}
6476
val part = new RangePartitioner(numPartitions, rdd, ascending = true)
6577
val shuffled = new ShuffledRDD[Row, Null, Null](rdd, part)
@@ -69,8 +81,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
6981

7082
case SinglePartition =>
7183
val rdd = child.execute().mapPartitions { iter =>
72-
val mutablePair = new MutablePair[Null, Row]()
73-
iter.map(r => mutablePair.update(null, r))
84+
if (sortBasedShuffleOn) {
85+
iter.map(r => (null, r.copy()))
86+
} else {
87+
val mutablePair = new MutablePair[Null, Row]()
88+
iter.map(r => mutablePair.update(null, r))
89+
}
7490
}
7591
val partitioner = new HashPartitioner(1)
7692
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)

0 commit comments

Comments
 (0)