diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index bd400f86ea2c..8da67e086320 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} +import scala.util.{Failure, Success, Try} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -50,8 +51,20 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + + // Fold ScalaUDFs if they're deterministic and all arguments are foldable. + // Watch out for potentially exception-throwing UDFs: some Scala UDFs may have been + // mis-declared as being deterministic, but throws exceptions at runtime. Do not optimize + // them, so that they can throw the exception at the expected timing. + case udf: ScalaUDF if maybeFoldable(udf) => Try(udf.eval(EmptyRow)) match { + case Success(v) => Literal.create(v, udf.dataType) + case Failure(_) => udf // defer any exception throwing to execution phase + } } } + + private def maybeFoldable(udf: ScalaUDF): Boolean = + udf.deterministic && udf.children.forall(_.foldable) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 23ab6b2df3e6..1065e809aeed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} @@ -155,6 +156,59 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Constant folding test: deterministic Scala UDFs") { + val normalFunc = (x: Int) => x + 41 + val exceptionFunc = (x: Int) => x / 0 + + val intEncoder = Option(ExpressionEncoder[Int]()) + + val foldableUdf = ScalaUDF( + function = normalFunc, + dataType = IntegerType, + children = Seq(Literal(1)), + inputEncoders = Seq(intEncoder), + udfName = None, + nullable = false, + udfDeterministic = true) + + val deterministicUnfoldableUdf = ScalaUDF( + function = normalFunc, + dataType = IntegerType, + children = Seq[Expression]('a), + inputEncoders = Seq(intEncoder), + udfName = None, + nullable = false, + udfDeterministic = true) + + val exceptionUdf = ScalaUDF( + function = exceptionFunc, + dataType = IntegerType, + children = Seq(Literal(1)), + inputEncoders = Seq(intEncoder), + udfName = None, + nullable = false, + udfDeterministic = true) // intentionally mis-declaring as deterministic + + val originalQuery = + testRelation + .select( + foldableUdf as Symbol("c1"), + deterministicUnfoldableUdf as Symbol("c2"), + exceptionUdf as Symbol("c3")) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(42) as Symbol("c1"), + deterministicUnfoldableUdf as Symbol("c2"), + exceptionUdf as Symbol("c3")) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("Constant folding test: expressions have nonfoldable functions") { val originalQuery = testRelation