Skip to content

Commit 1ea0669

Browse files
committed
choose sort based groupByKey() automatically
1 parent b40bae7 commit 1ea0669

File tree

4 files changed

+233
-190
lines changed

4 files changed

+233
-190
lines changed

python/pyspark/rdd.py

Lines changed: 8 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
from pyspark.storagelevel import StorageLevel
4545
from pyspark.resultiterable import ResultIterable
4646
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
47-
get_used_memory, ExternalSorter
47+
get_used_memory, ExternalSorter, ExternalGroupBy
4848

4949
from py4j.java_collections import ListConverter, MapConverter
5050

@@ -201,71 +201,6 @@ def _replaceRoot(self, value):
201201
self._sink(1)
202202

203203

204-
class SameKey(object):
205-
"""
206-
take the first few items which has the same expected key
207-
208-
This is used by GroupByKey.
209-
"""
210-
def __init__(self, key, values, it, groupBy):
211-
self.key = key
212-
self.values = values
213-
self.it = it
214-
self.groupBy = groupBy
215-
self._index = 0
216-
217-
def __iter__(self):
218-
return self
219-
220-
def next(self):
221-
if self._index >= len(self.values):
222-
if self.it is None:
223-
raise StopIteration
224-
225-
key, values = self.it.next()
226-
if key != self.key:
227-
self.groupBy._next_item = (key, values)
228-
raise StopIteration
229-
self.values = values
230-
self._index = 0
231-
232-
self._index += 1
233-
return self.values[self._index - 1]
234-
235-
236-
class GroupByKey(object):
237-
"""
238-
group a sorted iterator into [(k1, it1), (k2, it2), ...]
239-
"""
240-
def __init__(self, it):
241-
self.it = iter(it)
242-
self._next_item = None
243-
self.current = None
244-
245-
def __iter__(self):
246-
return self
247-
248-
def next(self):
249-
if self._next_item is None:
250-
while True:
251-
key, values = self.it.next()
252-
if self.current is None:
253-
break
254-
if key != self.current.key:
255-
break
256-
self.current.values.extend(values)
257-
258-
else:
259-
key, values = self._next_item
260-
self._next_item = None
261-
262-
if self.current is not None:
263-
self.current.it = None
264-
self.current = SameKey(key, values, self.it, self)
265-
266-
return key, self.current
267-
268-
269204
def _parse_memory(s):
270205
"""
271206
Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
@@ -1561,9 +1496,6 @@ def createZero():
15611496
def _can_spill(self):
15621497
return self.ctx._conf.get("spark.shuffle.spill", "True").lower() == "true"
15631498

1564-
def _sort_based(self):
1565-
return self.ctx._conf.get("spark.shuffle.sort", "False").lower() == "true"
1566-
15671499
def _memory_limit(self):
15681500
return _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
15691501

@@ -1577,14 +1509,6 @@ def groupByKey(self, numPartitions=None):
15771509
sum or average) over each key, using reduceByKey will provide much
15781510
better performance.
15791511
1580-
By default, it will use hash based aggregation, it can spill the items
1581-
into disks when the memory can not hold all the items, but it still
1582-
need to hold all the values for single key in memory.
1583-
1584-
When spark.shuffle.sort is True, it will switch to sort based approach,
1585-
then it can support single key with large number of values under small
1586-
amount of memory. But it is slower than hash based approach.
1587-
15881512
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
15891513
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
15901514
[('a', [1, 1]), ('b', [1])]
@@ -1601,42 +1525,26 @@ def mergeCombiners(a, b):
16011525
return a
16021526

16031527
spill = self._can_spill()
1604-
sort_based = self._sort_based()
1605-
if sort_based and not spill:
1606-
raise ValueError("can not use sort based group when"
1607-
" spark.executor.spill is false")
16081528
memory = self._memory_limit()
16091529
serializer = self._jrdd_deserializer
16101530
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
16111531

1612-
def combineLocally(iterator):
1532+
def combine(iterator):
16131533
merger = ExternalMerger(agg, memory * 0.9, serializer) \
16141534
if spill else InMemoryMerger(agg)
16151535
merger.mergeValues(iterator)
16161536
return merger.iteritems()
16171537

1618-
# combine them before shuffle could reduce the comparison later
1619-
locally_combined = self.mapPartitions(combineLocally)
1538+
locally_combined = self.mapPartitions(combine)
16201539
shuffled = locally_combined.partitionBy(numPartitions)
16211540

16221541
def groupByKey(it):
1623-
if sort_based:
1624-
# Flatten the combined values, so it will not consume huge
1625-
# memory during merging sort.
1626-
ser = FlattedValuesSerializer(
1627-
BatchedSerializer(PickleSerializer(), 1024), 10)
1628-
sorter = ExternalSorter(memory * 0.9, ser)
1629-
it = sorter.sorted(it, key=operator.itemgetter(0))
1630-
return imap(lambda (k, v): (k, ResultIterable(v)), GroupByKey(it))
1631-
1632-
else:
1633-
# this is faster than sort based
1634-
merger = ExternalMerger(agg, memory * 0.9, serializer) \
1635-
if spill else InMemoryMerger(agg)
1636-
merger.mergeCombiners(it)
1637-
return merger.iteritems()
1542+
merger = ExternalGroupBy(agg, memory, serializer)\
1543+
if spill else InMemoryMerger(agg)
1544+
merger.mergeCombiners(it)
1545+
return merger.iteritems()
16381546

1639-
return shuffled.mapPartitions(groupByKey)
1547+
return shuffled.mapPartitions(groupByKey).mapValues(ResultIterable)
16401548

16411549
# TODO: add tests
16421550
def flatMapValues(self, f):

python/pyspark/resultiterable.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,10 @@ def __iter__(self):
3333
return iter(self.it)
3434

3535
def __len__(self):
36-
return sum(1 for _ in self.it)
36+
try:
37+
return len(self.it)
38+
except TypeError:
39+
return sum(1 for _ in self.it)
3740

3841
def __reduce__(self):
3942
return (ResultIterable, (list(self.it),))

0 commit comments

Comments
 (0)