4444from pyspark .storagelevel import StorageLevel
4545from pyspark .resultiterable import ResultIterable
4646from pyspark .shuffle import Aggregator , InMemoryMerger , ExternalMerger , \
47- get_used_memory , ExternalSorter
47+ get_used_memory , ExternalSorter , ExternalGroupBy
4848
4949from py4j .java_collections import ListConverter , MapConverter
5050
@@ -201,71 +201,6 @@ def _replaceRoot(self, value):
201201 self ._sink (1 )
202202
203203
204- class SameKey (object ):
205- """
206- take the first few items which has the same expected key
207-
208- This is used by GroupByKey.
209- """
210- def __init__ (self , key , values , it , groupBy ):
211- self .key = key
212- self .values = values
213- self .it = it
214- self .groupBy = groupBy
215- self ._index = 0
216-
217- def __iter__ (self ):
218- return self
219-
220- def next (self ):
221- if self ._index >= len (self .values ):
222- if self .it is None :
223- raise StopIteration
224-
225- key , values = self .it .next ()
226- if key != self .key :
227- self .groupBy ._next_item = (key , values )
228- raise StopIteration
229- self .values = values
230- self ._index = 0
231-
232- self ._index += 1
233- return self .values [self ._index - 1 ]
234-
235-
236- class GroupByKey (object ):
237- """
238- group a sorted iterator into [(k1, it1), (k2, it2), ...]
239- """
240- def __init__ (self , it ):
241- self .it = iter (it )
242- self ._next_item = None
243- self .current = None
244-
245- def __iter__ (self ):
246- return self
247-
248- def next (self ):
249- if self ._next_item is None :
250- while True :
251- key , values = self .it .next ()
252- if self .current is None :
253- break
254- if key != self .current .key :
255- break
256- self .current .values .extend (values )
257-
258- else :
259- key , values = self ._next_item
260- self ._next_item = None
261-
262- if self .current is not None :
263- self .current .it = None
264- self .current = SameKey (key , values , self .it , self )
265-
266- return key , self .current
267-
268-
269204def _parse_memory (s ):
270205 """
271206 Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
@@ -1561,9 +1496,6 @@ def createZero():
15611496 def _can_spill (self ):
15621497 return self .ctx ._conf .get ("spark.shuffle.spill" , "True" ).lower () == "true"
15631498
1564- def _sort_based (self ):
1565- return self .ctx ._conf .get ("spark.shuffle.sort" , "False" ).lower () == "true"
1566-
15671499 def _memory_limit (self ):
15681500 return _parse_memory (self .ctx ._conf .get ("spark.python.worker.memory" , "512m" ))
15691501
@@ -1577,14 +1509,6 @@ def groupByKey(self, numPartitions=None):
15771509 sum or average) over each key, using reduceByKey will provide much
15781510 better performance.
15791511
1580- By default, it will use hash based aggregation, it can spill the items
1581- into disks when the memory can not hold all the items, but it still
1582- need to hold all the values for single key in memory.
1583-
1584- When spark.shuffle.sort is True, it will switch to sort based approach,
1585- then it can support single key with large number of values under small
1586- amount of memory. But it is slower than hash based approach.
1587-
15881512 >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
15891513 >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
15901514 [('a', [1, 1]), ('b', [1])]
@@ -1601,42 +1525,26 @@ def mergeCombiners(a, b):
16011525 return a
16021526
16031527 spill = self ._can_spill ()
1604- sort_based = self ._sort_based ()
1605- if sort_based and not spill :
1606- raise ValueError ("can not use sort based group when"
1607- " spark.executor.spill is false" )
16081528 memory = self ._memory_limit ()
16091529 serializer = self ._jrdd_deserializer
16101530 agg = Aggregator (createCombiner , mergeValue , mergeCombiners )
16111531
1612- def combineLocally (iterator ):
1532+ def combine (iterator ):
16131533 merger = ExternalMerger (agg , memory * 0.9 , serializer ) \
16141534 if spill else InMemoryMerger (agg )
16151535 merger .mergeValues (iterator )
16161536 return merger .iteritems ()
16171537
1618- # combine them before shuffle could reduce the comparison later
1619- locally_combined = self .mapPartitions (combineLocally )
1538+ locally_combined = self .mapPartitions (combine )
16201539 shuffled = locally_combined .partitionBy (numPartitions )
16211540
16221541 def groupByKey (it ):
1623- if sort_based :
1624- # Flatten the combined values, so it will not consume huge
1625- # memory during merging sort.
1626- ser = FlattedValuesSerializer (
1627- BatchedSerializer (PickleSerializer (), 1024 ), 10 )
1628- sorter = ExternalSorter (memory * 0.9 , ser )
1629- it = sorter .sorted (it , key = operator .itemgetter (0 ))
1630- return imap (lambda (k , v ): (k , ResultIterable (v )), GroupByKey (it ))
1631-
1632- else :
1633- # this is faster than sort based
1634- merger = ExternalMerger (agg , memory * 0.9 , serializer ) \
1635- if spill else InMemoryMerger (agg )
1636- merger .mergeCombiners (it )
1637- return merger .iteritems ()
1542+ merger = ExternalGroupBy (agg , memory , serializer )\
1543+ if spill else InMemoryMerger (agg )
1544+ merger .mergeCombiners (it )
1545+ return merger .iteritems ()
16381546
1639- return shuffled .mapPartitions (groupByKey )
1547+ return shuffled .mapPartitions (groupByKey ). mapValues ( ResultIterable )
16401548
16411549 # TODO: add tests
16421550 def flatMapValues (self , f ):
0 commit comments