Skip to content

Commit 1f69f93

Browse files
committed
fix tests
1 parent 0d3395f commit 1f69f93

File tree

2 files changed

+16
-17
lines changed

2 files changed

+16
-17
lines changed

python/pyspark/shuffle.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,16 @@ class ExternalMerger(Merger):
203203
>>> agg = SimpleAggregator(lambda x, y: x + y)
204204
>>> merger = ExternalMerger(agg, 10)
205205
>>> N = 10000
206-
>>> merger.mergeValues(zip(xrange(N), xrange(N)) * 10)
206+
>>> merger.mergeValues(zip(xrange(N), xrange(N)))
207207
>>> assert merger.spills > 0
208208
>>> sum(v for k,v in merger.iteritems())
209-
499950000
209+
49995000
210210
211211
>>> merger = ExternalMerger(agg, 10)
212-
>>> merger.mergeCombiners(zip(xrange(N), xrange(N)) * 10)
212+
>>> merger.mergeCombiners(zip(xrange(N), xrange(N)))
213213
>>> assert merger.spills > 0
214214
>>> sum(v for k,v in merger.iteritems())
215-
499950000
215+
49995000
216216
"""
217217

218218
# the max total partitions created recursively
@@ -376,15 +376,13 @@ def _external_items(self):
376376
if any(self.pdata):
377377
self._spill()
378378
self.pdata = []
379-
hard_limit = self._next_limit()
380379

381380
try:
382381
for i in range(self.partitions):
383382
for v in self._merged_items(i):
384383
yield v
385384
self.data.clear()
386385
gc.collect()
387-
hard_limit = self._next_limit()
388386

389387
# remove the merged partition
390388
for j in range(self.spills):
@@ -393,8 +391,9 @@ def _external_items(self):
393391
finally:
394392
self._cleanup()
395393

396-
def _merged_items(self, index, limit=0):
394+
def _merged_items(self, index):
397395
self.data = {}
396+
limit = self._next_limit()
398397
for j in range(self.spills):
399398
path = self._get_spill_dir(j)
400399
p = os.path.join(path, str(index))
@@ -411,11 +410,6 @@ def _merged_items(self, index, limit=0):
411410

412411
return self.data.iteritems()
413412

414-
def _cleanup(self):
415-
""" Clean up all the files in disks """
416-
for d in self.localdirs:
417-
shutil.rmtree(d, True)
418-
419413
def _recursive_merged_items(self, index):
420414
"""
421415
merge the partitioned items and return the as iterator
@@ -440,6 +434,11 @@ def _recursive_merged_items(self, index):
440434

441435
return m._external_items()
442436

437+
def _cleanup(self):
438+
""" Clean up all the files in disks """
439+
for d in self.localdirs:
440+
shutil.rmtree(d, True)
441+
443442

444443
class ExternalSorter(object):
445444
"""
@@ -572,7 +571,7 @@ def __iter__(self):
572571
for v in self.values:
573572
yield v
574573

575-
if not self.groupBy or not self.groupBy.next_item:
574+
if not self.groupBy or self.groupBy.next_item:
576575
# different key was already found by previous accessing
577576
return
578577

@@ -777,14 +776,14 @@ def _spill(self):
777776
gc.collect() # release the memory as much as possible
778777
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
779778

780-
def _merged_items(self, index, limit=0):
779+
def _merged_items(self, index):
781780
size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
782781
for j in range(self.spills))
783782
# if the memory can not hold all the partition,
784783
# then use sort based merge. Because of compression,
785784
# the data on disks will be much smaller than needed memory
786785
if (size >> 20) >= self.memory_limit / 10:
787-
return self._sorted_items(index)
786+
return self._merge_sorted_items(index)
788787

789788
self.data = {}
790789
for j in range(self.spills):
@@ -794,7 +793,7 @@ def _merged_items(self, index, limit=0):
794793
self.mergeCombiners(self.serializer.load_stream(open(p)), 0)
795794
return self.data.iteritems()
796795

797-
def _sorted_items(self, index):
796+
def _merge_sorted_items(self, index):
798797
""" load a partition from disk, then sort and group by key """
799798
def load_partition(j):
800799
path = self._get_spill_dir(j)

python/pyspark/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
class TestMerger(unittest.TestCase):
7171

7272
def setUp(self):
73-
self.N = 1 << 16
73+
self.N = 1 << 12
7474
self.l = [i for i in xrange(self.N)]
7575
self.data = zip(self.l, self.l)
7676
self.agg = Aggregator(lambda x: [x],

0 commit comments

Comments
 (0)