Skip to content

Commit 0a7ef63

Browse files
daviespwendell
authored andcommitted
[SPARK-3141] [PySpark] fix sortByKey() with take()
Fix sortByKey() with take() The function `f` used in mapPartitions should always return an iterator. Author: Davies Liu <[email protected]> Closes #2045 from davies/fix_sortbykey and squashes the following commits: 1160f59 [Davies Liu] fix sortByKey() with take()
1 parent 8a74e4b commit 0a7ef63

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

python/pyspark/rdd.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
575575
# noqa
576576
577577
>>> tmp = [('a', 1), ('b', 2), ('1', 3), ('d', 4), ('2', 5)]
578+
>>> sc.parallelize(tmp).sortByKey().first()
579+
('1', 3)
578580
>>> sc.parallelize(tmp).sortByKey(True, 1).collect()
579581
[('1', 3), ('2', 5), ('a', 1), ('b', 2), ('d', 4)]
580582
>>> sc.parallelize(tmp).sortByKey(True, 2).collect()
@@ -587,14 +589,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
587589
if numPartitions is None:
588590
numPartitions = self._defaultReducePartitions()
589591

592+
def sortPartition(iterator):
593+
return iter(sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=not ascending))
594+
590595
if numPartitions == 1:
591596
if self.getNumPartitions() > 1:
592597
self = self.coalesce(1)
593-
594-
def sort(iterator):
595-
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
596-
597-
return self.mapPartitions(sort)
598+
return self.mapPartitions(sortPartition)
598599

599600
# first compute the boundary of each part via sampling: we want to partition
600601
# the key-space into bins such that the bins have roughly the same
@@ -610,17 +611,14 @@ def sort(iterator):
610611
bounds = [samples[len(samples) * (i + 1) / numPartitions]
611612
for i in range(0, numPartitions - 1)]
612613

613-
def rangePartitionFunc(k):
614+
def rangePartitioner(k):
614615
p = bisect.bisect_left(bounds, keyfunc(k))
615616
if ascending:
616617
return p
617618
else:
618619
return numPartitions - 1 - p
619620

620-
def mapFunc(iterator):
621-
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))
622-
623-
return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
621+
return self.partitionBy(numPartitions, rangePartitioner).mapPartitions(sortPartition, True)
624622

625623
def sortBy(self, keyfunc, ascending=True, numPartitions=None):
626624
"""

0 commit comments

Comments
 (0)