Skip to content

Commit acd8e1b

Browse files
committed
fix memory when groupByKey().count()
1 parent 905b233 commit acd8e1b

File tree

1 file changed

+68
-17
lines changed

1 file changed

+68
-17
lines changed

python/pyspark/shuffle.py

Lines changed: 68 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def mergeValues(self, iterator):
247247
""" Combine the items by creator and combiner """
248248
# speedup attribute lookup
249249
creator, comb = self.agg.createCombiner, self.agg.mergeValue
250-
c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, 100
250+
c, data, pdata, hfun, batch = 0, self.data, self.pdata, self._partition, self.batch
251251
limit = self.memory_limit
252252

253253
for k, v in iterator:
@@ -259,36 +259,40 @@ def mergeValues(self, iterator):
259259
if get_used_memory() >= limit:
260260
self._spill()
261261
limit = self._next_limit()
262+
batch /= 2
263+
c = 0
262264
else:
263-
batch = min(batch * 2, self.batch)
264-
c = 0
265+
batch *= 1.5
265266

266267
def _partition(self, key):
267268
""" Return the partition for key """
268269
return hash((key, self._seed)) % self.partitions
269270

271+
def _object_size(self, obj):
272+
return 1
273+
270274
def mergeCombiners(self, iterator, limit=None):
271275
""" Merge (K,V) pair by mergeCombiner """
272276
if limit is None:
273277
limit = self.memory_limit
274278
# speedup attribute lookup
275-
comb, hfun = self.agg.mergeCombiners, self._partition
276-
c, data, pdata, batch = 0, self.data, self.pdata, 1
279+
comb, hfun, objsize = self.agg.mergeCombiners, self._partition, self._object_size
280+
c, data, pdata, batch = 0, self.data, self.pdata, self.batch
277281
for k, v in iterator:
278282
d = pdata[hfun(k)] if pdata else data
279283
d[k] = comb(d[k], v) if k in d else v
280284
if not limit:
281285
continue
282286

283-
c += 1
287+
c += objsize(v)
284288
if c > batch:
285289
if get_used_memory() > limit:
286290
self._spill()
287291
limit = self._next_limit()
288-
batch /= 4
292+
batch /= 2
293+
c = 0
289294
else:
290-
batch = min(batch * 2, self.batch)
291-
c = 0
295+
batch *= 1.5
292296

293297
def _spill(self):
294298
"""
@@ -476,18 +480,42 @@ class SameKey(object):
476480
>>> l = zip(range(2), range(2))
477481
>>> list(SameKey(0, [1], iter(l), GroupByKey(iter([]))))
478482
[1, 0]
483+
>>> s = SameKey(0, [1], iter(l), GroupByKey(iter([])))
484+
>>> for i in range(2000):
485+
... s.append(i)
486+
>>> len(list(s))
487+
2002
479488
"""
480489
def __init__(self, key, values, it, groupBy):
481490
self.key = key
482491
self.values = values
483492
self.it = it
484493
self.groupBy = groupBy
485-
self._index = 0
494+
self._file = None
495+
self._ser = None
496+
self._index = None
486497

487498
def __iter__(self):
488499
return self
489500

490501
def next(self):
502+
if self._index is None:
503+
# begin of iterator
504+
if self._file is not None:
505+
if self.values:
506+
self._spill()
507+
self._file.flush()
508+
self._file.seek(0)
509+
self._index = 0
510+
511+
if self._index >= len(self.values) and self._file is not None:
512+
try:
513+
self.values = next(self._ser.load_stream(self._file))
514+
self._index = 0
515+
except StopIteration:
516+
self._file.close()
517+
self._file = None
518+
491519
if self._index < len(self.values):
492520
value = self.values[self._index]
493521
self._index += 1
@@ -503,6 +531,29 @@ def next(self):
503531
raise StopIteration
504532
return value
505533

534+
def append(self, value):
535+
if self._index is not None:
536+
raise ValueError("Can not append value while iterating")
537+
538+
self.values.append(value)
539+
# dump them into disk if the key is huge
540+
if len(self.values) >= 10240:
541+
self._spill()
542+
543+
def _spill(self):
544+
if self._file is None:
545+
dirs = _get_local_dirs("objects")
546+
d = dirs[id(self) % len(dirs)]
547+
if not os.path.exists(d):
548+
os.makedirs(d)
549+
p = os.path.join(d, str(id))
550+
self._file = open(p, "w+")
551+
self._ser = CompressedSerializer(PickleSerializer())
552+
553+
self._ser.dump_stream([self.values], self._file)
554+
self.values = []
555+
gc.collect()
556+
506557

507558
class GroupByKey(object):
508559
"""
@@ -528,16 +579,12 @@ def next(self):
528579
key, value = self.it.next()
529580
if self.current is None or key != self.current.key:
530581
break
531-
self.current.values.append(value)
532-
582+
self.current.append(value)
533583
else:
534584
key, value = self._next_item
535585
self._next_item = None
536586

537-
if self.current is not None:
538-
self.current.it = None
539587
self.current = SameKey(key, [value], self.it, self)
540-
541588
return key, (v for vs in self.current for v in vs)
542589

543590

@@ -557,6 +604,9 @@ def _flatted_serializer(self):
557604
ser = FlattedValuesSerializer(ser, 20)
558605
return ser
559606

607+
def _object_size(self, obj):
608+
return len(obj)
609+
560610
def _spill(self):
561611
"""
562612
dump already partitioned data into disks.
@@ -615,8 +665,9 @@ def _merged_items(self, index, limit=0):
615665
size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))
616666
for j in range(self.spills))
617667
# if the memory can not hold all the partition,
618-
# then use sort based merge
619-
if (size >> 20) > self.memory_limit / 2:
668+
# then use sort based merge. Because of compression,
669+
# the data on disks will be much smaller than needed memory
670+
if (size >> 20) > self.memory_limit / 10:
620671
return self._sorted_items(index)
621672

622673
self.data = {}

0 commit comments

Comments
 (0)