diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala index 2550634681453..e77128755363d 100644 --- a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala +++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala @@ -48,11 +48,26 @@ private[spark] class SparkUncaughtExceptionHandler(val exitOnUncaughtException: System.exit(SparkExitCode.OOM) case _ if exitOnUncaughtException => System.exit(SparkExitCode.UNCAUGHT_EXCEPTION) + case _ => + // SPARK-30310: Don't System.exit() when exitOnUncaughtException is false } } } catch { - case oom: OutOfMemoryError => Runtime.getRuntime.halt(SparkExitCode.OOM) - case t: Throwable => Runtime.getRuntime.halt(SparkExitCode.UNCAUGHT_EXCEPTION_TWICE) + case oom: OutOfMemoryError => + try { + logError(s"Uncaught OutOfMemoryError in thread $thread, process halted.", oom) + } catch { + // absorb any exception/error since we're halting the process + case _: Throwable => + } + Runtime.getRuntime.halt(SparkExitCode.OOM) + case t: Throwable => + try { + logError(s"Another uncaught exception in thread $thread, process halted.", t) + } catch { + case _: Throwable => + } + Runtime.getRuntime.halt(SparkExitCode.UNCAUGHT_EXCEPTION_TWICE) } } diff --git a/core/src/test/scala/org/apache/spark/util/SparkUncaughtExceptionHandlerSuite.scala b/core/src/test/scala/org/apache/spark/util/SparkUncaughtExceptionHandlerSuite.scala new file mode 100644 index 0000000000000..90741a6bde7f0 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SparkUncaughtExceptionHandlerSuite.scala @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.io.File + +import scala.util.Try + +import org.apache.spark.SparkFunSuite + +class SparkUncaughtExceptionHandlerSuite extends SparkFunSuite { + + private val sparkHome = + sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + + Seq( + (ThrowableTypes.RuntimeException, true, SparkExitCode.UNCAUGHT_EXCEPTION), + (ThrowableTypes.RuntimeException, false, 0), + (ThrowableTypes.OutOfMemoryError, true, SparkExitCode.OOM), + (ThrowableTypes.OutOfMemoryError, false, SparkExitCode.OOM), + (ThrowableTypes.SparkFatalRuntimeException, true, SparkExitCode.UNCAUGHT_EXCEPTION), + (ThrowableTypes.SparkFatalRuntimeException, false, 0), + (ThrowableTypes.SparkFatalOutOfMemoryError, true, SparkExitCode.OOM), + (ThrowableTypes.SparkFatalOutOfMemoryError, false, SparkExitCode.OOM) + ).foreach { + case (throwable: ThrowableTypes.ThrowableTypesVal, + exitOnUncaughtException: Boolean, expectedExitCode) => + test(s"SPARK-30310: Test uncaught $throwable, " + + s"exitOnUncaughtException = $exitOnUncaughtException") { + + // creates a ThrowableThrower process via spark-class and verify the exit code + val process = Utils.executeCommand( + Seq(s"$sparkHome/bin/spark-class", + ThrowableThrower.getClass.getCanonicalName.dropRight(1), // drops the "$" at the end + throwable.name, + exitOnUncaughtException.toString), + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome) + ) + assert(process.waitFor == expectedExitCode) + } + } +} + +// enumeration object for the Throwable types that SparkUncaughtExceptionHandler handles +object ThrowableTypes extends Enumeration { + + sealed case class ThrowableTypesVal(name: String, t: Throwable) extends Val(name) + + val RuntimeException = ThrowableTypesVal("RuntimeException", new RuntimeException) + val OutOfMemoryError = ThrowableTypesVal("OutOfMemoryError", new OutOfMemoryError) + val SparkFatalRuntimeException = ThrowableTypesVal("SparkFatalException(RuntimeException)", + new SparkFatalException(new RuntimeException)) + val SparkFatalOutOfMemoryError = ThrowableTypesVal("SparkFatalException(OutOfMemoryError)", + new SparkFatalException(new OutOfMemoryError)) + + // returns the actual Throwable by its name + def getThrowableByName(name: String): Throwable = { + super.withName(name).asInstanceOf[ThrowableTypesVal].t + } +} + +// Invoked by spark-class for throwing a Throwable +object ThrowableThrower { + + // a thread that uses SparkUncaughtExceptionHandler and throws a Throwable by name + class ThrowerThread(name: String, exitOnUncaughtException: Boolean) extends Thread { + override def run() { + Thread.setDefaultUncaughtExceptionHandler( + new SparkUncaughtExceptionHandler(exitOnUncaughtException)) + throw ThrowableTypes.getThrowableByName(name) + } + } + + // main() requires 2 args: + // - args(0): name of the Throwable defined in ThrowableTypes + // - args(1): exitOnUncaughtException (true/false) + // + // it exits with the exit code dictated by either: + // - SparkUncaughtExceptionHandler (SparkExitCode) + // - main() (0, or -1 when number of args is wrong) + def main(args: Array[String]): Unit = { + if (args.length == 2) { + val t = new ThrowerThread(args(0), + Try(args(1).toBoolean).getOrElse(false)) + t.start() + t.join() + System.exit(0) + } else { + System.exit(-1) + } + } +} diff --git a/python/pyspark/tests/test_worker.py b/python/pyspark/tests/test_worker.py index ccbe21f3a6f38..09ba70e295f8e 100644 --- a/python/pyspark/tests/test_worker.py +++ b/python/pyspark/tests/test_worker.py @@ -183,7 +183,7 @@ def test_reuse_worker_of_parallelize_xrange(self): class WorkerMemoryTest(PySparkTestCase): def test_memory_limit(self): - self.sc._conf.set("spark.executor.pyspark.memory", "1m") + self.sc._conf.set("spark.executor.pyspark.memory", "8m") rdd = self.sc.parallelize(xrange(1), 1) def getrlimit(): @@ -194,8 +194,8 @@ def getrlimit(): self.assertTrue(len(actual) == 1) self.assertTrue(len(actual[0]) == 2) [(soft_limit, hard_limit)] = actual - self.assertEqual(soft_limit, 1024 * 1024) - self.assertEqual(hard_limit, 1024 * 1024) + self.assertEqual(soft_limit, 8 * 1024 * 1024) + self.assertEqual(hard_limit, 8 * 1024 * 1024) if __name__ == "__main__":