@@ -652,14 +652,13 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
652652 if numPartitions is None :
653653 numPartitions = self ._defaultReducePartitions ()
654654
655- spill = ( self .ctx . _conf . get ( "spark.shuffle.spill" , 'True' ). lower () == 'true' )
656- memory = _parse_memory ( self .ctx . _conf . get ( "spark.python.worker.memory" , "512m" ) )
655+ spill = self ._can_spill ( )
656+ memory = self ._memory_limit ( )
657657 serializer = self ._jrdd_deserializer
658658
659659 def sortPartition (iterator ):
660- if spill :
661- sorted = ExternalSorter (memory * 0.9 , serializer ).sorted
662- return sorted (iterator , key = lambda (k , v ): keyfunc (k ), reverse = (not ascending ))
660+ sort = ExternalSorter (memory * 0.9 , serializer ).sorted if spill else sorted
661+ return sort (iterator , key = lambda (k , v ): keyfunc (k ), reverse = (not ascending ))
663662
664663 if numPartitions == 1 :
665664 if self .getNumPartitions () > 1 :
@@ -1505,10 +1504,8 @@ def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
15051504 numPartitions = self ._defaultReducePartitions ()
15061505
15071506 serializer = self .ctx .serializer
1508- spill = (self .ctx ._conf .get ("spark.shuffle.spill" , 'True' ).lower ()
1509- == 'true' )
1510- memory = _parse_memory (self .ctx ._conf .get (
1511- "spark.python.worker.memory" , "512m" ))
1507+ spill = self ._can_spill ()
1508+ memory = self ._memory_limit ()
15121509 agg = Aggregator (createCombiner , mergeValue , mergeCombiners )
15131510
15141511 def combineLocally (iterator ):
@@ -1562,7 +1559,10 @@ def createZero():
15621559 return self .combineByKey (lambda v : func (createZero (), v ), func , func , numPartitions )
15631560
15641561 def _can_spill (self ):
1565- return (self .ctx ._conf .get ("spark.shuffle.spill" , 'True' ).lower () == 'true' )
1562+ return self .ctx ._conf .get ("spark.shuffle.spill" , "True" ).lower () == "true"
1563+
1564+ def _sort_based (self ):
1565+ return self .ctx ._conf .get ("spark.shuffle.sort" , "False" ).lower () == "true"
15661566
15671567 def _memory_limit (self ):
15681568 return _parse_memory (self .ctx ._conf .get ("spark.python.worker.memory" , "512m" ))
@@ -1577,6 +1577,14 @@ def groupByKey(self, numPartitions=None):
15771577 sum or average) over each key, using reduceByKey will provide much
15781578 better performance.
15791579
1580+ By default, it will use hash based aggregation, it can spill the items
1581+ into disks when the memory can not hold all the items, but it still
1582+ need to hold all the values for single key in memory.
1583+
1584+ When spark.shuffle.sort is True, it will switch to sort based approach,
1585+ then it can support single key with large number of values under small
1586+ amount of memory. But it is slower than hash based approach.
1587+
15801588 >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
15811589 >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
15821590 [('a', [1, 1]), ('b', [1])]
@@ -1592,9 +1600,13 @@ def mergeCombiners(a, b):
15921600 a .extend (b )
15931601 return a
15941602
1595- serializer = self ._jrdd_deserializer
15961603 spill = self ._can_spill ()
1604+ sort_based = self ._sort_based ()
1605+ if sort_based and not spill :
1606+ raise ValueError ("can not use sort based group when"
1607+ " spark.executor.spill is false" )
15971608 memory = self ._memory_limit ()
1609+ serializer = self ._jrdd_deserializer
15981610 agg = Aggregator (createCombiner , mergeValue , mergeCombiners )
15991611
16001612 def combineLocally (iterator ):
@@ -1608,16 +1620,21 @@ def combineLocally(iterator):
16081620 shuffled = locally_combined .partitionBy (numPartitions )
16091621
16101622 def groupByKey (it ):
1611- if spill :
1623+ if sort_based :
16121624 # Flatten the combined values, so it will not consume huge
16131625 # memory during merging sort.
1614- serializer = FlattedValuesSerializer (
1626+ ser = FlattedValuesSerializer (
16151627 BatchedSerializer (PickleSerializer (), 1024 ), 10 )
1616- sorted = ExternalSorter (memory * 0.9 , serializer ).sorted
1628+ sorter = ExternalSorter (memory * 0.9 , ser )
1629+ it = sorter .sorted (it , key = operator .itemgetter (0 ))
1630+ return imap (lambda (k , v ): ResultIterable (v ), GroupByKey (it ))
16171631
1618- it = sorted (it , key = operator .itemgetter (0 ))
1619- for k , v in GroupByKey (it ):
1620- yield k , ResultIterable (v )
1632+ else :
1633+ # this is faster than sort based
1634+ merger = ExternalMerger (agg , memory * 0.9 , serializer ) \
1635+ if spill else InMemoryMerger (agg )
1636+ merger .mergeCombiners (it )
1637+ return merger .iteritems ()
16211638
16221639 return shuffled .mapPartitions (groupByKey )
16231640
0 commit comments