Skip to content

Commit c92a281

Browse files
committed
Move sort into shuffle implementations
1 parent 2d25e34 commit c92a281

File tree

4 files changed

+28
-12
lines changed

4 files changed

+28
-12
lines changed

core/src/main/scala/org/apache/spark/Dependency.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,12 @@ class ShuffleDependency[K, V, C](
6262
val serializer: Option[Serializer] = None,
6363
val keyOrdering: Option[Ordering[K]] = None,
6464
val aggregator: Option[Aggregator[K, V, C]] = None,
65-
val mapSideCombine: Boolean = false)
65+
val mapSideCombine: Boolean = false,
66+
val ascending: Boolean = true)
6667
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
6768

69+
def isKeySorted = keyOrdering.isDefined
70+
6871
val shuffleId: Int = rdd.context.newShuffleId()
6972

7073
val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(

core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,8 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
5757
*/
5858
def sortByKey(ascending: Boolean = true, numPartitions: Int = self.partitions.size): RDD[P] = {
5959
val part = new RangePartitioner(numPartitions, self, ascending)
60-
val shuffled = new ShuffledRDD[K, V, V, P](self, part).setKeyOrdering(ordering)
61-
shuffled.mapPartitions(iter => {
62-
val buf = iter.toArray
63-
if (ascending) {
64-
buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
65-
} else {
66-
buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
67-
}
68-
}, preservesPartitioning = true)
60+
new ShuffledRDD[K, V, V, P](self, part)
61+
.setKeyOrdering(ordering)
62+
.setAscendingFlag(ascending)
6963
}
7064
}

core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
5151

5252
private var mapSideCombine: Boolean = false
5353

54+
private var ascending: Boolean = true
55+
5456
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
5557
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C, P] = {
5658
this.serializer = Option(serializer)
@@ -63,6 +65,11 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
6365
this
6466
}
6567

68+
def setAscendingFlag(ascending: Boolean): ShuffledRDD[K, V, C, P] = {
69+
this.ascending = ascending
70+
this
71+
}
72+
6673
/** Set aggregator for RDD's shuffle. */
6774
def setAggregator(aggregator: Aggregator[K, V, C]): ShuffledRDD[K, V, C, P] = {
6875
this.aggregator = Option(aggregator)
@@ -76,7 +83,8 @@ class ShuffledRDD[K, V, C, P <: Product2[K, C] : ClassTag](
7683
}
7784

7885
override def getDependencies: Seq[Dependency[_]] = {
79-
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
86+
List(new ShuffleDependency(prev, part, serializer,
87+
keyOrdering, aggregator, mapSideCombine, ascending))
8088
}
8189

8290
override val partitioner = Some(part)

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class HashShuffleReader[K, C](
3838
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
3939
Serializer.getSerializer(dep.serializer))
4040

41-
if (dep.aggregator.isDefined) {
41+
val aggregatedIter = if (dep.aggregator.isDefined) {
4242
if (dep.mapSideCombine) {
4343
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
4444
} else {
@@ -49,6 +49,17 @@ class HashShuffleReader[K, C](
4949
} else {
5050
iter
5151
}
52+
53+
dep.keyOrdering.map { ordering =>
54+
val buf = aggregatedIter.toArray
55+
if (dep.ascending) {
56+
buf.sortWith((x, y) => ordering.lt(x._1, y._1)).iterator
57+
} else {
58+
buf.sortWith((x, y) => ordering.gt(x._1, y._1)).iterator
59+
}
60+
}.getOrElse {
61+
aggregatedIter
62+
}
5263
}
5364

5465
/** Close this reader */

0 commit comments

Comments
 (0)