Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@

package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression}

/**
* A trait that can be used to provide a fallback mode for expression code generation.
*/
trait CodegenFallback extends Expression {

protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
}

ctx.references += this
val objectTerm = ctx.freshName("obj")
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,40 @@ package org.apache.spark.sql.catalyst.expressions.codegen

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, LeafExpression}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{BooleanType, DataType}

/**
* A test suite that makes sure code generation handles expression internally states correctly.
*/
class CodegenExpressionCachingSuite extends SparkFunSuite {

test("GenerateUnsafeProjection") {
test("GenerateUnsafeProjection should initialize expressions") {
// Use an Add to wrap two of them together in case we only initialize the top level expressions.
val expr = And(NondeterministicExpression(), NondeterministicExpression())
val instance = UnsafeProjection.create(Seq(expr))
assert(instance.apply(null).getBoolean(0) === false)
}

test("GenerateProjection should initialize expressions") {
val expr = And(NondeterministicExpression(), NondeterministicExpression())
val instance = GenerateProjection.generate(Seq(expr))
assert(instance.apply(null).getBoolean(0) === false)
}

test("GenerateMutableProjection should initialize expressions") {
val expr = And(NondeterministicExpression(), NondeterministicExpression())
val instance = GenerateMutableProjection.generate(Seq(expr))()
assert(instance.apply(null).getBoolean(0) === false)
}

test("GeneratePredicate should initialize expressions") {
val expr = And(NondeterministicExpression(), NondeterministicExpression())
val instance = GeneratePredicate.generate(expr)
assert(instance.apply(null) === false)
}

test("GenerateUnsafeProjection should not share expression instances") {
val expr1 = MutableExpression()
val instance1 = UnsafeProjection.create(Seq(expr1))
assert(instance1.apply(null).getBoolean(0) === false)
Expand All @@ -39,7 +64,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
assert(instance2.apply(null).getBoolean(0) === true)
}

test("GenerateProjection") {
test("GenerateProjection should not share expression instances") {
val expr1 = MutableExpression()
val instance1 = GenerateProjection.generate(Seq(expr1))
assert(instance1.apply(null).getBoolean(0) === false)
Expand All @@ -51,7 +76,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
assert(instance2.apply(null).getBoolean(0) === true)
}

test("GenerateMutableProjection") {
test("GenerateMutableProjection should not share expression instances") {
val expr1 = MutableExpression()
val instance1 = GenerateMutableProjection.generate(Seq(expr1))()
assert(instance1.apply(null).getBoolean(0) === false)
Expand All @@ -63,7 +88,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
assert(instance2.apply(null).getBoolean(0) === true)
}

test("GeneratePredicate") {
test("GeneratePredicate should not share expression instances") {
val expr1 = MutableExpression()
val instance1 = GeneratePredicate.generate(expr1)
assert(instance1.apply(null) === false)
Expand All @@ -77,6 +102,17 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {

}

/**
* An expression that's non-deterministic and doesn't support codegen.
*/
case class NondeterministicExpression()
extends LeafExpression with Nondeterministic with CodegenFallback {
override protected def initInternal(): Unit = { }
override protected def evalInternal(input: InternalRow): Any = false
override def nullable: Boolean = false
override def dataType: DataType = BooleanType
}


/**
* An expression with mutable state so we can change it freely in our test suite.
Expand Down