Skip to content

Commit ca67909

Browse files
committed
Merge pull request apache#311 from tmyklebu/master
SPARK-991: Report information gleaned from a Python stacktrace in the UI Scala: - Added setCallSite/clearCallSite to SparkContext and JavaSparkContext. These functions mutate a LocalProperty called "externalCallSite." - Add a wrapper, getCallSite, that checks for an externalCallSite and, if none is found, calls the usual Utils.formatSparkCallSite. - Change everything that calls Utils.formatSparkCallSite to call getCallSite instead. Except getCallSite. - Add wrappers to setCallSite/clearCallSite wrappers to JavaSparkContext. Python: - Add a gruesome hack to rdd.py that inspects the traceback and guesses what you want to see in the UI. - Add a RAII wrapper around said gruesome hack that calls setCallSite/clearCallSite as appropriate. - Wire said RAII wrapper up around three calls into the Scala code. I'm not sure that I hit all the spots with the RAII wrapper. I'm also not sure that my gruesome hack does exactly what we want. One could also approach this change by refactoring runJob/submitJob/runApproximateJob to take a call site, then threading that parameter through everything that needs to know it. One might object to the pointless-looking wrappers in JavaSparkContext. Unfortunately, I can't directly access the SparkContext from Python---or, if I can, I don't know how---so I need to wrap everything that matters in JavaSparkContext. Conflicts: core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
2 parents 3713f81 + fec0166 commit ca67909

File tree

4 files changed

+93
-15
lines changed

4 files changed

+93
-15
lines changed

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,26 @@ class SparkContext(
728728
conf.getOption("spark.home").orElse(Option(System.getenv("SPARK_HOME")))
729729
}
730730

731+
/**
732+
* Support function for API backtraces.
733+
*/
734+
def setCallSite(site: String) {
735+
setLocalProperty("externalCallSite", site)
736+
}
737+
738+
/**
739+
* Support function for API backtraces.
740+
*/
741+
def clearCallSite() {
742+
setLocalProperty("externalCallSite", null)
743+
}
744+
745+
private[spark] def getCallSite(): String = {
746+
val callSite = getLocalProperty("externalCallSite")
747+
if (callSite == null) return Utils.formatSparkCallSite
748+
callSite
749+
}
750+
731751
/**
732752
* Run a function on a given set of partitions in an RDD and pass the results to the given
733753
* handler function. This is the main entry point for all actions in Spark. The allowLocal
@@ -740,7 +760,7 @@ class SparkContext(
740760
partitions: Seq[Int],
741761
allowLocal: Boolean,
742762
resultHandler: (Int, U) => Unit) {
743-
val callSite = Utils.formatSparkCallSite
763+
val callSite = getCallSite
744764
val cleanedFunc = clean(func)
745765
logInfo("Starting job: " + callSite)
746766
val start = System.nanoTime
@@ -824,7 +844,7 @@ class SparkContext(
824844
func: (TaskContext, Iterator[T]) => U,
825845
evaluator: ApproximateEvaluator[U, R],
826846
timeout: Long): PartialResult[R] = {
827-
val callSite = Utils.formatSparkCallSite
847+
val callSite = getCallSite
828848
logInfo("Starting job: " + callSite)
829849
val start = System.nanoTime
830850
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
@@ -844,7 +864,7 @@ class SparkContext(
844864
resultFunc: => R): SimpleFutureAction[R] =
845865
{
846866
val cleanF = clean(processPartition)
847-
val callSite = Utils.formatSparkCallSite
867+
val callSite = getCallSite
848868
val waiter = dagScheduler.submitJob(
849869
rdd,
850870
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),

core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,20 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
411411
* changed at runtime.
412412
*/
413413
def getConf: SparkConf = sc.getConf
414+
415+
/**
416+
* Pass-through to SparkContext.setCallSite. For API support only.
417+
*/
418+
def setCallSite(site: String) {
419+
sc.setCallSite(site)
420+
}
421+
422+
/**
423+
* Pass-through to SparkContext.setCallSite. For API support only.
424+
*/
425+
def clearCallSite() {
426+
sc.clearCallSite()
427+
}
414428
}
415429

