diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index e4167c43ab9f..7f61b3f0b2c2 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -231,7 +231,7 @@ private[spark] object ThreadUtils { /** * Run a piece of code in a new thread and return the result. Exception in the new thread is * thrown in the caller thread with an adjusted stack trace that removes references to this - * method for clarity. The exception stack traces will be like the following + * method for clarity. The exception stack traces will be like the following: * * SomeException: exception-message * at CallerClass.body-method (sourcefile.scala) @@ -261,31 +261,51 @@ private[spark] object ThreadUtils { exception match { case Some(realException) => - // Remove the part of the stack that shows method calls into this helper method - // This means drop everything from the top until the stack element - // ThreadUtils.runInNewThread(), and then drop that as well (hence the `drop(1)`). - val baseStackTrace = Thread.currentThread().getStackTrace().dropWhile( - ! _.getClassName.contains(this.getClass.getSimpleName)).drop(1) - - // Remove the part of the new thread stack that shows methods call from this helper method - val extraStackTrace = realException.getStackTrace.takeWhile( - ! _.getClassName.contains(this.getClass.getSimpleName)) - - // Combine the two stack traces, with a place holder just specifying that there - // was a helper method used, without any further details of the helper - val placeHolderStackElem = new StackTraceElement( - s"... run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")} ..", - " ", "", -1) - val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace - - // Update the stack trace and rethrow the exception in the caller thread - realException.setStackTrace(finalStackTrace) - throw realException + throw wrapCallerStacktrace(realException, dropStacks = 2) case None => result } } + /** + * Adjust exception stack stace to wrap with caller side thread stack trace. + * The exception stack traces will be like the following: + * + * SomeException: exception-message + * at CallerClass.body-method (sourcefile.scala) + * at ... run in separate thread using org.apache.spark.util.ThreadUtils ... () + * at CallerClass.caller-method (sourcefile.scala) + * ... + */ + def wrapCallerStacktrace[T <: Throwable]( + realException: T, + combineMessage: String = + s"run in separate thread using ${ThreadUtils.getClass.getName.stripSuffix("$")}", + dropStacks: Int = 1): T = { + require(dropStacks >= 0, "dropStacks must be zero or positive") + val simpleName = this.getClass.getSimpleName + // Remove the part of the stack that shows method calls into this helper method + // This means drop everything from the top until the stack element + // ThreadUtils.wrapCallerStack(), and then drop that as well (hence the `drop(1)`). + // Large dropStacks allows caller to drop more stacks. + val baseStackTrace = Thread.currentThread().getStackTrace + .dropWhile(!_.getClassName.contains(simpleName)) + .drop(dropStacks) + + // Remove the part of the new thread stack that shows methods call from this helper method + val extraStackTrace = realException.getStackTrace + .takeWhile(!_.getClassName.contains(simpleName)) + + // Combine the two stack traces, with a place holder just specifying that there + // was a helper method used, without any further details of the helper + val placeHolderStackElem = new StackTraceElement(s"... $combineMessage ..", " ", "", -1) + val finalStackTrace = extraStackTrace ++ Seq(placeHolderStackElem) ++ baseStackTrace + + // Update the stack trace and rethrow the exception in the caller thread + realException.setStackTrace(finalStackTrace) + realException + } + /** * Construct a new ForkJoinPool with a specified max parallelism and name prefix. */ diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index d907fe1a27c8..04f661db691e 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -119,11 +119,46 @@ class ThreadUtilsSuite extends SparkFunSuite { runInNewThread("thread-name") { throw new IllegalArgumentException(uniqueExceptionMessage) } } assert(exception.getMessage === uniqueExceptionMessage) - assert(exception.getStackTrace.mkString("\n").contains( + val stacktrace = exception.getStackTrace.mkString("\n") + assert(stacktrace.contains( "... run in separate thread using org.apache.spark.util.ThreadUtils ..."), "stack trace does not contain expected place holder" ) - assert(exception.getStackTrace.mkString("\n").contains("ThreadUtils.scala") === false, + assert(!stacktrace.contains("ThreadUtils.scala"), + "stack trace contains unexpected references to ThreadUtils" + ) + } + + test("SPARK-47833: wrapCallerStacktrace") { + var runnerThreadName: String = null + var exception: Throwable = null + val t = new Thread() { + override def run(): Unit = { + runnerThreadName = Thread.currentThread().getName + internalMethod() + } + private def internalMethod(): Unit = { + throw new RuntimeException(s"Error occurred on $runnerThreadName") + } + } + t.setDaemon(true) + t.setUncaughtExceptionHandler { case (_, e) => exception = e } + t.start() + t.join() + + ThreadUtils.wrapCallerStacktrace(exception, s"run in separate thread: $runnerThreadName") + + val stacktrace = exception.getStackTrace.mkString("\n") + assert(stacktrace.contains("internalMethod"), + "stack trace does not contain real exception stack trace" + ) + assert(stacktrace.contains(s"... run in separate thread: $runnerThreadName ..."), + "stack trace does not contain expected place holder" + ) + assert(stacktrace.contains("org.scalatest.Suite.run"), + "stack trace does not contain caller stack trace" + ) + assert(!stacktrace.contains("ThreadUtils.scala"), "stack trace contains unexpected references to ThreadUtils" ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 4c2d6a4cdf5e..b03521507a06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -784,7 +784,7 @@ object DataSource extends Logging { globResult }.flatten } catch { - case e: SparkException => throw e.getCause + case e: SparkException => throw ThreadUtils.wrapCallerStacktrace(e.getCause) } if (checkFilesExist) { @@ -796,7 +796,7 @@ object DataSource extends Logging { } } } catch { - case e: SparkException => throw e.getCause + case e: SparkException => throw ThreadUtils.wrapCallerStacktrace(e.getCause) } }