Skip to content

Commit 2c1d05b

Browse files
committed
refactor, minor turning
1 parent b48cda5 commit 2c1d05b

File tree

2 files changed

+71
-122
lines changed

2 files changed

+71
-122
lines changed

python/pyspark/rdd.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1660,14 +1660,22 @@ def _memory_limit(self):
16601660
def groupByKey(self, numPartitions=None):
16611661
"""
16621662
Group the values for each key in the RDD into a single sequence.
1663-
Hash-partitions the resulting RDD with into numPartitions partitions.
1663+
Hash-partitions the resulting RDD with into numPartitions
1664+
partitions.
1665+
1666+
The values in the resulting RDD is iterable object L{ResultIterable},
1667+
they can be iterated only once. The `len(values)` will result in
1668+
iterating values, so they can not be iterable after calling
1669+
`len(values)`.
16641670
16651671
Note: If you are grouping in order to perform an aggregation (such as a
16661672
sum or average) over each key, using reduceByKey will provide much
16671673
better performance.
16681674
16691675
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
1670-
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
1676+
>>> sorted(x.groupByKey().mapValues(len).collect())
1677+
[('a', 2), ('b', 1)]
1678+
>>> sorted(x.groupByKey().mapValues(list).collect())
16711679
[('a', [1, 1]), ('b', [1])]
16721680
"""
16731681
def createCombiner(x):

python/pyspark/shuffle.py

Lines changed: 61 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,19 @@
3333
try:
3434
import psutil
3535

36+
process = None
37+
3638
def get_used_memory():
3739
""" Return the used memory in MB """
38-
process = psutil.Process(os.getpid())
40+
global process
41+
if process is None or process._pid != os.getpid():
42+
process = psutil.Process(os.getpid())
3943
if hasattr(process, "memory_info"):
4044
info = process.memory_info()
4145
else:
4246
info = process.get_memory_info()
4347
return info.rss >> 20
48+
4449
except ImportError:
4550

4651
def get_used_memory():
@@ -49,6 +54,7 @@ def get_used_memory():
4954
for line in open('/proc/self/status'):
5055
if line.startswith('VmRSS:'):
5156
return int(line.split()[1]) >> 10
57+
5258
else:
5359
warnings.warn("Please install psutil to have better "
5460
"support with spilling")
@@ -57,6 +63,7 @@ def get_used_memory():
5763
rss = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
5864
return rss >> 20
5965
# TODO: support windows
66+
6067
return 0
6168

6269

@@ -146,7 +153,7 @@ def mergeCombiners(self, iterator):
146153
d[k] = comb(d[k], v) if k in d else v
147154

148155
def iteritems(self):
149-
""" Return the merged items ad iterator """
156+
""" Return the merged items as iterator """
150157
return self.data.iteritems()
151158

152159

@@ -210,18 +217,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
210217
localdirs=None, scale=1, partitions=59, batch=1000):
211218
Merger.__init__(self, aggregator)
212219
self.memory_limit = memory_limit
213-
# default serializer is only used for tests
214-
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
215-
# add compression
216-
if isinstance(self.serializer, BatchedSerializer):
217-
if not isinstance(self.serializer.serializer, CompressedSerializer):
218-
self.serializer = BatchedSerializer(
219-
CompressedSerializer(self.serializer.serializer),
220-
self.serializer.batchSize)
221-
else:
222-
if not isinstance(self.serializer, CompressedSerializer):
223-
self.serializer = CompressedSerializer(self.serializer)
224-
220+
self.serializer = self._compressed_serializer(serializer)
225221
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
226222
# number of partitions when spill data into disks
227223
self.partitions = partitions
@@ -238,6 +234,18 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
238234
# randomize the hash of key, id(o) is the address of o (aligned by 8)
239235
self._seed = id(self) + 7
240236

237+
def _compressed_serializer(self, serializer=None):
238+
# default serializer is only used for tests
239+
ser = serializer or PickleSerializer()
240+
# add compression
241+
if isinstance(ser, BatchedSerializer):
242+
if not isinstance(ser.serializer, CompressedSerializer):
243+
ser = BatchedSerializer(CompressedSerializer(ser.serializer), ser.batchSize)
244+
else:
245+
if not isinstance(ser, CompressedSerializer):
246+
ser = BatchedSerializer(CompressedSerializer(ser), 1024)
247+
return ser
248+
241249
def _get_spill_dir(self, n):
242250
""" Choose one directory for spill by number n """
243251
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
@@ -276,6 +284,9 @@ def _partition(self, key):
276284
return hash((key, self._seed)) % self.partitions
277285

