Skip to content

Commit bb5c334

Browse files
committed
[SPARK-9462][SQL] Initialize nondeterministic expressions in code gen fallback mode.
1 parent 2a9fe4a commit bb5c334

File tree

5 files changed

+60
-5
lines changed

5 files changed

+60
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
122122

123123
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
124124

125+
ctx.references.foreach {
126+
case n: Nondeterministic => n.setInitialValues()
127+
case _ =>
128+
}
129+
125130
val c = compile(code)
126131
() => {
127132
c.generate(ctx.references.toArray).asInstanceOf[MutableProjection]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool
6262

6363
logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}")
6464

65+
ctx.references.foreach {
66+
case n: Nondeterministic => n.setInitialValues()
67+
case _ =>
68+
}
69+
6570
val p = compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate]
6671
(r: InternalRow) => p.eval(r)
6772
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
233233
logDebug(s"MutableRow, initExprs: ${expressions.mkString(",")} code:\n" +
234234
CodeFormatter.format(code))
235235

236+
ctx.references.foreach {
237+
case n: Nondeterministic => n.setInitialValues()
238+
case _ =>
239+
}
240+
236241
compile(code).generate(ctx.references.toArray).asInstanceOf[Projection]
237242
}
238243
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
285285

286286
logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
287287

288+
ctx.references.foreach {
289+
case n: Nondeterministic => n.setInitialValues()
290+
case _ =>
291+
}
292+
288293
val c = compile(code)
289294
c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection]
290295
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,39 @@ package org.apache.spark.sql.catalyst.expressions.codegen
1919

2020
import org.apache.spark.SparkFunSuite
2121
import org.apache.spark.sql.catalyst.InternalRow
22-
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, LeafExpression}
22+
import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, UnsafeProjection, LeafExpression}
2323
import org.apache.spark.sql.types.{BooleanType, DataType}
2424

2525
/**
2626
* A test suite that makes sure code generation handles expression internally states correctly.
2727
*/
2828
class CodegenExpressionCachingSuite extends SparkFunSuite {
2929

30-
test("GenerateUnsafeProjection") {
30+
test("GenerateUnsafeProjection should initialize expressions") {
31+
val expr = NondeterministicExpression()
32+
val instance = UnsafeProjection.create(Seq(expr))
33+
assert(instance.apply(null).getBoolean(0) === false)
34+
}
35+
36+
test("GenerateProjection should initialize expressions") {
37+
val expr = NondeterministicExpression()
38+
val instance = GenerateProjection.generate(Seq(expr))
39+
assert(instance.apply(null).getBoolean(0) === false)
40+
}
41+
42+
test("GenerateMutableProjection should initialize expressions") {
43+
val expr = NondeterministicExpression()
44+
val instance = GenerateMutableProjection.generate(Seq(expr))()
45+
assert(instance.apply(null).getBoolean(0) === false)
46+
}
47+
48+
test("GeneratePredicate should initialize expressions") {
49+
val expr = NondeterministicExpression()
50+
val instance = GeneratePredicate.generate(expr)
51+
assert(instance.apply(null) === false)
52+
}
53+
54+
test("GenerateUnsafeProjection should not share expression instances") {
3155
val expr1 = MutableExpression()
3256
val instance1 = UnsafeProjection.create(Seq(expr1))
3357
assert(instance1.apply(null).getBoolean(0) === false)
@@ -39,7 +63,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
3963
assert(instance2.apply(null).getBoolean(0) === true)
4064
}
4165

42-
test("GenerateProjection") {
66+
test("GenerateProjection should not share expression instances") {
4367
val expr1 = MutableExpression()
4468
val instance1 = GenerateProjection.generate(Seq(expr1))
4569
assert(instance1.apply(null).getBoolean(0) === false)
@@ -51,7 +75,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
5175
assert(instance2.apply(null).getBoolean(0) === true)
5276
}
5377

54-
test("GenerateMutableProjection") {
78+
test("GenerateMutableProjection should not share expression instances") {
5579
val expr1 = MutableExpression()
5680
val instance1 = GenerateMutableProjection.generate(Seq(expr1))()
5781
assert(instance1.apply(null).getBoolean(0) === false)
@@ -63,7 +87,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
6387
assert(instance2.apply(null).getBoolean(0) === true)
6488
}
6589

66-
test("GeneratePredicate") {
90+
test("GeneratePredicate should not share expression instances") {
6791
val expr1 = MutableExpression()
6892
val instance1 = GeneratePredicate.generate(expr1)
6993
assert(instance1.apply(null) === false)
@@ -77,6 +101,17 @@ class CodegenExpressionCachingSuite extends SparkFunSuite {
77101

78102
}
79103

104+
/**
105+
* An expression that's non-deterministic and doesn't support codegen.
106+
*/
107+
case class NondeterministicExpression()
108+
extends LeafExpression with Nondeterministic with CodegenFallback {
109+
override protected def initInternal(): Unit = {}
110+
override protected def evalInternal(input: InternalRow): Any = false
111+
override def nullable: Boolean = false
112+
override def dataType: DataType = BooleanType
113+
}
114+
80115

81116
/**
82117
* An expression with mutable state so we can change it freely in our test suite.

0 commit comments

Comments
 (0)