Skip to content

Commit 8ef965e

Browse files
committed
Merge branch 'master' into groupby
Conflicts: python/pyspark/shuffle.py
2 parents fbc504a + 4e3fbe8 commit 8ef965e

File tree

4 files changed

+47
-14
lines changed

4 files changed

+47
-14
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,10 @@ private[spark] class PythonRDD(
124124
val total = finishTime - startTime
125125
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
126126
init, finish))
127+
val memoryBytesSpilled = stream.readLong()
128+
val diskBytesSpilled = stream.readLong()
129+
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
130+
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
127131
read()
128132
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
129133
// Signals that an exception has been thrown in python

python/pyspark/shuffle.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,11 @@ def _get_local_dirs(sub):
7878
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
7979

8080

81+
# global stats
82+
MemoryBytesSpilled = 0L
83+
DiskBytesSpilled = 0L
84+
85+
8186
class Aggregator(object):
8287

8388
"""
@@ -318,10 +323,12 @@ def _spill(self):
318323
319324
It will dump the data in batch for better performance.
320325
"""
326+
global MemoryBytesSpilled, DiskBytesSpilled
321327
path = self._get_spill_dir(self.spills)
322328
if not os.path.exists(path):
323329
os.makedirs(path)
324330

331+
used_memory = get_used_memory()
325332
if not self.pdata:
326333
# The data has not been partitioned, it will iterator the
327334
# dataset once, write them into different files, has no
@@ -339,6 +346,7 @@ def _spill(self):
339346
self.serializer.dump_stream([(k, v)], streams[h])
340347

341348
for s in streams:
349+
DiskBytesSpilled += s.tell()
342350
s.close()
343351

344352
self.data.clear()
@@ -351,9 +359,11 @@ def _spill(self):
351359
# dump items in batch
352360
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
353361
self.pdata[i].clear()
362+
DiskBytesSpilled += os.path.getsize(p)
354363

355364
self.spills += 1
356365
gc.collect() # release the memory as much as possible
366+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
357367

358368
def iteritems(self):
359369
""" Return all merged items as iterator """
@@ -454,7 +464,6 @@ def __init__(self, memory_limit, serializer=None):
454464
self.local_dirs = _get_local_dirs("sort")
455465
self.serializer = serializer or BatchedSerializer(
456466
CompressedSerializer(PickleSerializer()), 1024)
457-
self._spilled_bytes = 0
458467

