diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index f90d3532995f3..b724e87776f87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -677,7 +677,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val end = System.nanoTime() session.listenerManager.onSuccess(name, qe, end - start) } catch { - case e: Exception => + case e: Throwable => session.listenerManager.onFailure(name, qe, e) throw e } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index c90b2e857e664..6bc0d0bfaa3f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3373,7 +3373,7 @@ class Dataset[T] private[sql]( sparkSession.listenerManager.onSuccess(name, qe, end - start) result } catch { - case e: Exception => + case e: Throwable => sparkSession.listenerManager.onFailure(name, qe, e) throw e } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index abfc6773e361f..4d51f2751d01d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -25,7 +25,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging -import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException} import org.apache.spark.sql.internal.StaticSQLConf._ import org.apache.spark.util.Utils @@ -59,7 +59,9 @@ trait QueryExecutionListener { * @param funcName the name of the action that triggered this query. * @param qe the QueryExecution object that carries detail information like logical plan, * physical plan, etc. - * @param exception the exception that failed this query. + * @param exception the exception that failed this query. If `java.lang.Error` is thrown during + * execution, it will be wrapped with an `Exception` and it can be accessed by + * `exception.getCause`. * * @note This can be invoked by multiple different threads. */ @@ -129,7 +131,14 @@ class ExecutionListenerManager private extends Logging { } } - private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + private[sql] def onFailure(funcName: String, qe: QueryExecution, t: Throwable): Unit = { + val exception = t match { + case e: Exception => e + case other: Throwable => + val message = "Hit an error when executing a query" + + (if (other.getMessage == null) "" else s": ${other.getMessage}") + new QueryExecutionException(message, other) + } readLock { withErrorHandling { listener => listener.onFailure(funcName, qe, exception) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index a239e39d9c5a3..8e7274ac7456e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -20,13 +20,16 @@ package org.apache.spark.sql.util import scala.collection.mutable.ArrayBuffer import org.apache.spark._ -import org.apache.spark.sql.{functions, AnalysisException, QueryTest} +import org.apache.spark.sql.{functions, AnalysisException, Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project} -import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec} +import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec} +import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand} import org.apache.spark.sql.execution.datasources.json.JsonFileFormat import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -213,4 +216,31 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(exceptions.head._2 == e) } } + + testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") { + var e: Exception = null + val listener = new QueryExecutionListener { + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + e = exception + } + override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} + } + spark.listenerManager.register(listener) + + intercept[Error] { + Dataset.ofRows(spark, ErrorTestCommand("foo")).collect() + } + assert(e != null && e.isInstanceOf[QueryExecutionException] + && e.getCause.isInstanceOf[Error] && e.getCause.getMessage == "foo") + spark.listenerManager.unregister(listener) + } +} + +/** A test command that throws `java.lang.Error` during execution. */ +case class ErrorTestCommand(foo: String) extends RunnableCommand { + + override val output: Seq[Attribute] = Seq(AttributeReference("foo", StringType)()) + + override def run(sparkSession: SparkSession): Seq[Row] = + throw new java.lang.Error(foo) }