@@ -438,11 +438,23 @@ class ExternalSorter(object):
438438
439439 The spilling will only happen when the used memory goes above
440440 the limit.
441+
442+
443+ >>> sorter = ExternalSorter(1) # 1M
444+ >>> import random
445+ >>> l = range(1024)
446+ >>> random.shuffle(l)
447+ >>> sorted(l) == list(sorter.sorted(l))
448+ True
449+ >>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True))
450+ True
441451 """
442452 def __init__ (self , memory_limit , serializer = None ):
443453 self .memory_limit = memory_limit
444454 self .local_dirs = _get_local_dirs ("sort" )
445- self .serializer = serializer or BatchedSerializer (PickleSerializer (), 1024 )
455+ self .serializer = serializer or BatchedSerializer (
456+ CompressedSerializer (PickleSerializer ()), 1024 )
457+ self ._spilled_bytes = 0
446458
447459 def _get_path (self , n ):
448460 """ Choose one directory for spill by number n """
@@ -472,6 +484,7 @@ def sorted(self, iterator, key=None, reverse=False):
472484 path = self ._get_path (len (chunks ))
473485 with open (path , 'w' ) as f :
474486 self .serializer .dump_stream (current_chunk , f )
487+ self ._spilled_bytes += os .path .getsize (path )
475488 chunks .append (self .serializer .load_stream (open (path )))
476489 os .unlink (path ) # data will be deleted after close
477490 current_chunk = []
@@ -486,6 +499,7 @@ def sorted(self, iterator, key=None, reverse=False):
486499
487500 if current_chunk :
488501 chunks .append (iter (current_chunk ))
502+
489503 return heapq .merge (chunks , key = key , reverse = reverse )
490504
491505
0 commit comments