278286
def _object_size(self, obj):
287+
""" How much of memory for this obj, assume that all the objects
288+
consume similar bytes of memory
289+
"""
279290
return 1
280291

281292
def mergeCombiners(self, iterator, limit=None):
@@ -485,18 +496,18 @@ class SameKey(object):
485496
This is used by GroupByKey.
486497
487498
>>> l = zip(range(2), range(2))
488-
>>> list(SameKey(0, [1], iter(l), GroupByKey(iter([]))))
499+
>>> list(SameKey(0, 1, iter(l), GroupByKey(iter([]))))
489500
[1, 0]
490-
>>> s = SameKey(0, [1], iter(l), GroupByKey(iter([])))
501+
>>> s = SameKey(0, 1, iter(l), GroupByKey(iter([])))
491502
>>> for i in range(2000):
492503
... s.append(i)
493504
>>> len(list(s))
494505
2002
495506
"""
496-
def __init__(self, key, values, it, groupBy):
507+
def __init__(self, key, value, iterator, groupBy):
497508
self.key = key
498-
self.values = values
499-
self.it = it
509+
self.values = [value]
510+
self.iterator = iterator
500511
self.groupBy = groupBy
501512
self._file = None
502513
self._ser = None
@@ -516,27 +527,22 @@ def next(self):
516527
self._index = 0
517528

518529
if self._index >= len(self.values) and self._file is not None:
519-
try:
520-
self.values = next(self._ser.load_stream(self._file))
521-
self._index = 0
522-
except StopIteration:
523-
self._file.close()
524-
self._file = None
530+
# load next chunk of values from disk
531+
self.values = next(self._ser.load_stream(self._file))
532+
self._index = 0
525533

526534
if self._index < len(self.values):
527535
value = self.values[self._index]
528536
self._index += 1
529537
return value
530538

531-
if self.it is None:
532-
raise StopIteration
539+
key, value = next(self.iterator)
540+
if key == self.key:
541+
return value
533542

534-
key, value = self.it.next()
535-
if key != self.key:
536-
self.groupBy._next_item = (key, value)
537-
self.it = None
538-
raise StopIteration
539-
return value
543+
# push them back into groupBy
544+
self.groupBy.next_item = (key, value)
545+
raise StopIteration
540546

541547
def append(self, value):
542548
if self._index is not None:
@@ -548,13 +554,14 @@ def append(self, value):
548554
self._spill()
549555

550556
def _spill(self):
557+
""" dump the values into disk """
551558
if self._file is None:
552559
dirs = _get_local_dirs("objects")
553560
d = dirs[id(self) % len(dirs)]
554561
if not os.path.exists(d):
555562
os.makedirs(d)
556563
p = os.path.join(d, str(id))
557-
self._file = open(p, "w+")
564+
self._file = open(p, "w+", 65536)
558565
self._ser = CompressedSerializer(PickleSerializer())
559566

560567
self._ser.dump_stream([self.values], self._file)
@@ -567,32 +574,34 @@ class GroupByKey(object):
567574
group a sorted iterator into [(k1, it1), (k2, it2), ...]
568575
569576
>>> k = [i/3 for i in range(6)]
570-
>>> v = [[i] for i in range(6)]
577+
>>> v = [i for i in range(6)]
571578
>>> g = GroupByKey(iter(zip(k, v)))
572579
>>> [(k, list(it)) for k, it in g]
573580
[(0, [0, 1, 2]), (1, [3, 4, 5])]
574581
"""
575-
def __init__(self, it):
576-
self.it = iter(it)
577-
self._next_item = None
582+
def __init__(self, iterator):
583+
self.iterator = iterator
584+
self.next_item = None
578585
self.current = None
579586

580587
def __iter__(self):
581588
return self
582589

583590
def next(self):
584-
if self._next_item is None:
591+
if self.next_item is None:
585592
while True:
586-
key, value = self.it.next()
593+
key, value = next(self.iterator)
587594
if self.current is None or key != self.current.key:
588595
break
596+
# the current key has not been visited.
589597
self.current.append(value)
590598
else:
591-
key, value = self._next_item
592-
self._next_item = None
599+
# next key was popped while visiting current key
600+
key, value = self.next_item
601+
self.next_item = None
593602

594-
self.current = SameKey(key, [value], self.it, self)
595-
return key, (v for vs in self.current for v in vs)
603+
self.current = SameKey(key, value, self.iterator, self)
604+
return key, self.current
596605

597606

598607
class ExternalGroupBy(ExternalMerger):
@@ -624,7 +633,7 @@ def _spill(self):
624633

