|
23 | 23 | import os |
24 | 24 | import sys |
25 | 25 | import shlex |
| 26 | +import traceback |
26 | 27 | from subprocess import Popen, PIPE |
27 | 28 | from tempfile import NamedTemporaryFile |
28 | 29 | from threading import Thread |
|
39 | 40 |
|
40 | 41 | __all__ = ["RDD"] |
41 | 42 |
|
| 43 | +def _extract_concise_traceback(): |
| 44 | + tb = traceback.extract_stack() |
| 45 | + if len(tb) == 0: |
| 46 | + return "I'm lost!" |
| 47 | + # HACK: This function is in a file called 'rdd.py' in the top level of |
| 48 | + # everything PySpark. Just trim off the directory name and assume |
| 49 | + # everything in that tree is PySpark guts. |
| 50 | + file, line, module, what = tb[len(tb) - 1] |
| 51 | + sparkpath = os.path.dirname(file) |
| 52 | + first_spark_frame = len(tb) - 1 |
| 53 | + for i in range(0, len(tb)): |
| 54 | + file, line, fun, what = tb[i] |
| 55 | + if file.startswith(sparkpath): |
| 56 | + first_spark_frame = i |
| 57 | + break |
| 58 | + if first_spark_frame == 0: |
| 59 | + file, line, fun, what = tb[0] |
| 60 | + return "%s at %s:%d" % (fun, file, line) |
| 61 | + sfile, sline, sfun, swhat = tb[first_spark_frame] |
| 62 | + ufile, uline, ufun, uwhat = tb[first_spark_frame-1] |
| 63 | + return "%s at %s:%d" % (sfun, ufile, uline) |
| 64 | + |
| 65 | +_spark_stack_depth = 0 |
| 66 | + |
| 67 | +class _JavaStackTrace(object): |
| 68 | + def __init__(self, sc): |
| 69 | + self._traceback = _extract_concise_traceback() |
| 70 | + self._context = sc |
| 71 | + |
| 72 | + def __enter__(self): |
| 73 | + global _spark_stack_depth |
| 74 | + if _spark_stack_depth == 0: |
| 75 | + self._context._jsc.setCallSite(self._traceback) |
| 76 | + _spark_stack_depth += 1 |
| 77 | + |
| 78 | + def __exit__(self, type, value, tb): |
| 79 | + global _spark_stack_depth |
| 80 | + _spark_stack_depth -= 1 |
| 81 | + if _spark_stack_depth == 0: |
| 82 | + self._context._jsc.setCallSite(None) |
42 | 83 |
|
43 | 84 | class RDD(object): |
44 | 85 | """ |
@@ -401,7 +442,8 @@ def collect(self): |
401 | 442 | """ |
402 | 443 | Return a list that contains all of the elements in this RDD. |
403 | 444 | """ |
404 | | - bytesInJava = self._jrdd.collect().iterator() |
| 445 | + with _JavaStackTrace(self.context) as st: |
| 446 | + bytesInJava = self._jrdd.collect().iterator() |
405 | 447 | return list(self._collect_iterator_through_file(bytesInJava)) |
406 | 448 |
|
407 | 449 | def _collect_iterator_through_file(self, iterator): |
@@ -582,13 +624,14 @@ def takeUpToNum(iterator): |
582 | 624 | # TODO(shivaram): Similar to the scala implementation, update the take |
583 | 625 | # method to scan multiple splits based on an estimate of how many elements |
584 | 626 | # we have per-split. |
585 | | - for partition in range(mapped._jrdd.splits().size()): |
586 | | - partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) |
587 | | - partitionsToTake[0] = partition |
588 | | - iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() |
589 | | - items.extend(mapped._collect_iterator_through_file(iterator)) |
590 | | - if len(items) >= num: |
591 | | - break |
| 627 | + with _JavaStackTrace(self.context) as st: |
| 628 | + for partition in range(mapped._jrdd.splits().size()): |
| 629 | + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) |
| 630 | + partitionsToTake[0] = partition |
| 631 | + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() |
| 632 | + items.extend(mapped._collect_iterator_through_file(iterator)) |
| 633 | + if len(items) >= num: |
| 634 | + break |
592 | 635 | return items[:num] |
593 | 636 |
|
594 | 637 | def first(self): |
@@ -765,9 +808,10 @@ def add_shuffle_key(split, iterator): |
765 | 808 | yield outputSerializer.dumps(items) |
766 | 809 | keyed = PipelinedRDD(self, add_shuffle_key) |
767 | 810 | keyed._bypass_serializer = True |
768 | | - pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() |
769 | | - partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, |
770 | | - id(partitionFunc)) |
| 811 | + with _JavaStackTrace(self.context) as st: |
| 812 | + pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() |
| 813 | + partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, |
| 814 | + id(partitionFunc)) |
771 | 815 | jrdd = pairRDD.partitionBy(partitioner).values() |
772 | 816 | rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) |
773 | 817 | # This is required so that id(partitionFunc) remains unique, even if |
|
0 commit comments