Skip to content

Commit e78c15c

Browse files
author
Davies Liu
committed
address comments
1 parent 0b0fde8 commit e78c15c

File tree

2 files changed

+26
-38
lines changed

2 files changed

+26
-38
lines changed

python/pyspark/shuffle.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -533,14 +533,14 @@ class ExternalList(object):
533533
>>> l.append(10)
534534
>>> len(l)
535535
101
536-
>>> for i in range(10240):
536+
>>> for i in range(20240):
537537
... l.append(i)
538538
>>> len(l)
539-
10341
539+
20341
540540
>>> import pickle
541541
>>> l2 = pickle.loads(pickle.dumps(l))
542542
>>> len(l2)
543-
10341
543+
20341
544544
>>> list(l2)[100]
545545
10
546546
"""
@@ -577,9 +577,8 @@ def __iter__(self):
577577
# read all items from disks first
578578
with os.fdopen(os.dup(self._file.fileno()), 'r') as f:
579579
f.seek(0)
580-
for values in self._ser.load_stream(f):
581-
for v in values:
582-
yield v
580+
for v in self._ser.load_stream(f):
581+
yield v
583582

584583
for v in self.values:
585584
yield v
@@ -601,7 +600,7 @@ def _open_file(self):
601600
os.makedirs(d)
602601
p = os.path.join(d, str(id))
603602
self._file = open(p, "w+", 65536)
604-
self._ser = CompressedSerializer(PickleSerializer())
603+
self._ser = BatchedSerializer(CompressedSerializer(PickleSerializer()), 1024)
605604
os.unlink(p)
606605

607606
def _spill(self):
@@ -612,7 +611,7 @@ def _spill(self):
612611

613612
used_memory = get_used_memory()
614613
pos = self._file.tell()
615-
self._ser.dump_stream([self.values], self._file)
614+
self._ser.dump_stream(self.values, self._file)
616615
self.values = []
617616
gc.collect()
618617
DiskBytesSpilled += self._file.tell() - pos
@@ -622,7 +621,17 @@ def _spill(self):
622621
class ExternalListOfList(ExternalList):
623622
"""
624623
An external list for list.
624+
625+
>>> l = ExternalListOfList([[i, i] for i in range(100)])
626+
>>> len(l)
627+
200
628+
>>> l.append(range(10))
629+
>>> len(l)
630+
210
631+
>>> len(list(l))
632+
210
625633
"""
634+
626635
def __init__(self, values):
627636
ExternalList.__init__(self, values)
628637
self.count = sum(len(i) for i in values)
@@ -632,20 +641,23 @@ def append(self, value):
632641
# already counted 1 in ExternalList.append
633642
self.count += len(value) - 1
634643

644+
def __iter__(self):
645+
for values in ExternalList.__iter__(self):
646+
for v in values:
647+
yield v
648+
635649

636650
class GroupByKey(object):
637651
"""
638652
Group a sorted iterator as [(k1, it1), (k2, it2), ...]
639653
640654
>>> k = [i/3 for i in range(6)]
641-
>>> v = [i for i in range(6)]
655+
>>> v = [[i] for i in range(6)]
642656
>>> g = GroupByKey(iter(zip(k, v)))
643657
>>> [(k, list(it)) for k, it in g]
644658
[(0, [0, 1, 2]), (1, [3, 4, 5])]
645659
"""
646660

647-
external_class = ExternalList
648-
649661
def __init__(self, iterator):
650662
self.iterator = iterator
651663
self.next_item = None
@@ -655,7 +667,7 @@ def __iter__(self):
655667

656668
def next(self):
657669
key, value = self.next_item if self.next_item else next(self.iterator)
658-
values = self.external_class([value])
670+
values = ExternalListOfList([value])
659671
try:
660672
while True:
661673
k, v = next(self.iterator)
@@ -668,30 +680,6 @@ def next(self):
668680
return key, values
669681

670682

671-
class GroupListsByKey(GroupByKey):
672-
"""
673-
Group a sorted iterator of list as [(k1, it1), (k2, it2), ...]
674-
"""
675-
external_class = ExternalListOfList
676-
677-
678-
class ChainedIterable(object):
679-
"""
680-
Picklable chained iterator, similar to itertools.chain.from_iterable()
681-
"""
682-
def __init__(self, iterators):
683-
self.iterators = iterators
684-
685-
def __len__(self):
686-
try:
687-
return len(self.iterators)
688-
except TypeError:
689-
return sum(len(i) for i in self.iterators)
690-
691-
def __iter__(self):
692-
return itertools.chain.from_iterable(self.iterators)
693-
694-
695683
class ExternalGroupBy(ExternalMerger):
696684

697685
"""
@@ -835,7 +823,7 @@ def load_partition(j):
835823
sorter = ExternalSorter(self.memory_limit, ser)
836824
sorted_items = sorter.sorted(itertools.chain(*disk_items),
837825
key=operator.itemgetter(0))
838-
return ((k, ChainedIterable(vs)) for k, vs in GroupListsByKey(sorted_items))
826+
return ((k, vs) for k, vs in GroupByKey(sorted_items))
839827

840828

841829
if __name__ == "__main__":

python/pyspark/tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,7 @@ def test_external_group_by_key(self):
745745
filtered.values().map(lambda x: (len(x), len(list(x)))).collect())
746746
result = filtered.collect()[0][1]
747747
self.assertEqual(N/3, len(result))
748-
self.assertTrue(isinstance(result.data, shuffle.ChainedIterable))
748+
self.assertTrue(isinstance(result.data, shuffle.ExternalList))
749749

750750
def test_sort_on_empty_rdd(self):
751751
self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())

0 commit comments

Comments
 (0)