@@ -218,19 +218,19 @@ def __iter__(self):
218218 return self
219219
220220 def next (self ):
221- if self ._index < len (self .values ):
222- self ._index += 1
223- return self . values [ self . _index - 1 ]
221+ if self ._index >= len (self .values ):
222+ if self .it is None :
223+ raise StopIteration
224224
225- if self .it is None :
226- raise StopIteration
225+ key , values = self .it .next ()
226+ if key != self .key :
227+ self .groupBy ._next_item = (key , values )
228+ raise StopIteration
229+ self .values = values
230+ self ._index = 0
227231
228- key , value = self .it .next ()
229- if key == self .key :
230- return value
231-
232- self .groupBy ._next_item = (key , value )
233- raise StopIteration
232+ self ._index += 1
233+ return self .values [self ._index - 1 ]
234234
235235
236236class GroupByKey (object ):
@@ -248,20 +248,20 @@ def __iter__(self):
248248 def next (self ):
249249 if self ._next_item is None :
250250 while True :
251- key , value = self .it .next ()
251+ key , values = self .it .next ()
252252 if self .current is None :
253253 break
254254 if key != self .current .key :
255255 break
256- self .current .values .append ( value )
256+ self .current .values .extend ( values )
257257
258258 else :
259- key , value = self ._next_item
259+ key , values = self ._next_item
260260 self ._next_item = None
261261
262262 if self .current is not None :
263263 self .current .it = None
264- self .current = SameKey (key , [ value ] , self .it , self )
264+ self .current = SameKey (key , values , self .it , self )
265265
266266 return key , self .current
267267
@@ -1581,9 +1581,30 @@ def groupByKey(self, numPartitions=None):
15811581 >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
15821582 [('a', [1, 1]), ('b', [1])]
15831583 """
1584+ def createCombiner (x ):
1585+ return [x ]
1586+
1587+ def mergeValue (xs , x ):
1588+ xs .append (x )
1589+ return xs
1590+
1591+ def mergeCombiners (a , b ):
1592+ a .extend (b )
1593+ return a
1594+
15841595 serializer = self .ctx .serializer
15851596 spill = self ._can_spill ()
15861597 memory = self ._memory_limit ()
1598+ agg = Aggregator (createCombiner , mergeValue , mergeCombiners )
1599+
1600+ def combineLocally (iterator ):
1601+ merger = ExternalMerger (agg , memory * 0.9 , serializer ) \
1602+ if spill else InMemoryMerger (agg )
1603+ merger .mergeValues (iterator )
1604+ return merger .iteritems ()
1605+
1606+ locally_combined = self .mapPartitions (combineLocally )
1607+ shuffled = locally_combined .partitionBy (numPartitions )
15871608
15881609 def groupByKey (it ):
15891610 if spill :
@@ -1592,8 +1613,6 @@ def groupByKey(it):
15921613 for k , v in GroupByKey (it ):
15931614 yield k , ResultIterable (v )
15941615
1595- # TODO: combine before shuffle ?
1596- shuffled = self .partitionBy (numPartitions )
15971616 return shuffled .mapPartitions (groupByKey )
15981617
15991618 # TODO: add tests
0 commit comments