@@ -78,6 +78,11 @@ def _get_local_dirs(sub):
7878 return [os .path .join (d , "python" , str (os .getpid ()), sub ) for d in dirs ]
7979
8080
81+ # global stats
82+ MemoryBytesSpilled = 0L
83+ DiskBytesSpilled = 0L
84+
85+
8186class Aggregator (object ):
8287
8388 """
@@ -318,10 +323,12 @@ def _spill(self):
318323
319324 It will dump the data in batch for better performance.
320325 """
326+ global MemoryBytesSpilled , DiskBytesSpilled
321327 path = self ._get_spill_dir (self .spills )
322328 if not os .path .exists (path ):
323329 os .makedirs (path )
324330
331+ used_memory = get_used_memory ()
325332 if not self .pdata :
326333 # The data has not been partitioned, it will iterator the
327334 # dataset once, write them into different files, has no
@@ -339,6 +346,7 @@ def _spill(self):
339346 self .serializer .dump_stream ([(k , v )], streams [h ])
340347
341348 for s in streams :
349+ DiskBytesSpilled += s .tell ()
342350 s .close ()
343351
344352 self .data .clear ()
@@ -351,9 +359,11 @@ def _spill(self):
351359 # dump items in batch
352360 self .serializer .dump_stream (self .pdata [i ].iteritems (), f )
353361 self .pdata [i ].clear ()
362+ DiskBytesSpilled += os .path .getsize (p )
354363
355364 self .spills += 1
356365 gc .collect () # release the memory as much as possible
366+ MemoryBytesSpilled += (used_memory - get_used_memory ()) << 20
357367
358368 def iteritems (self ):
359369 """ Return all merged items as iterator """
@@ -454,7 +464,6 @@ def __init__(self, memory_limit, serializer=None):
454464 self .local_dirs = _get_local_dirs ("sort" )
455465 self .serializer = serializer or BatchedSerializer (
456466 CompressedSerializer (PickleSerializer ()), 1024 )
457- self ._spilled_bytes = 0
458467
459468 def _get_path (self , n ):
460469 """ Choose one directory for spill by number n """
@@ -468,6 +477,7 @@ def sorted(self, iterator, key=None, reverse=False):
468477 Sort the elements in iterator, do external sort when the memory
469478 goes above the limit.
470479 """
480+ global MemoryBytesSpilled , DiskBytesSpilled
471481 batch = 10
472482 chunks , current_chunk = [], []
473483 iterator = iter (iterator )
@@ -478,17 +488,19 @@ def sorted(self, iterator, key=None, reverse=False):
478488 if len (chunk ) < batch :
479489 break
480490
481- if get_used_memory () > self .memory_limit :
491+ used_memory = get_used_memory ()
492+ if used_memory > self .memory_limit :
482493 # sort them inplace will save memory
483494 current_chunk .sort (key = key , reverse = reverse )
484495 path = self ._get_path (len (chunks ))
485496 with open (path , 'w' ) as f :
486497 self .serializer .dump_stream (current_chunk , f )
487- self ._spilled_bytes += os .path .getsize (path )
488498 chunks .append (self .serializer .load_stream (open (path )))
489499 os .unlink (path ) # data will be deleted after close
490500 current_chunk = []
491501 gc .collect ()
502+ MemoryBytesSpilled += (used_memory - get_used_memory ()) << 20
503+ DiskBytesSpilled += os .path .getsize (path )
492504
493505 elif not chunks :
494506 batch = min (batch * 2 , 10000 )
@@ -569,6 +581,7 @@ def append(self, value):
569581
570582 def _spill (self ):
571583 """ dump the values into disk """
584+ global MemoryBytesSpilled , DiskBytesSpilled
572585 if self ._file is None :
573586 dirs = _get_local_dirs ("objects" )
574587 d = dirs [id (self ) % len (dirs )]
@@ -578,9 +591,13 @@ def _spill(self):
578591 self ._file = open (p , "w+" , 65536 )
579592 self ._ser = CompressedSerializer (PickleSerializer ())
580593
594+ used_memory = get_used_memory ()
595+ pos = self ._file .tell ()
581596 self ._ser .dump_stream ([self .values ], self ._file )
597+ DiskBytesSpilled += self ._file .tell () - pos
582598 self .values = []
583599 gc .collect ()
600+ MemoryBytesSpilled += (used_memory - get_used_memory ()) << 20
584601
585602
586603class GroupByKey (object ):
@@ -641,10 +658,12 @@ def _spill(self):
641658 """
642659 dump already partitioned data into disks.
643660 """
661+ global MemoryBytesSpilled , DiskBytesSpilled
644662 path = self ._get_spill_dir (self .spills )
645663 if not os .path .exists (path ):
646664 os .makedirs (path )
647665
666+ used_memory = get_used_memory ()
648667 if not self .pdata :
649668 # The data has not been partitioned, it will iterator the
650669 # data once, write them into different files, has no
@@ -669,6 +688,7 @@ def _spill(self):
669688 self .serializer .dump_stream ([(k , v )], streams [h ])
670689
671690 for s in streams :
691+ DiskBytesSpilled += s .tell ()
672692 s .close ()
673693
674694 self .data .clear ()
@@ -687,9 +707,11 @@ def _spill(self):
687707 else :
688708 self .serializer .dump_stream (self .pdata [i ].iteritems (), f )
689709 self .pdata [i ].clear ()
710+ DiskBytesSpilled += os .path .getsize (p )
690711
691712 self .spills += 1
692713 gc .collect () # release the memory as much as possible
714+ MemoryBytesSpilled += (used_memory - get_used_memory ()) << 20
693715
694716 def _merged_items (self , index , limit = 0 ):
695717 size = sum (os .path .getsize (os .path .join (self ._get_spill_dir (j ), str (index )))
0 commit comments