1818package org .apache .spark .sql .execution
1919
2020import org .apache .spark .annotation .DeveloperApi
21- import org .apache .spark .{HashPartitioner , RangePartitioner , SparkConf }
21+ import org .apache .spark .shuffle .sort .SortShuffleManager
22+ import org .apache .spark .{SparkEnv , HashPartitioner , RangePartitioner , SparkConf }
2223import org .apache .spark .rdd .ShuffledRDD
2324import org .apache .spark .sql .{SQLContext , Row }
2425import org .apache .spark .sql .catalyst .errors .attachTree
@@ -37,6 +38,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
3738
3839 def output = child.output
3940
41+ /** We must copy rows when sort based shuffle is on */
42+ protected def sortBasedShuffleOn = SparkEnv .get.shuffleManager.isInstanceOf [SortShuffleManager ]
43+
4044 def execute () = attachTree(this , " execute" ) {
4145 newPartitioning match {
4246 case HashPartitioning (expressions, numPartitions) =>
@@ -45,8 +49,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
4549 @ transient val hashExpressions =
4650 newMutableProjection(expressions, child.output)()
4751
48- val mutablePair = new MutablePair [Row , Row ]()
49- iter.map(r => mutablePair.update(hashExpressions(r), r))
52+ if (sortBasedShuffleOn) {
53+ iter.map(r => (hashExpressions(r), r.copy()))
54+ } else {
55+ val mutablePair = new MutablePair [Row , Row ]()
56+ iter.map(r => mutablePair.update(hashExpressions(r), r))
57+ }
5058 }
5159 val part = new HashPartitioner (numPartitions)
5260 val shuffled = new ShuffledRDD [Row , Row , Row ](rdd, part)
@@ -58,8 +66,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
5866 implicit val ordering = new RowOrdering (sortingExpressions, child.output)
5967
6068 val rdd = child.execute().mapPartitions { iter =>
61- val mutablePair = new MutablePair [Row , Null ](null , null )
62- iter.map(row => mutablePair.update(row, null ))
69+ if (sortBasedShuffleOn) {
70+ iter.map(row => (row.copy(), null ))
71+ } else {
72+ val mutablePair = new MutablePair [Row , Null ](null , null )
73+ iter.map(row => mutablePair.update(row, null ))
74+ }
6375 }
6476 val part = new RangePartitioner (numPartitions, rdd, ascending = true )
6577 val shuffled = new ShuffledRDD [Row , Null , Null ](rdd, part)
@@ -69,8 +81,12 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
6981
7082 case SinglePartition =>
7183 val rdd = child.execute().mapPartitions { iter =>
72- val mutablePair = new MutablePair [Null , Row ]()
73- iter.map(r => mutablePair.update(null , r))
84+ if (sortBasedShuffleOn) {
85+ iter.map(r => (null , r.copy()))
86+ } else {
87+ val mutablePair = new MutablePair [Null , Row ]()
88+ iter.map(r => mutablePair.update(null , r))
89+ }
7490 }
7591 val partitioner = new HashPartitioner (1 )
7692 val shuffled = new ShuffledRDD [Null , Row , Row ](rdd, partitioner)
0 commit comments