Skip to content

Commit efa23df

Browse files
committed
refactor, add spark.shuffle.sort=False
1 parent 250be4e commit efa23df

File tree

1 file changed

+34
-17
lines changed

1 file changed

+34
-17
lines changed

python/pyspark/rdd.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -652,14 +652,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
652652
if numPartitions is None:
653653
numPartitions = self._defaultReducePartitions()
654654

655-
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
656-
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
655+
spill = self._can_spill()
656+
memory = self._memory_limit()
657657
serializer = self._jrdd_deserializer
658658

659659
def sortPartition(iterator):
660-
if spill:
661-
sorted = ExternalSorter(memory * 0.9, serializer).sorted
662-
return sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))
660+
sort = ExternalSorter(memory * 0.9, serializer).sorted if spill else sorted
661+
return sort(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))
663662

664663
if numPartitions == 1:
665664
if self.getNumPartitions() > 1:
@@ -1505,10 +1504,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
15051504
numPartitions = self._defaultReducePartitions()
15061505

15071506
serializer = self.ctx.serializer
1508-
spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower()
1509-
== 'true')
1510-
memory = _parse_memory(self.ctx._conf.get(
1511-
"spark.python.worker.memory", "512m"))
1507+
spill = self._can_spill()
1508+
memory = self._memory_limit()
15121509
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
15131510

15141511
def combineLocally(iterator):
@@ -1562,7 +1559,10 @@ def createZero():
15621559
return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
15631560

15641561
def _can_spill(self):
1565-
return (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
1562+
return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
1563+
1564+
def _sort_based(self):
1565+
return self.ctx._conf.get("spark.shuffle.sort", "False").lower() == "true"
15661566

15671567
def _memory_limit(self):
15681568
return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
@@ -1577,6 +1577,14 @@ def groupByKey(self, numPartitions=None):
15771577
sum or average) over each key, using reduceByKey will provide much
15781578
better performance.
15791579
1580+
By default, it will use hash based aggregation, it can spill the items
1581+
into disks when the memory can not hold all the items, but it still
1582+
need to hold all the values for single key in memory.
1583+
1584+
When spark.shuffle.sort is True, it will switch to sort based approach,
1585+
then it can support single key with large number of values under small
1586+
amount of memory. But it is slower than hash based approach.
1587+
15801588
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
15811589
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
15821590
[('a', [1, 1]), ('b', [1])]
@@ -1592,9 +1600,13 @@ def mergeCombiners(a, b):
15921600
a.extend(b)
15931601
return a
15941602

1595-
serializer = self._jrdd_deserializer
15961603
spill = self._can_spill()
1604+
sort_based = self._sort_based()
1605+
if sort_based and not spill:
1606+
raise ValueError("can not use sort based group when"
1607+
" spark.executor.spill is false")
15971608
memory = self._memory_limit()
1609+
serializer = self._jrdd_deserializer
15981610
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
15991611

16001612
def combineLocally(iterator):
@@ -1608,16 +1620,21 @@ def combineLocally(iterator):
16081620
shuffled = locally_combined.partitionBy(numPartitions)
16091621

16101622
def groupByKey(it):
1611-
if spill:
1623+
if sort_based:
16121624
# Flatten the combined values, so it will not consume huge
16131625
# memory during merging sort.
1614-
serializer = FlattedValuesSerializer(
1626+
ser = FlattedValuesSerializer(
16151627
BatchedSerializer(PickleSerializer(), 1024), 10)
1616-
sorted = ExternalSorter(memory * 0.9, serializer).sorted
1628+
sorter = ExternalSorter(memory * 0.9, ser)
1629+
it = sorter.sorted(it, key=operator.itemgetter(0))
1630+
return imap(lambda (k, v): ResultIterable(v), GroupByKey(it))
16171631

1618-
it = sorted(it, key=operator.itemgetter(0))
1619-
for k, v in GroupByKey(it):
1620-
yield k, ResultIterable(v)
1632+
else:
1633+
# this is faster than sort based
1634+
merger = ExternalMerger(agg, memory * 0.9, serializer) \
1635+
if spill else InMemoryMerger(agg)
1636+
merger.mergeCombiners(it)
1637+
return merger.iteritems()
16211638

16221639
return shuffled.mapPartitions(groupByKey)
16231640

0 commit comments

Comments
 (0)