diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 6b187f05604f..3492d2c6189e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.expressions.Expression trait CodegenFallback extends Expression { protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + foreach { + case n: Nondeterministic => n.setInitialValues() + case _ => + } + ctx.references += this val objectTerm = ctx.freshName("obj") s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index 866bf904e4a4..2d3f98dbbd3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, LeafExpression} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{BooleanType, DataType} /** @@ -27,7 +27,32 @@ import org.apache.spark.sql.types.{BooleanType, DataType} */ class CodegenExpressionCachingSuite extends SparkFunSuite { - test("GenerateUnsafeProjection") { + test("GenerateUnsafeProjection should initialize expressions") { + // Use an Add to wrap two of them together in case we only initialize the top level expressions. + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = UnsafeProjection.create(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateProjection.generate(Seq(expr)) + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GenerateMutableProjection should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GenerateMutableProjection.generate(Seq(expr))() + assert(instance.apply(null).getBoolean(0) === false) + } + + test("GeneratePredicate should initialize expressions") { + val expr = And(NondeterministicExpression(), NondeterministicExpression()) + val instance = GeneratePredicate.generate(expr) + assert(instance.apply(null) === false) + } + + test("GenerateUnsafeProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = UnsafeProjection.create(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) @@ -39,7 +64,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateProjection") { + test("GenerateProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateProjection.generate(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) @@ -51,7 +76,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GenerateMutableProjection") { + test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() val instance1 = GenerateMutableProjection.generate(Seq(expr1))() assert(instance1.apply(null).getBoolean(0) === false) @@ -63,7 +88,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { assert(instance2.apply(null).getBoolean(0) === true) } - test("GeneratePredicate") { + test("GeneratePredicate should not share expression instances") { val expr1 = MutableExpression() val instance1 = GeneratePredicate.generate(expr1) assert(instance1.apply(null) === false) @@ -77,6 +102,17 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { } +/** + * An expression that's non-deterministic and doesn't support codegen. + */ +case class NondeterministicExpression() + extends LeafExpression with Nondeterministic with CodegenFallback { + override protected def initInternal(): Unit = { } + override protected def evalInternal(input: InternalRow): Any = false + override def nullable: Boolean = false + override def dataType: DataType = BooleanType +} + /** * An expression with mutable state so we can change it freely in our test suite.