2020import traceback
2121
2222
23- __all__ = [ "extract_concise_traceback " , "SparkContext" ]
23+ CallSite = namedtuple ( "CallSite " , "function file linenum" )
2424
2525
26- def extract_concise_traceback ():
26+ def first_spark_call ():
2727 """
28- This function returns the traceback info for a callsite, returns a dict
29- with function name, file name and line number
28+ Return a CallSite representing the first Spark call in the current call stack.
3029 """
3130 tb = traceback .extract_stack ()
32- callsite = namedtuple ("Callsite" , "function file linenum" )
3331 if len (tb ) == 0 :
3432 return None
3533 file , line , module , what = tb [len (tb ) - 1 ]
@@ -42,39 +40,39 @@ def extract_concise_traceback():
4240 break
4341 if first_spark_frame == 0 :
4442 file , line , fun , what = tb [0 ]
45- return callsite (function = fun , file = file , linenum = line )
43+ return CallSite (function = fun , file = file , linenum = line )
4644 sfile , sline , sfun , swhat = tb [first_spark_frame ]
4745 ufile , uline , ufun , uwhat = tb [first_spark_frame - 1 ]
48- return callsite (function = sfun , file = ufile , linenum = uline )
46+ return CallSite (function = sfun , file = ufile , linenum = uline )
4947
5048
51- class JavaStackTrace (object ):
49+ class SCCallSiteSync (object ):
5250 """
5351 Helper for setting the spark context call site.
5452
5553 Example usage:
56- from pyspark.context import JavaStackTrace
57- with JavaStackTrace (<relevant SparkContext>) as st :
54+ from pyspark.context import SCCallSiteSync
55+ with SCCallSiteSync (<relevant SparkContext>) as css :
5856 <a Spark call>
5957 """
6058
6159 _spark_stack_depth = 0
6260
6361 def __init__ (self , sc ):
64- tb = extract_concise_traceback ()
65- if tb is not None :
66- self ._traceback = "%s at %s:%s" % (
67- tb .function , tb .file , tb .linenum )
62+ call_site = first_spark_call ()
63+ if call_site is not None :
64+ self ._call_site = "%s at %s:%s" % (
65+ call_site .function , call_site .file , call_site .linenum )
6866 else :
69- self ._traceback = "Error! Could not extract traceback info"
67+ self ._call_site = "Error! Could not extract traceback info"
7068 self ._context = sc
7169
7270 def __enter__ (self ):
73- if JavaStackTrace ._spark_stack_depth == 0 :
74- self ._context ._jsc .setCallSite (self ._traceback )
75- JavaStackTrace ._spark_stack_depth += 1
71+ if SCCallSiteSync ._spark_stack_depth == 0 :
72+ self ._context ._jsc .setCallSite (self ._call_site )
73+ SCCallSiteSync ._spark_stack_depth += 1
7674
7775 def __exit__ (self , type , value , tb ):
78- JavaStackTrace ._spark_stack_depth -= 1
79- if JavaStackTrace ._spark_stack_depth == 0 :
76+ SCCallSiteSync ._spark_stack_depth -= 1
77+ if SCCallSiteSync ._spark_stack_depth == 0 :
8078 self ._context ._jsc .setCallSite (None )
0 commit comments