Skip to content

Commit fec0166

Browse files
committed
Make Python function/line appear in the UI.
1 parent d812aee commit fec0166

File tree

1 file changed

+55
-11
lines changed

1 file changed

+55
-11
lines changed

python/pyspark/rdd.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import sys
2525
import shlex
26+
import traceback
2627
from subprocess import Popen, PIPE
2728
from tempfile import NamedTemporaryFile
2829
from threading import Thread
@@ -39,6 +40,46 @@
3940

4041
__all__ = ["RDD"]
4142

43+
def _extract_concise_traceback():
44+
tb = traceback.extract_stack()
45+
if len(tb) == 0:
46+
return "I'm lost!"
47+
# HACK: This function is in a file called 'rdd.py' in the top level of
48+
# everything PySpark. Just trim off the directory name and assume
49+
# everything in that tree is PySpark guts.
50+
file, line, module, what = tb[len(tb) - 1]
51+
sparkpath = os.path.dirname(file)
52+
first_spark_frame = len(tb) - 1
53+
for i in range(0, len(tb)):
54+
file, line, fun, what = tb[i]
55+
if file.startswith(sparkpath):
56+
first_spark_frame = i
57+
break
58+
if first_spark_frame == 0:
59+
file, line, fun, what = tb[0]
60+
return "%s at %s:%d" % (fun, file, line)
61+
sfile, sline, sfun, swhat = tb[first_spark_frame]
62+
ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
63+
return "%s at %s:%d" % (sfun, ufile, uline)
64+
65+
_spark_stack_depth = 0
66+
67+
class _JavaStackTrace(object):
68+
def __init__(self, sc):
69+
self._traceback = _extract_concise_traceback()
70+
self._context = sc
71+
72+
def __enter__(self):
73+
global _spark_stack_depth
74+
if _spark_stack_depth == 0:
75+
self._context._jsc.setCallSite(self._traceback)
76+
_spark_stack_depth += 1
77+
78+
def __exit__(self, type, value, tb):
79+
global _spark_stack_depth
80+
_spark_stack_depth -= 1
81+
if _spark_stack_depth == 0:
82+
self._context._jsc.setCallSite(None)
4283

4384
class RDD(object):
4485
"""
@@ -401,7 +442,8 @@ def collect(self):
401442
"""
402443
Return a list that contains all of the elements in this RDD.
403444
"""
404-
bytesInJava = self._jrdd.collect().iterator()
445+
with _JavaStackTrace(self.context) as st:
446+
bytesInJava = self._jrdd.collect().iterator()
405447
return list(self._collect_iterator_through_file(bytesInJava))
406448

407449
def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ def takeUpToNum(iterator):
582624
# TODO(shivaram): Similar to the scala implementation, update the take
583625
# method to scan multiple splits based on an estimate of how many elements
584626
# we have per-split.
585-
for partition in range(mapped._jrdd.splits().size()):
586-
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
587-
partitionsToTake[0] = partition
588-
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
589-
items.extend(mapped._collect_iterator_through_file(iterator))
590-
if len(items) >= num:
591-
break
627+
with _JavaStackTrace(self.context) as st:
628+
for partition in range(mapped._jrdd.splits().size()):
629+
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
630+
partitionsToTake[0] = partition
631+
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
632+
items.extend(mapped._collect_iterator_through_file(iterator))
633+
if len(items) >= num:
634+
break
592635
return items[:num]
593636

594637
def first(self):
@@ -765,9 +808,10 @@ def add_shuffle_key(split, iterator):
765808
yield outputSerializer.dumps(items)
766809
keyed = PipelinedRDD(self, add_shuffle_key)
767810
keyed._bypass_serializer = True
768-
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
769-
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
770-
id(partitionFunc))
811+
with _JavaStackTrace(self.context) as st:
812+
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
813+
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
814+
id(partitionFunc))
771815
jrdd = pairRDD.partitionBy(partitioner).values()
772816
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
773817
# This is required so that id(partitionFunc) remains unique, even if

0 commit comments

Comments
 (0)