Skip to content

Commit d05060d

Browse files
committed
group the same key before shuffle, reduce the comparison during sorting
1 parent 083d842 commit d05060d

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

python/pyspark/rdd.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -218,19 +218,19 @@ def __iter__(self):
218218
return self
219219

220220
def next(self):
221-
if self._index < len(self.values):
222-
self._index += 1
223-
return self.values[self._index - 1]
221+
if self._index >= len(self.values):
222+
if self.it is None:
223+
raise StopIteration
224224

225-
if self.it is None:
226-
raise StopIteration
225+
key, values = self.it.next()
226+
if key != self.key:
227+
self.groupBy._next_item = (key, values)
228+
raise StopIteration
229+
self.values = values
230+
self._index = 0
227231

228-
key, value = self.it.next()
229-
if key == self.key:
230-
return value
231-
232-
self.groupBy._next_item = (key, value)
233-
raise StopIteration
232+
self._index += 1
233+
return self.values[self._index - 1]
234234

235235

236236
class GroupByKey(object):
@@ -248,20 +248,20 @@ def __iter__(self):
248248
def next(self):
249249
if self._next_item is None:
250250
while True:
251-
key, value = self.it.next()
251+
key, values = self.it.next()
252252
if self.current is None:
253253
break
254254
if key != self.current.key:
255255
break
256-
self.current.values.append(value)
256+
self.current.values.extend(values)
257257

258258
else:
259-
key, value = self._next_item
259+
key, values = self._next_item
260260
self._next_item = None
261261

262262
if self.current is not None:
263263
self.current.it = None
264-
self.current = SameKey(key, [value], self.it, self)
264+
self.current = SameKey(key, values, self.it, self)
265265

266266
return key, self.current
267267

@@ -1581,9 +1581,30 @@ def groupByKey(self, numPartitions=None):
15811581
>>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect()))
15821582
[('a', [1, 1]), ('b', [1])]
15831583
"""
1584+
def createCombiner(x):
1585+
return [x]
1586+
1587+
def mergeValue(xs, x):
1588+
xs.append(x)
1589+
return xs
1590+
1591+
def mergeCombiners(a, b):
1592+
a.extend(b)
1593+
return a
1594+
15841595
serializer = self.ctx.serializer
15851596
spill = self._can_spill()
15861597
memory = self._memory_limit()
1598+
agg = Aggregator(createCombiner, mergeValue, mergeCombiners)
1599+
1600+
def combineLocally(iterator):
1601+
merger = ExternalMerger(agg, memory * 0.9, serializer) \
1602+
if spill else InMemoryMerger(agg)
1603+
merger.mergeValues(iterator)
1604+
return merger.iteritems()
1605+
1606+
locally_combined = self.mapPartitions(combineLocally)
1607+
shuffled = locally_combined.partitionBy(numPartitions)
15871608

15881609
def groupByKey(it):
15891610
if spill:
@@ -1592,8 +1613,6 @@ def groupByKey(it):
15921613
for k, v in GroupByKey(it):
15931614
yield k, ResultIterable(v)
15941615

1595-
# TODO: combine before shuffle ?
1596-
shuffled = self.partitionBy(numPartitions)
15971616
return shuffled.mapPartitions(groupByKey)
15981617

15991618
# TODO: add tests

0 commit comments

Comments
 (0)