@@ -247,7 +247,7 @@ def mergeValues(self, iterator):
247247 """ Combine the items by creator and combiner """
248248 # speedup attribute lookup
249249 creator , comb = self .agg .createCombiner , self .agg .mergeValue
250- c , data , pdata , hfun , batch = 0 , self .data , self .pdata , self ._partition , 100
250+ c , data , pdata , hfun , batch = 0 , self .data , self .pdata , self ._partition , self . batch
251251 limit = self .memory_limit
252252
253253 for k , v in iterator :
@@ -259,36 +259,40 @@ def mergeValues(self, iterator):
259259 if get_used_memory () >= limit :
260260 self ._spill ()
261261 limit = self ._next_limit ()
262+ batch /= 2
263+ c = 0
262264 else :
263- batch = min (batch * 2 , self .batch )
264- c = 0
265+ batch *= 1.5
265266
266267 def _partition (self , key ):
267268 """ Return the partition for key """
268269 return hash ((key , self ._seed )) % self .partitions
269270
271+ def _object_size (self , obj ):
272+ return 1
273+
270274 def mergeCombiners (self , iterator , limit = None ):
271275 """ Merge (K,V) pair by mergeCombiner """
272276 if limit is None :
273277 limit = self .memory_limit
274278 # speedup attribute lookup
275- comb , hfun = self .agg .mergeCombiners , self ._partition
276- c , data , pdata , batch = 0 , self .data , self .pdata , 1
279+ comb , hfun , objsize = self .agg .mergeCombiners , self ._partition , self . _object_size
280+ c , data , pdata , batch = 0 , self .data , self .pdata , self . batch
277281 for k , v in iterator :
278282 d = pdata [hfun (k )] if pdata else data
279283 d [k ] = comb (d [k ], v ) if k in d else v
280284 if not limit :
281285 continue
282286
283- c += 1
287+ c += objsize ( v )
284288 if c > batch :
285289 if get_used_memory () > limit :
286290 self ._spill ()
287291 limit = self ._next_limit ()
288- batch /= 4
292+ batch /= 2
293+ c = 0
289294 else :
290- batch = min (batch * 2 , self .batch )
291- c = 0
295+ batch *= 1.5
292296
293297 def _spill (self ):
294298 """
@@ -476,18 +480,42 @@ class SameKey(object):
476480 >>> l = zip(range(2), range(2))
477481 >>> list(SameKey(0, [1], iter(l), GroupByKey(iter([]))))
478482 [1, 0]
483+ >>> s = SameKey(0, [1], iter(l), GroupByKey(iter([])))
484+ >>> for i in range(2000):
485+ ... s.append(i)
486+ >>> len(list(s))
487+ 2002
479488 """
480489 def __init__ (self , key , values , it , groupBy ):
481490 self .key = key
482491 self .values = values
483492 self .it = it
484493 self .groupBy = groupBy
485- self ._index = 0
494+ self ._file = None
495+ self ._ser = None
496+ self ._index = None
486497
487498 def __iter__ (self ):
488499 return self
489500
490501 def next (self ):
502+ if self ._index is None :
503+ # begin of iterator
504+ if self ._file is not None :
505+ if self .values :
506+ self ._spill ()
507+ self ._file .flush ()
508+ self ._file .seek (0 )
509+ self ._index = 0
510+
511+ if self ._index >= len (self .values ) and self ._file is not None :
512+ try :
513+ self .values = next (self ._ser .load_stream (self ._file ))
514+ self ._index = 0
515+ except StopIteration :
516+ self ._file .close ()
517+ self ._file = None
518+
491519 if self ._index < len (self .values ):
492520 value = self .values [self ._index ]
493521 self ._index += 1
@@ -503,6 +531,29 @@ def next(self):
503531 raise StopIteration
504532 return value
505533
534+ def append (self , value ):
535+ if self ._index is not None :
536+ raise ValueError ("Can not append value while iterating" )
537+
538+ self .values .append (value )
539+ # dump them into disk if the key is huge
540+ if len (self .values ) >= 10240 :
541+ self ._spill ()
542+
543+ def _spill (self ):
544+ if self ._file is None :
545+ dirs = _get_local_dirs ("objects" )
546+ d = dirs [id (self ) % len (dirs )]
547+ if not os .path .exists (d ):
548+ os .makedirs (d )
549+ p = os .path .join (d , str (id ))
550+ self ._file = open (p , "w+" )
551+ self ._ser = CompressedSerializer (PickleSerializer ())
552+
553+ self ._ser .dump_stream ([self .values ], self ._file )
554+ self .values = []
555+ gc .collect ()
556+
506557
507558class GroupByKey (object ):
508559 """
@@ -528,16 +579,12 @@ def next(self):
528579 key , value = self .it .next ()
529580 if self .current is None or key != self .current .key :
530581 break
531- self .current .values .append (value )
532-
582+ self .current .append (value )
533583 else :
534584 key , value = self ._next_item
535585 self ._next_item = None
536586
537- if self .current is not None :
538- self .current .it = None
539587 self .current = SameKey (key , [value ], self .it , self )
540-
541588 return key , (v for vs in self .current for v in vs )
542589
543590
@@ -557,6 +604,9 @@ def _flatted_serializer(self):
557604 ser = FlattedValuesSerializer (ser , 20 )
558605 return ser
559606
607+ def _object_size (self , obj ):
608+ return len (obj )
609+
560610 def _spill (self ):
561611 """
562612 dump already partitioned data into disks.
@@ -615,8 +665,9 @@ def _merged_items(self, index, limit=0):
615665 size = sum (os .path .getsize (os .path .join (self ._get_spill_dir (j ), str (index )))
616666 for j in range (self .spills ))
617667 # if the memory can not hold all the partition,
618- # then use sort based merge
619- if (size >> 20 ) > self .memory_limit / 2 :
668+ # then use sort based merge. Because of compression,
669+ # the data on disks will be much smaller than needed memory
670+ if (size >> 20 ) > self .memory_limit / 10 :
620671 return self ._sorted_items (index )
621672
622673 self .data = {}
0 commit comments