459468
def _get_path(self, n):
460469
""" Choose one directory for spill by number n """
@@ -468,6 +477,7 @@ def sorted(self, iterator, key=None, reverse=False):
468477
Sort the elements in iterator, do external sort when the memory
469478
goes above the limit.
470479
"""
480+
global MemoryBytesSpilled, DiskBytesSpilled
471481
batch = 10
472482
chunks, current_chunk = [], []
473483
iterator = iter(iterator)
@@ -478,17 +488,19 @@ def sorted(self, iterator, key=None, reverse=False):
478488
if len(chunk) < batch:
479489
break
480490

481-
if get_used_memory() > self.memory_limit:
491+
used_memory = get_used_memory()
492+
if used_memory > self.memory_limit:
482493
# sort them inplace will save memory
483494
current_chunk.sort(key=key, reverse=reverse)
484495
path = self._get_path(len(chunks))
485496
with open(path, 'w') as f:
486497
self.serializer.dump_stream(current_chunk, f)
487-
self._spilled_bytes += os.path.getsize(path)
488498
chunks.append(self.serializer.load_stream(open(path)))
489499
os.unlink(path) # data will be deleted after close
490500
current_chunk = []
491501
gc.collect()
502+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
503+
DiskBytesSpilled += os.path.getsize(path)
492504

493505
elif not chunks:
494506
batch = min(batch * 2, 10000)
@@ -569,6 +581,7 @@ def append(self, value):
569581

570582
def _spill(self):
571583
""" dump the values into disk """
584+
global MemoryBytesSpilled, DiskBytesSpilled
572585
if self._file is None:
573586
dirs = _get_local_dirs("objects")
574587
d = dirs[id(self) % len(dirs)]
@@ -578,9 +591,13 @@ def _spill(self):
578591
self._file = open(p, "w+", 65536)
579592
self._ser = CompressedSerializer(PickleSerializer())
580593

594+
used_memory = get_used_memory()
595+
pos = self._file.tell()
581596
self._ser.dump_stream([self.values], self._file)
597+
DiskBytesSpilled += self._file.tell() - pos
582598
self.values = []
583599
gc.collect()
600+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
584601

585602

586603
class GroupByKey(object):
@@ -641,10 +658,12 @@ def _spill(self):
641658
"""
642659
dump already partitioned data into disks.
643660
"""
661+
global MemoryBytesSpilled, DiskBytesSpilled
644662
path = self._get_spill_dir(self.spills)
645663
if not os.path.exists(path):
646664
os.makedirs(path)
647665

666+
used_memory = get_used_memory()
648667
if not self.pdata:
649668
# The data has not been partitioned, it will iterator the
650669
# data once, write them into different files, has no
@@ -669,6 +688,7 @@ def _spill(self):
669688
self.serializer.dump_stream([(k, v)], streams[h])
670689

671690
for s in streams:
691+
DiskBytesSpilled += s.tell()
672692
s.close()
673693

674694
self.data.clear()
@@ -687,9 +707,11 @@ def _spill(self):
687707
else:
688708
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
689709
self.pdata[i].clear()
710+
DiskBytesSpilled += os.path.getsize(p)
690711

691712
self.spills += 1
692713
gc.collect() # release the memory as much as possible
714+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
693715

694716
def _merged_items(self, index, limit=0):
695717
size = sum(os.path.getsize(os.path.join(self._get_spill_dir(j), str(index)))

python/pyspark/tests.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
CloudPickleSerializer
4747
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
4848
from pyspark.sql import SQLContext, IntegerType
49+
from pyspark import shuffle
4950

5051
_have_scipy = False
5152
_have_numpy = False
@@ -138,17 +139,17 @@ def test_external_sort(self):
138139
random.shuffle(l)
139140
sorter = ExternalSorter(1)
140141
self.assertEquals(sorted(l), list(sorter.sorted(l)))
141-
self.assertGreater(sorter._spilled_bytes, 0)
142-
last = sorter._spilled_bytes
142+
self.assertGreater(shuffle.DiskBytesSpilled, 0)
143+
last = shuffle.DiskBytesSpilled
143144
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
144-
self.assertGreater(sorter._spilled_bytes, last)
145-
last = sorter._spilled_bytes
145+
self.assertGreater(shuffle.DiskBytesSpilled, last)
146+
last = shuffle.DiskBytesSpilled
146147
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
147-
self.assertGreater(sorter._spilled_bytes, last)
148-
last = sorter._spilled_bytes
148+
self.assertGreater(shuffle.DiskBytesSpilled, last)
149+
last = shuffle.DiskBytesSpilled
149150
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
150151
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
151-
self.assertGreater(sorter._spilled_bytes, last)
152+
self.assertGreater(shuffle.DiskBytesSpilled, last)
152153

153154
def test_external_sort_in_rdd(self):
154155
conf = SparkConf().set("spark.python.worker.memory", "1m")

python/pyspark/worker.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,14 @@
2323
import time
2424
import socket
2525
import traceback
26-
# CloudPickler needs to be imported so that depicklers are registered using the
27-
# copy_reg module.
26+
2827
from pyspark.accumulators import _accumulatorRegistry
2928
from pyspark.broadcast import Broadcast, _broadcastRegistry
30-
from pyspark.cloudpickle import CloudPickler
3129
from pyspark.files import SparkFiles
3230
from pyspark.serializers import write_with_length, write_int, read_long, \
3331
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
3432
CompressedSerializer
35-
33+
from pyspark import shuffle
3634

3735
pickleSer = PickleSerializer()
3836
utf8_deserializer = UTF8Deserializer()
@@ -52,6 +50,11 @@ def main(infile, outfile):
5250
if split_index == -1: # for unit tests
5351
return
5452

53+
# initialize global state
54+
shuffle.MemoryBytesSpilled = 0
55+
shuffle.DiskBytesSpilled = 0
56+
_accumulatorRegistry.clear()
57+
5558
# fetch name of workdir
5659
spark_files_dir = utf8_deserializer.loads(infile)
5760
SparkFiles._root_directory = spark_files_dir
@@ -97,6 +100,9 @@ def main(infile, outfile):
97100
exit(-1)
98101
finish_time = time.time()
99102
report_times(outfile, boot_time, init_time, finish_time)
103+
write_long(shuffle.MemoryBytesSpilled, outfile)
104+
write_long(shuffle.DiskBytesSpilled, outfile)
105+
100106
# Mark the beginning of the accumulators section of the output
101107
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
102108
write_int(len(_accumulatorRegistry), outfile)

0 commit comments

Comments
 (0)