416430
object JavaSparkContext {

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ abstract class RDD[T: ClassTag](
953953
private var storageLevel: StorageLevel = StorageLevel.NONE
954954

955955
/** Record user function generating this RDD. */
956-
@transient private[spark] val origin = Utils.formatSparkCallSite
956+
@transient private[spark] val origin = sc.getCallSite
957957

958958
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
959959

python/pyspark/rdd.py

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import os
2424
import sys
2525
import shlex
26+
import traceback
2627
from subprocess import Popen, PIPE
2728
from tempfile import NamedTemporaryFile
2829
from threading import Thread
@@ -39,6 +40,46 @@
3940

4041
__all__ = ["RDD"]
4142

43+
def _extract_concise_traceback():
44+
tb = traceback.extract_stack()
45+
if len(tb) == 0:
46+
return "I'm lost!"
47+
# HACK: This function is in a file called 'rdd.py' in the top level of
48+
# everything PySpark. Just trim off the directory name and assume
49+
# everything in that tree is PySpark guts.
50+
file, line, module, what = tb[len(tb) - 1]
51+
sparkpath = os.path.dirname(file)
52+
first_spark_frame = len(tb) - 1
53+
for i in range(0, len(tb)):
54+
file, line, fun, what = tb[i]
55+
if file.startswith(sparkpath):
56+
first_spark_frame = i
57+
break
58+
if first_spark_frame == 0:
59+
file, line, fun, what = tb[0]
60+
return "%s at %s:%d" % (fun, file, line)
61+
sfile, sline, sfun, swhat = tb[first_spark_frame]
62+
ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
63+
return "%s at %s:%d" % (sfun, ufile, uline)
64+
65+
_spark_stack_depth = 0
66+
67+
class _JavaStackTrace(object):
68+
def __init__(self, sc):
69+
self._traceback = _extract_concise_traceback()
70+
self._context = sc
71+
72+
def __enter__(self):
73+
global _spark_stack_depth
74+
if _spark_stack_depth == 0:
75+
self._context._jsc.setCallSite(self._traceback)
76+
_spark_stack_depth += 1
77+
78+
def __exit__(self, type, value, tb):
79+
global _spark_stack_depth
80+
_spark_stack_depth -= 1
81+
if _spark_stack_depth == 0:
82+
self._context._jsc.setCallSite(None)
4283

4384
class RDD(object):
4485
"""
@@ -401,7 +442,8 @@ def collect(self):
401442
"""
402443
Return a list that contains all of the elements in this RDD.
403444
"""
404-
bytesInJava = self._jrdd.collect().iterator()
445+
with _JavaStackTrace(self.context) as st:
446+
bytesInJava = self._jrdd.collect().iterator()
405447
return list(self._collect_iterator_through_file(bytesInJava))
406448

407449
def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ def takeUpToNum(iterator):
582624
# TODO(shivaram): Similar to the scala implementation, update the take
583625
# method to scan multiple splits based on an estimate of how many elements
584626
# we have per-split.
585-
for partition in range(mapped._jrdd.splits().size()):
586-
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
587-
partitionsToTake[0] = partition
588-
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
589-
items.extend(mapped._collect_iterator_through_file(iterator))
590-
if len(items) >= num:
591-
break
627+
with _JavaStackTrace(self.context) as st:
628+
for partition in range(mapped._jrdd.splits().size()):
629+
partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
630+
partitionsToTake[0] = partition
631+
iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
632+
items.extend(mapped._collect_iterator_through_file(iterator))
633+
if len(items) >= num:
634+
break
592635
return items[:num]
593636

594637
def first(self):
@@ -765,9 +808,10 @@ def add_shuffle_key(split, iterator):
765808
yield outputSerializer.dumps(items)
766809
keyed = PipelinedRDD(self, add_shuffle_key)
767810
keyed._bypass_serializer = True
768-
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
769-
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
770-
id(partitionFunc))
811+
with _JavaStackTrace(self.context) as st:
812+
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
813+
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
814+
id(partitionFunc))
771815
jrdd = pairRDD.partitionBy(partitioner).values()
772816
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
773817
# This is required so that id(partitionFunc) remains unique, even if

0 commit comments

Comments
 (0)