625634
if not self.pdata:
626635
# The data has not been partitioned, it will iterator the
627-
# dataset once, write them into different files, has no
636+
# data once, write them into different files, has no
628637
# additional memory. It only called when the memory goes
629638
# above limit at the first time.
630639

@@ -636,12 +645,10 @@ def _spill(self):
636645
# sort them before dumping into disks
637646
self._sorted = len(self.data) < self.SORT_KEY_LIMIT
638647
if self._sorted:
639-
ser = self._flatted_serializer()
648+
self.serializer = self._flatted_serializer()
640649
for k in sorted(self.data.keys()):
641-
v = self.data[k]
642650
h = self._partition(k)
643-
ser.dump_stream([(k, v)], streams[h])
644-
self.serializer = ser
651+
self.serializer.dump_stream([(k, self.data[k])], streams[h])
645652
else:
646653
for k, v in self.data.iteritems():
647654
h = self._partition(k)
@@ -651,6 +658,7 @@ def _spill(self):
651658
s.close()
652659

653660
self.data.clear()
661+
# self.pdata is cached in `mergeValues` and `mergeCombiners`
654662
self.pdata.extend([{} for i in range(self.partitions)])
655663

656664
else:
@@ -659,8 +667,9 @@ def _spill(self):
659667
with open(p, "w") as f:
660668
# dump items in batch
661669
if self._sorted:
662-
self.serializer.dump_stream(
663-
sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0)), f)
670+
# sort by key only (stable)
671+
sorted_items = sorted(self.pdata[i].iteritems(), key=operator.itemgetter(0))
672+
self.serializer.dump_stream(sorted_items, f)
664673
else:
665674
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
666675
self.pdata[i].clear()
@@ -706,75 +715,7 @@ def load_partition(j):
706715
sorted_items = sorter.sorted(itertools.chain(*disk_items),
707716
key=operator.itemgetter(0))
708717

709-
return GroupByKey(sorted_items)
710-
711-
712-
class ExternalSorter(object):
713-
"""
714-
ExtenalSorter will divide the elements into chunks, sort them in
715-
memory and dump them into disks, finally merge them back.
716-
717-
The spilling will only happen when the used memory goes above
718-
the limit.
719-
720-
>>> sorter = ExternalSorter(1) # 1M
721-
>>> import random
722-
>>> l = range(1024)
723-
>>> random.shuffle(l)
724-
>>> sorted(l) == list(sorter.sorted(l))
725-
True
726-
>>> sorted(l) == list(sorter.sorted(l, key=lambda x: -x, reverse=True))
727-
True
728-
"""
729-
def __init__(self, memory_limit, serializer=None):
730-
self.memory_limit = memory_limit
731-
self.local_dirs = _get_local_dirs("sort")
732-
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
733-
self._spilled_bytes = 0
734-
735-
def _get_path(self, n):
736-
""" Choose one directory for spill by number n """
737-
d = self.local_dirs[n % len(self.local_dirs)]
738-
if not os.path.exists(d):
739-
os.makedirs(d)
740-
return os.path.join(d, str(n))
741-
742-
def sorted(self, iterator, key=None, reverse=False):
743-
"""
744-
Sort the elements in iterator, do external sort when the memory
745-
goes above the limit.
746-
"""
747-
batch = 10
748-
chunks, current_chunk = [], []
749-
iterator = iter(iterator)
750-
while True:
751-
# pick elements in batch
752-
chunk = list(itertools.islice(iterator, batch))
753-
current_chunk.extend(chunk)
754-
if len(chunk) < batch:
755-
break
756-
757-
if get_used_memory() > self.memory_limit:
758-
# sort them inplace will save memory
759-
current_chunk.sort(key=key, reverse=reverse)
760-
path = self._get_path(len(chunks))
761-
with open(path, 'w') as f:
762-
self.serializer.dump_stream(current_chunk, f)
763-
self._spilled_bytes += os.path.getsize(path)
764-
chunks.append(self.serializer.load_stream(open(path)))
765-
current_chunk = []
766-
767-
elif not chunks:
768-
batch = min(batch * 2, 10000)
769-
770-
current_chunk.sort(key=key, reverse=reverse)
771-
if not chunks:
772-
return current_chunk
773-
774-
if current_chunk:
775-
chunks.append(iter(current_chunk))
776-
777-
return heapq.merge(chunks, key=key, reverse=reverse)
718+
return ((k, itertools.chain.from_iterable(vs)) for k, vs in GroupByKey(sorted_items))
778719

779720

780721
if __name__ == "__main__":

0 commit comments

Comments
 (0)