Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
val end = System.nanoTime()
session.listenerManager.onSuccess(name, qe, end - start)
} catch {
case e: Exception =>
case e: Throwable =>
session.listenerManager.onFailure(name, qe, e)
throw e
}
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3373,7 +3373,7 @@ class Dataset[T] private[sql](
sparkSession.listenerManager.onSuccess(name, qe, end - start)
result
} catch {
case e: Exception =>
case e: Throwable =>
sparkSession.listenerManager.onFailure(name, qe, e)
throw e
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.util.control.NonFatal
import org.apache.spark.SparkConf
import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException}
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -59,7 +59,9 @@ trait QueryExecutionListener {
* @param funcName the name of the action that triggered this query.
* @param qe the QueryExecution object that carries detail information like logical plan,
* physical plan, etc.
* @param exception the exception that failed this query.
* @param exception the exception that failed this query. If `java.lang.Error` is thrown during
* execution, it will be wrapped with an `Exception` and it can be accessed by
* `exception.getCause`.
*
* @note This can be invoked by multiple different threads.
*/
Expand Down Expand Up @@ -129,7 +131,14 @@ class ExecutionListenerManager private extends Logging {
}
}

private[sql] def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
private[sql] def onFailure(funcName: String, qe: QueryExecution, t: Throwable): Unit = {
val exception = t match {
case e: Exception => e
case other: Throwable =>
val message = "Hit an error when executing a query" +
(if (other.getMessage == null) "" else s": ${other.getMessage}")
new QueryExecutionException(message, other)
}
readLock {
withErrorHandling { listener =>
listener.onFailure(funcName, qe, exception)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ package org.apache.spark.sql.util
import scala.collection.mutable.ArrayBuffer

import org.apache.spark._
import org.apache.spark.sql.{functions, AnalysisException, QueryTest}
import org.apache.spark.sql.{functions, AnalysisException, Dataset, QueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, InsertIntoTable, LogicalPlan, Project}
import org.apache.spark.sql.execution.{QueryExecution, WholeStageCodegenExec}
import org.apache.spark.sql.execution.{QueryExecution, QueryExecutionException, WholeStageCodegenExec}
import org.apache.spark.sql.execution.command.RunnableCommand
import org.apache.spark.sql.execution.datasources.{CreateTable, InsertIntoHadoopFsRelationCommand}
import org.apache.spark.sql.execution.datasources.json.JsonFileFormat
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.StringType

class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -213,4 +216,31 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext {
assert(exceptions.head._2 == e)
}
}

testQuietly("SPARK-31144: QueryExecutionListener should receive `java.lang.Error`") {
var e: Exception = null
val listener = new QueryExecutionListener {
override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {
e = exception
}
override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {}
}
spark.listenerManager.register(listener)

intercept[Error] {
Dataset.ofRows(spark, ErrorTestCommand("foo")).collect()
}
assert(e != null && e.isInstanceOf[QueryExecutionException]
&& e.getCause.isInstanceOf[Error] && e.getCause.getMessage == "foo")
spark.listenerManager.unregister(listener)
}
}

/** A test command that throws `java.lang.Error` during execution. */
case class ErrorTestCommand(foo: String) extends RunnableCommand {

override val output: Seq[Attribute] = Seq(AttributeReference("foo", StringType)())

override def run(sparkSession: SparkSession): Seq[Row] =
throw new java.lang.Error(foo)
}