diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala index 0f6d86691b4d..07fa813a9892 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallback.scala @@ -17,24 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.codehaus.commons.compiler.CompileException -import org.codehaus.janino.InternalCompilerException +import scala.util.control.NonFatal -import org.apache.spark.TaskContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils -/** - * Catches compile error during code generation. - */ -object CodegenError { - def unapply(throwable: Throwable): Option[Exception] = throwable match { - case e: InternalCompilerException => Some(e) - case e: CompileException => Some(e) - case _ => None - } -} - /** * Defines values for `SQLConf` config of fallback mode. Use for test only. */ @@ -47,7 +35,7 @@ object CodegenObjectFactoryMode extends Enumeration { * error happens, it can fallback to interpreted implementation. In tests, we can use a SQL config * `SQLConf.CODEGEN_FACTORY_MODE` to control fallback behavior. */ -abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { +abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] extends Logging { def createObject(in: IN): OUT = { // We are allowed to choose codegen-only or no-codegen modes if under tests. @@ -63,7 +51,10 @@ abstract class CodeGeneratorWithInterpretedFallback[IN, OUT] { try { createCodeGeneratedObject(in) } catch { - case CodegenError(_) => createInterpretedObject(in) + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + createInterpretedObject(in) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 6493f0910057..226a4ddcffaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import scala.util.control.NonFatal + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.types.{DataType, StructType} @@ -180,7 +182,10 @@ object UnsafeProjection try { GenerateUnsafeProjection.generate(unsafeExprs, subexpressionEliminationEnabled) } catch { - case CodegenError(_) => InterpretedUnsafeProjection.createProjection(unsafeExprs) + case NonFatal(_) => + // We should have already seen the error message in `CodeGenerator` + logWarning("Expr codegen error and falling back to interpreter mode") + InterpretedUnsafeProjection.createProjection(unsafeExprs) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala index 531ca9a87370..28edd85ab6e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGeneratorWithInterpretedFallbackSuite.scala @@ -17,17 +17,33 @@ package org.apache.spark.sql.catalyst.expressions +import java.util.concurrent.ExecutionException + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, LongType} +import org.apache.spark.sql.types.IntegerType class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanTestBase { - test("UnsafeProjection with codegen factory mode") { - val input = Seq(LongType, IntegerType) - .zipWithIndex.map(x => BoundReference(x._2, x._1, true)) + object FailedCodegenProjection + extends CodeGeneratorWithInterpretedFallback[Seq[Expression], UnsafeProjection] { + + override protected def createCodeGeneratedObject(in: Seq[Expression]): UnsafeProjection = { + val invalidCode = new CodeAndComment("invalid code", Map.empty) + // We assume this compilation throws an exception + CodeGenerator.compile(invalidCode) + null + } + + override protected def createInterpretedObject(in: Seq[Expression]): UnsafeProjection = { + InterpretedUnsafeProjection.createProjection(in) + } + } + test("UnsafeProjection with codegen factory mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { val obj = UnsafeProjection.createObject(input) @@ -40,4 +56,24 @@ class CodeGeneratorWithInterpretedFallbackSuite extends SparkFunSuite with PlanT assert(obj.isInstanceOf[InterpretedUnsafeProjection]) } } + + test("fallback to the interpreter mode") { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val fallback = CodegenObjectFactoryMode.FALLBACK.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> fallback) { + val obj = FailedCodegenProjection.createObject(input) + assert(obj.isInstanceOf[InterpretedUnsafeProjection]) + } + } + + test("codegen failures in the CODEGEN_ONLY mode") { + val errMsg = intercept[ExecutionException] { + val input = Seq(BoundReference(0, IntegerType, nullable = true)) + val codegenOnly = CodegenObjectFactoryMode.CODEGEN_ONLY.toString + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenOnly) { + FailedCodegenProjection.createObject(input) + } + }.getMessage + assert(errMsg.contains("failed to compile: org.codehaus.commons.compiler.CompileException:")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 80f886ea1adc..1fc4de9e5601 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -21,6 +21,7 @@ import java.util.Locale import java.util.function.Supplier import scala.collection.mutable +import scala.util.control.NonFatal import org.apache.spark.broadcast import org.apache.spark.rdd.RDD @@ -582,7 +583,7 @@ case class WholeStageCodegenExec(child: SparkPlan)(val codegenStageId: Int) val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { - case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => + case NonFatal(_) if !Utils.isTesting && sqlContext.conf.codegenFallback => // We should already saw the error message logWarning(s"Whole-stage codegen disabled for plan (id=$codegenStageId):\n $treeString") return child.execute()