diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index bb13de563cdd..e1c5f007268a 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1,3 +1,4 @@ +# coding=utf-8 # # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with @@ -1859,6 +1860,31 @@ def test_with_different_versions_of_python(self): finally: self.sc.pythonVer = version + def test_exception_blocking(self): + """ + SPARK-21045 + Test whether program is blocked when occur exception in worker sending + exception to PythonRDD + + """ + import threading + + def run(): + try: + + def f(): + raise Exception("δΈ­") + + self.sc.parallelize([1]).map(lambda x: f()).count() + except Exception: + pass + + t = threading.Thread(target=run) + t.daemon = True + t.start() + t.join(10) + self.assertFalse(t.isAlive(), 'Spark executor is blocked.') + class SparkSubmitTests(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index baaa3fe074e9..11c6555b1fdc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -36,6 +36,9 @@ pickleSer = PickleSerializer() utf8_deserializer = UTF8Deserializer() +if sys.version >= '3': + unicode = str + def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) @@ -177,8 +180,11 @@ def process(): process() except Exception: try: + exc_info = traceback.format_exc() + if isinstance(exc_info, unicode): + exc_info = exc_info.encode('utf-8') write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(traceback.format_exc().encode("utf-8"), outfile) + write_with_length(exc_info, outfile) except IOError: # JVM close the socket pass