Skip to content

Commit fbe9029

Browse files
committed
show spilled bytes in Python in web ui
1 parent f0f1ba0 commit fbe9029

File tree

5 files changed

+39
-13
lines changed

5 files changed

+39
-13
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ private[spark] class PythonRDD(
115115
val total = finishTime - startTime
116116
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot,
117117
init, finish))
118+
val memoryBytesSpilled = stream.readLong()
119+
val diskBytesSpilled = stream.readLong()
120+
context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
121+
context.taskMetrics.diskBytesSpilled += diskBytesSpilled
118122
read()
119123
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
120124
// Signals that an exception has been thrown in python

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
242242
t.taskMetrics)
243243

244244
// Overwrite task metrics
245-
t.taskMetrics = Some(taskMetrics)
245+
// FIXME: deepcopy the metrics, or they will be the same object in local mode
246+
t.taskMetrics = Some(scala.util.Marshal.load[TaskMetrics](scala.util.Marshal.dump(taskMetrics)))
246247
}
247248
}
248249
}

python/pyspark/shuffle.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def _get_local_dirs(sub):
6868
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]
6969

7070

71+
# global stats
72+
MemoryBytesSpilled = 0L
73+
DiskBytesSpilled = 0L
74+
75+
7176
class Aggregator(object):
7277

7378
"""
@@ -313,10 +318,12 @@ def _spill(self):
313318
314319
It will dump the data in batch for better performance.
315320
"""
321+
global MemoryBytesSpilled, DiskBytesSpilled
316322
path = self._get_spill_dir(self.spills)
317323
if not os.path.exists(path):
318324
os.makedirs(path)
319325

326+
used_memory = get_used_memory()
320327
if not self.pdata:
321328
# The data has not been partitioned, it will iterator the
322329
# dataset once, write them into different files, has no
@@ -334,6 +341,7 @@ def _spill(self):
334341
self.serializer.dump_stream([(k, v)], streams[h])
335342

336343
for s in streams:
344+
DiskBytesSpilled += s.tell()
337345
s.close()
338346

339347
self.data.clear()
@@ -346,9 +354,11 @@ def _spill(self):
346354
# dump items in batch
347355
self.serializer.dump_stream(self.pdata[i].iteritems(), f)
348356
self.pdata[i].clear()
357+
DiskBytesSpilled += os.path.getsize(p)
349358

350359
self.spills += 1
351360
gc.collect() # release the memory as much as possible
361+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
352362

353363
def iteritems(self):
354364
""" Return all merged items as iterator """
@@ -462,7 +472,6 @@ def __init__(self, memory_limit, serializer=None):
462472
self.memory_limit = memory_limit
463473
self.local_dirs = _get_local_dirs("sort")
464474
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
465-
self._spilled_bytes = 0
466475

467476
def _get_path(self, n):
468477
""" Choose one directory for spill by number n """
@@ -476,6 +485,7 @@ def sorted(self, iterator, key=None, reverse=False):
476485
Sort the elements in iterator, do external sort when the memory
477486
goes above the limit.
478487
"""
488+
global MemoryBytesSpilled, DiskBytesSpilled
479489
batch = 10
480490
chunks, current_chunk = [], []
481491
iterator = iter(iterator)
@@ -486,15 +496,18 @@ def sorted(self, iterator, key=None, reverse=False):
486496
if len(chunk) < batch:
487497
break
488498

489-
if get_used_memory() > self.memory_limit:
499+
used_memory = get_used_memory()
500+
if used_memory > self.memory_limit:
490501
# sort them inplace will save memory
491502
current_chunk.sort(key=key, reverse=reverse)
492503
path = self._get_path(len(chunks))
493504
with open(path, 'w') as f:
494505
self.serializer.dump_stream(current_chunk, f)
495-
self._spilled_bytes += os.path.getsize(path)
496506
chunks.append(self.serializer.load_stream(open(path)))
497507
current_chunk = []
508+
gc.collect()
509+
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
510+
DiskBytesSpilled += os.path.getsize(path)
498511

499512
elif not chunks:
500513
batch = min(batch * 2, 10000)

python/pyspark/tests.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer
4545
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
4646
from pyspark.sql import SQLContext, IntegerType
47+
from pyspark import shuffle
4748

4849
_have_scipy = False
4950
_have_numpy = False
@@ -136,17 +137,17 @@ def test_external_sort(self):
136137
random.shuffle(l)
137138
sorter = ExternalSorter(1)
138139
self.assertEquals(sorted(l), list(sorter.sorted(l)))
139-
self.assertGreater(sorter._spilled_bytes, 0)
140-
last = sorter._spilled_bytes
140+
self.assertGreater(shuffle.DiskBytesSpilled, 0)
141+
last = shuffle.DiskBytesSpilled
141142
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
142-
self.assertGreater(sorter._spilled_bytes, last)
143-
last = sorter._spilled_bytes
143+
self.assertGreater(shuffle.DiskBytesSpilled, last)
144+
last = shuffle.DiskBytesSpilled
144145
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
145-
self.assertGreater(sorter._spilled_bytes, last)
146-
last = sorter._spilled_bytes
146+
self.assertGreater(shuffle.DiskBytesSpilled, last)
147+
last = shuffle.DiskBytesSpilled
147148
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
148149
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))
149-
self.assertGreater(sorter._spilled_bytes, last)
150+
self.assertGreater(shuffle.DiskBytesSpilled, last)
150151

151152
def test_external_sort_in_rdd(self):
152153
conf = SparkConf().set("spark.python.worker.memory", "1m")

python/pyspark/worker.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@
2727
# copy_reg module.
2828
from pyspark.accumulators import _accumulatorRegistry
2929
from pyspark.broadcast import Broadcast, _broadcastRegistry
30-
from pyspark.cloudpickle import CloudPickler
3130
from pyspark.files import SparkFiles
3231
from pyspark.serializers import write_with_length, write_int, read_long, \
3332
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
3433
CompressedSerializer
35-
34+
from pyspark import shuffle
3635

3736
pickleSer = PickleSerializer()
3837
utf8_deserializer = UTF8Deserializer()
@@ -52,6 +51,11 @@ def main(infile, outfile):
5251
if split_index == -1: # for unit tests
5352
return
5453

54+
# initialize global state
55+
shuffle.MemoryBytesSpilled = 0
56+
shuffle.DiskBytesSpilled = 0
57+
_accumulatorRegistry.clear()
58+
5559
# fetch name of workdir
5660
spark_files_dir = utf8_deserializer.loads(infile)
5761
SparkFiles._root_directory = spark_files_dir
@@ -92,6 +96,9 @@ def main(infile, outfile):
9296
exit(-1)
9397
finish_time = time.time()
9498
report_times(outfile, boot_time, init_time, finish_time)
99+
write_long(shuffle.MemoryBytesSpilled, outfile)
100+
write_long(shuffle.DiskBytesSpilled, outfile)
101+
95102
# Mark the beginning of the accumulators section of the output
96103
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
97104
write_int(len(_accumulatorRegistry), outfile)

0 commit comments

Comments
 (0)