Skip to content

Commit f147634

Browse files
committed
SPARK-2978. Transformation with MR shuffle semantics
1 parent eddfedd commit f147634

File tree

4 files changed

+58
-1
lines changed

4 files changed

+58
-1
lines changed

core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package org.apache.spark.rdd
1919

2020
import scala.reflect.ClassTag
2121

22-
import org.apache.spark.{Logging, RangePartitioner}
22+
import org.apache.spark.{Partitioner, Logging, RangePartitioner}
2323
import org.apache.spark.annotation.DeveloperApi
2424

2525
/**
@@ -64,4 +64,15 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
6464
new ShuffledRDD[K, V, V](self, part)
6565
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
6666
}
67+
68+
/**
69+
* Repartition the RDD according to the given partitioner and, within each resulting partition,
70+
* sort records by their keys.
71+
*/
72+
def repartitionAndSortWithinPartition(partitioner: Partitioner, ascending: Boolean = true)
73+
: RDD[(K, V)] = {
74+
new ShuffledRDD[K, V, V](self, partitioner)
75+
.setKeyOrdering(if (ascending) ordering else ordering.reverse)
76+
}
77+
6778
}

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,20 @@ class RDDSuite extends FunSuite with SharedSparkContext {
682682
assert(data.sortBy(parse, true, 2)(NameOrdering, classTag[Person]).collect() === nameOrdered)
683683
}
684684

685+
test("repartitionAndSortWithinPartitions") {
686+
val data = sc.parallelize(Seq((0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)), 2)
687+
688+
val partitioner = new Partitioner {
689+
def numPartitions: Int = 2
690+
def getPartition(key: Any): Int = key.asInstanceOf[Int] % 2
691+
}
692+
693+
val repartitioned = data.repartitionAndSortWithinPartition(partitioner)
694+
val partitions = repartitioned.glom().collect()
695+
assert(partitions(0) === Seq((0, 5), (0, 8), (2, 6)))
696+
assert(partitions(1) === Seq((1, 3), (3, 8), (3, 8)))
697+
}
698+
685699
test("intersection") {
686700
val all = sc.parallelize(1 to 10)
687701
val evens = sc.parallelize(2 to 10 by 2)

python/pyspark/rdd.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,6 +520,30 @@ def __add__(self, other):
520520
raise TypeError
521521
return self.union(other)
522522

523+
def repartitionAndSortWithinPartition(self, ascending=True, numPartitions=None,
524+
partitionFunc=portable_hash, keyfunc=lambda x: x):
525+
"""
526+
Repartition the RDD according to the given partitioner and, within each resulting partition,
527+
sort records by their keys.
528+
529+
>>> rdd = sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)])
530+
>>> rdd2 = rdd.repartitionAndSortWithinPartition(True, lambda x: x % 2, 2)
531+
>>> rdd2.glom().collect()
532+
[[(0, 5), (0, 8), (2, 6)], [(1, 3), (3, 8), (3, 8)]]
533+
"""
534+
if numPartitions is None:
535+
numPartitions = self._defaultReducePartitions()
536+
537+
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == "true")
538+
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
539+
serializer = self._jrdd_deserializer
540+
541+
def sortPartition(iterator):
542+
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
543+
return iter(sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending)))
544+
545+
return self.partitionBy(numPartitions, partitionFunc).mapPartitions(sortPartition, True)
546+
523547
def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
524548
"""
525549
Sorts this RDD, which is assumed to consist of (key, value) pairs.

python/pyspark/tests.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,14 @@ def test_histogram(self):
545545
self.assertEquals(([1, "b"], [5]), rdd.histogram(1))
546546
self.assertRaises(TypeError, lambda: rdd.histogram(2))
547547

548+
def test_repartitionAndSortWithinPartition(self):
549+
rdd = self.sc.parallelize([(0, 5), (3, 8), (2, 6), (0, 8), (3, 8), (1, 3)], 2)
550+
551+
repartitioned = rdd.repartitionAndSortWithinPartition(True, 2, lambda key: key % 2)
552+
partitions = repartitioned.glom().collect()
553+
self.assertEquals(partitions[0], [(0, 5), (0, 8), (2, 6)])
554+
self.assertEquals(partitions[1], [(1, 3), (3, 8), (3, 8)])
555+
548556

549557
class TestSQL(PySparkTestCase):
550558

0 commit comments

Comments
 (0)