3333try :
3434 import psutil
3535
36+ process = None
37+
3638 def get_used_memory ():
3739 """ Return the used memory in MB """
38- process = psutil .Process (os .getpid ())
40+ global process
41+ if process is None or process ._pid != os .getpid ():
42+ process = psutil .Process (os .getpid ())
3943 if hasattr (process , "memory_info" ):
4044 info = process .memory_info ()
4145 else :
4246 info = process .get_memory_info ()
4347 return info .rss >> 20
48+
4449except ImportError :
4550
4651 def get_used_memory ():
@@ -49,6 +54,7 @@ def get_used_memory():
4954 for line in open ('/proc/self/status' ):
5055 if line .startswith ('VmRSS:' ):
5156 return int (line .split ()[1 ]) >> 10
57+
5258 else :
5359 warnings .warn ("Please install psutil to have better "
5460 "support with spilling" )
@@ -57,6 +63,7 @@ def get_used_memory():
5763 rss = resource .getrusage (resource .RUSAGE_SELF ).ru_maxrss
5864 return rss >> 20
5965 # TODO: support windows
66+
6067 return 0
6168
6269
@@ -146,7 +153,7 @@ def mergeCombiners(self, iterator):
146153 d [k ] = comb (d [k ], v ) if k in d else v
147154
148155 def iteritems (self ):
149- """ Return the merged items ad iterator """
156+ """ Return the merged items as iterator """
150157 return self .data .iteritems ()
151158
152159
@@ -210,18 +217,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
210217 localdirs = None , scale = 1 , partitions = 59 , batch = 1000 ):
211218 Merger .__init__ (self , aggregator )
212219 self .memory_limit = memory_limit
213- # default serializer is only used for tests
214- self .serializer = serializer or BatchedSerializer (PickleSerializer (), 1024 )
215- # add compression
216- if isinstance (self .serializer , BatchedSerializer ):
217- if not isinstance (self .serializer .serializer , CompressedSerializer ):
218- self .serializer = BatchedSerializer (
219- CompressedSerializer (self .serializer .serializer ),
220- self .serializer .batchSize )
221- else :
222- if not isinstance (self .serializer , CompressedSerializer ):
223- self .serializer = CompressedSerializer (self .serializer )
224-
220+ self .serializer = self ._compressed_serializer (serializer )
225221 self .localdirs = localdirs or _get_local_dirs (str (id (self )))
226222 # number of partitions when spill data into disks
227223 self .partitions = partitions
@@ -238,6 +234,18 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
238234 # randomize the hash of key, id(o) is the address of o (aligned by 8)
239235 self ._seed = id (self ) + 7
240236
237+ def _compressed_serializer (self , serializer = None ):
238+ # default serializer is only used for tests
239+ ser = serializer or PickleSerializer ()
240+ # add compression
241+ if isinstance (ser , BatchedSerializer ):
242+ if not isinstance (ser .serializer , CompressedSerializer ):
243+ ser = BatchedSerializer (CompressedSerializer (ser .serializer ), ser .batchSize )
244+ else :
245+ if not isinstance (ser , CompressedSerializer ):
246+ ser = BatchedSerializer (CompressedSerializer (ser ), 1024 )
247+ return ser
248+
241249 def _get_spill_dir (self , n ):
242250 """ Choose one directory for spill by number n """
243251 return os .path .join (self .localdirs [n % len (self .localdirs )], str (n ))
@@ -276,6 +284,9 @@ def _partition(self, key):
276284 return hash ((key , self ._seed )) % self .partitions
277285
278286 def _object_size (self , obj ):
287+ """ How much of memory for this obj, assume that all the objects
288+ consume similar bytes of memory
289+ """
279290 return 1
280291
281292 def mergeCombiners (self , iterator , limit = None ):
@@ -485,18 +496,18 @@ class SameKey(object):
485496 This is used by GroupByKey.
486497
487498 >>> l = zip(range(2), range(2))
488- >>> list(SameKey(0, [1] , iter(l), GroupByKey(iter([]))))
499+ >>> list(SameKey(0, 1 , iter(l), GroupByKey(iter([]))))
489500 [1, 0]
490- >>> s = SameKey(0, [1] , iter(l), GroupByKey(iter([])))
501+ >>> s = SameKey(0, 1 , iter(l), GroupByKey(iter([])))
491502 >>> for i in range(2000):
492503 ... s.append(i)
493504 >>> len(list(s))
494505 2002
495506 """
496- def __init__ (self , key , values , it , groupBy ):
507+ def __init__ (self , key , value , iterator , groupBy ):
497508 self .key = key
498- self .values = values
499- self .it = it
509+ self .values = [ value ]
510+ self .iterator = iterator
500511 self .groupBy = groupBy
501512 self ._file = None
502513 self ._ser = None
@@ -516,27 +527,22 @@ def next(self):
516527 self ._index = 0
517528
518529 if self ._index >= len (self .values ) and self ._file is not None :
519- try :
520- self .values = next (self ._ser .load_stream (self ._file ))
521- self ._index = 0
522- except StopIteration :
523- self ._file .close ()
524- self ._file = None
530+ # load next chunk of values from disk
531+ self .values = next (self ._ser .load_stream (self ._file ))
532+ self ._index = 0
525533
526534 if self ._index < len (self .values ):
527535 value = self .values [self ._index ]
528536 self ._index += 1
529537 return value
530538
531- if self .it is None :
532- raise StopIteration
539+ key , value = next (self .iterator )
540+ if key == self .key :
541+ return value
533542
534- key , value = self .it .next ()
535- if key != self .key :
536- self .groupBy ._next_item = (key , value )
537- self .it = None
538- raise StopIteration
539- return value
543+ # push them back into groupBy
544+ self .groupBy .next_item = (key , value )
545+ raise StopIteration
540546
541547 def append (self , value ):
542548 if self ._index is not None :
@@ -548,13 +554,14 @@ def append(self, value):
548554 self ._spill ()
549555
550556 def _spill (self ):
557+ """ dump the values into disk """
551558 if self ._file is None :
552559 dirs = _get_local_dirs ("objects" )
553560 d = dirs [id (self ) % len (dirs )]
554561 if not os .path .exists (d ):
555562 os .makedirs (d )
556563 p = os .path .join (d , str (id ))
557- self ._file = open (p , "w+" )
564+ self ._file = open (p , "w+" , 65536 )
558565 self ._ser = CompressedSerializer (PickleSerializer ())
559566
560567 self ._ser .dump_stream ([self .values ], self ._file )
@@ -567,32 +574,34 @@ class GroupByKey(object):
567574 group a sorted iterator into [(k1, it1), (k2, it2), ...]
568575
569576 >>> k = [i/3 for i in range(6)]
570- >>> v = [[i] for i in range(6)]
577+ >>> v = [i for i in range(6)]
571578 >>> g = GroupByKey(iter(zip(k, v)))
572579 >>> [(k, list(it)) for k, it in g]
573580 [(0, [0, 1, 2]), (1, [3, 4, 5])]
574581 """
575- def __init__ (self , it ):
576- self .it = iter ( it )
577- self ._next_item = None
582+ def __init__ (self , iterator ):
583+ self .iterator = iterator
584+ self .next_item = None
578585 self .current = None
579586
580587 def __iter__ (self ):
581588 return self
582589
583590 def next (self ):
584- if self ._next_item is None :
591+ if self .next_item is None :
585592 while True :
586- key , value = self . it . next ()
593+ key , value = next (self . iterator )
587594 if self .current is None or key != self .current .key :
588595 break
596+ # the current key has not been visited.
589597 self .current .append (value )
590598 else :
591- key , value = self ._next_item
592- self ._next_item = None
599+ # next key was popped while visiting current key
600+ key , value = self .next_item
601+ self .next_item = None
593602
594- self .current = SameKey (key , [ value ] , self .it , self )
595- return key , ( v for vs in self .current for v in vs )
603+ self .current = SameKey (key , value , self .iterator , self )
604+ return key , self .current
596605
597606
598607class ExternalGroupBy (ExternalMerger ):
@@ -624,7 +633,7 @@ def _spill(self):
624633
625634 if not self .pdata :
626635 # The data has not been partitioned, it will iterator the
627- # dataset once, write them into different files, has no
636+ # data once, write them into different files, has no
628637 # additional memory. It only called when the memory goes
629638 # above limit at the first time.
630639
@@ -636,12 +645,10 @@ def _spill(self):
636645 # sort them before dumping into disks
637646 self ._sorted = len (self .data ) < self .SORT_KEY_LIMIT
638647 if self ._sorted :
639- ser = self ._flatted_serializer ()
648+ self . serializer = self ._flatted_serializer ()
640649 for k in sorted (self .data .keys ()):
641- v = self .data [k ]
642650 h = self ._partition (k )
643- ser .dump_stream ([(k , v )], streams [h ])
644- self .serializer = ser
651+ self .serializer .dump_stream ([(k , self .data [k ])], streams [h ])
645652 else :
646653 for k , v in self .data .iteritems ():
647654 h = self ._partition (k )
@@ -651,6 +658,7 @@ def _spill(self):
651658 s .close ()
652659
653660 self .data .clear ()
661+ # self.pdata is cached in `mergeValues` and `mergeCombiners`
654662 self .pdata .extend ([{} for i in range (self .partitions )])
655663
656664 else :
@@ -659,8 +667,9 @@ def _spill(self):
659667 with open (p , "w" ) as f :
660668 # dump items in batch
661669 if self ._sorted :
662- self .serializer .dump_stream (
663- sorted (self .pdata [i ].iteritems (), key = operator .itemgetter (0 )), f )
670+ # sort by key only (stable)
671+ sorted_items = sorted (self .pdata [i ].iteritems (), key = operator .itemgetter (0 ))
672+ self .serializer .dump_stream (sorted_items , f )
664673 else :
665674 self .serializer .dump_stream (self .pdata [i ].iteritems (), f )
666675 self .pdata [i ].clear ()
@@ -706,75 +715,7 @@ def load_partition(j):
706715 sorted_items = sorter .sorted (itertools .chain (* disk_items ),
707716 key = operator .itemgetter (0 ))
708717
709- return GroupByKey (sorted_items )
710-
711-
712- class ExternalSorter (object ):
713- """
714- ExtenalSorter will divide the elements into chunks, sort them in
715- memory and dump them into disks, finally merge them back.
716-
717- The spilling will only happen when the used memory goes above
718- the limit.
719-
720- >>> sorter = ExternalSorter(1) # 1M
721- >>> import random
722- >>> l = range(1024)
723- >>> random.shuffle(l)
724- >>> sorted(l) == list(sorter.sorted(l))
725- True
726- >>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True))
727- True
728- """
729- def __init__ (self , memory_limit , serializer = None ):
730- self .memory_limit = memory_limit
731- self .local_dirs = _get_local_dirs ("sort" )
732- self .serializer = serializer or BatchedSerializer (PickleSerializer (), 1024 )
733- self ._spilled_bytes = 0
734-
735- def _get_path (self , n ):
736- """ Choose one directory for spill by number n """
737- d = self .local_dirs [n % len (self .local_dirs )]
738- if not os .path .exists (d ):
739- os .makedirs (d )
740- return os .path .join (d , str (n ))
741-
742- def sorted (self , iterator , key = None , reverse = False ):
743- """
744- Sort the elements in iterator, do external sort when the memory
745- goes above the limit.
746- """
747- batch = 10
748- chunks , current_chunk = [], []
749- iterator = iter (iterator )
750- while True :
751- # pick elements in batch
752- chunk = list (itertools .islice (iterator , batch ))
753- current_chunk .extend (chunk )
754- if len (chunk ) < batch :
755- break
756-
757- if get_used_memory () > self .memory_limit :
758- # sort them inplace will save memory
759- current_chunk .sort (key = key , reverse = reverse )
760- path = self ._get_path (len (chunks ))
761- with open (path , 'w' ) as f :
762- self .serializer .dump_stream (current_chunk , f )
763- self ._spilled_bytes += os .path .getsize (path )
764- chunks .append (self .serializer .load_stream (open (path )))
765- current_chunk = []
766-
767- elif not chunks :
768- batch = min (batch * 2 , 10000 )
769-
770- current_chunk .sort (key = key , reverse = reverse )
771- if not chunks :
772- return current_chunk
773-
774- if current_chunk :
775- chunks .append (iter (current_chunk ))
776-
777- return heapq .merge (chunks , key = key , reverse = reverse )
718+ return ((k , itertools .chain .from_iterable (vs )) for k , vs in GroupByKey (sorted_items ))
778719
779720
780721if __name__ == "__main__" :
0 commit comments