Skip to content

Commit 18ec598

Browse files
committed
Ensure no global variables in arguments of method split by CodegenContext.splitExpressions()
1 parent ee56fc3 commit 18ec598

File tree

5 files changed

+28
-21
lines changed

5 files changed

+28
-21
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -602,13 +602,13 @@ case class Least(children: Seq[Expression]) extends Expression {
602602

603603
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
604604
val evalChildren = children.map(_.genCode(ctx))
605-
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "leastTmpIsNull")
605+
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
606606
val evals = evalChildren.map(eval =>
607607
s"""
608608
|${eval.code}
609-
|if (!${eval.isNull} && ($tmpIsNull ||
609+
|if (!${eval.isNull} && (${ev.isNull} ||
610610
| ${ctx.genGreater(dataType, ev.value, eval.value)})) {
611-
| $tmpIsNull = false;
611+
| ${ev.isNull} = false;
612612
| ${ev.value} = ${eval.value};
613613
|}
614614
""".stripMargin
@@ -628,10 +628,9 @@ case class Least(children: Seq[Expression]) extends Expression {
628628
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
629629
ev.copy(code =
630630
s"""
631-
|$tmpIsNull = true;
631+
|${ev.isNull} = true;
632632
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
633633
|$codes
634-
|final boolean ${ev.isNull} = $tmpIsNull;
635634
""".stripMargin)
636635
}
637636
}
@@ -682,13 +681,13 @@ case class Greatest(children: Seq[Expression]) extends Expression {
682681

683682
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
684683
val evalChildren = children.map(_.genCode(ctx))
685-
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "greatestTmpIsNull")
684+
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
686685
val evals = evalChildren.map(eval =>
687686
s"""
688687
|${eval.code}
689-
|if (!${eval.isNull} && ($tmpIsNull ||
688+
|if (!${eval.isNull} && (${ev.isNull} ||
690689
| ${ctx.genGreater(dataType, eval.value, ev.value)})) {
691-
| $tmpIsNull = false;
690+
| ${ev.isNull} = false;
692691
| ${ev.value} = ${eval.value};
693692
|}
694693
""".stripMargin
@@ -708,10 +707,9 @@ case class Greatest(children: Seq[Expression]) extends Expression {
708707
foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
709708
ev.copy(code =
710709
s"""
711-
|$tmpIsNull = true;
710+
|${ev.isNull} = true;
712711
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
713712
|$codes
714-
|final boolean ${ev.isNull} = $tmpIsNull;
715713
""".stripMargin)
716714
}
717715
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,18 @@ class CodegenContext {
930930
// inline execution if only one block
931931
blocks.head
932932
} else {
933+
if (Utils.isTesting) {
934+
// Passing global variables to the split method is dangerous, as any mutating to it is
935+
// ignored and may lead to unexpected behavior.
936+
// We don't need to check `arrayCompactedMutableStates` here, as it results to array access
937+
// code and will raise compile error if we use it in parameter list.
938+
val mutableStateNames = inlinedMutableStates.map(_._2).toSet
939+
arguments.foreach { case (_, name) =>
940+
assert(!mutableStateNames.contains(name),
941+
s"split function argument $name cannot be a global variable.")
942+
}
943+
}
944+
933945
val func = freshName(funcName)
934946
val argString = arguments.map { case (t, name) => s"$t $name" }.mkString(", ")
935947
val functions = blocks.zipWithIndex.map { case (body, i) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ case class CaseWhen(
190190
// It is initialized to `NOT_MATCHED`, and if it's set to `HAS_NULL` or `HAS_NONNULL`,
191191
// We won't go on anymore on the computation.
192192
val resultState = ctx.freshName("caseWhenResultState")
193-
val tmpResult = ctx.addMutableState(ctx.javaType(dataType), "caseWhenTmpResult")
193+
ev.value = ctx.addMutableState(ctx.javaType(dataType), ev.value)
194194

195195
// these blocks are meant to be inside a
196196
// do {
@@ -205,7 +205,7 @@ case class CaseWhen(
205205
|if (!${cond.isNull} && ${cond.value}) {
206206
| ${res.code}
207207
| $resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
208-
| $tmpResult = ${res.value};
208+
| ${ev.value} = ${res.value};
209209
| continue;
210210
|}
211211
""".stripMargin
@@ -216,7 +216,7 @@ case class CaseWhen(
216216
s"""
217217
|${res.code}
218218
|$resultState = (byte)(${res.isNull} ? $HAS_NULL : $HAS_NONNULL);
219-
|$tmpResult = ${res.value};
219+
|${ev.value} = ${res.value};
220220
""".stripMargin
221221
}
222222

@@ -264,13 +264,11 @@ case class CaseWhen(
264264
ev.copy(code =
265265
s"""
266266
|${ctx.JAVA_BYTE} $resultState = $NOT_MATCHED;
267-
|$tmpResult = ${ctx.defaultValue(dataType)};
268267
|do {
269268
| $codes
270269
|} while (false);
271270
|// TRUE if any condition is met and the result is null, or no any condition is met.
272271
|final boolean ${ev.isNull} = ($resultState != $HAS_NONNULL);
273-
|final ${ctx.javaType(dataType)} ${ev.value} = $tmpResult;
274272
""".stripMargin)
275273
}
276274
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,15 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
7272
}
7373

7474
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
75-
val tmpIsNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, "coalesceTmpIsNull")
75+
ev.isNull = ctx.addMutableState(ctx.JAVA_BOOLEAN, ev.isNull)
7676

7777
// all the evals are meant to be in a do { ... } while (false); loop
7878
val evals = children.map { e =>
7979
val eval = e.genCode(ctx)
8080
s"""
8181
|${eval.code}
8282
|if (!${eval.isNull}) {
83-
| $tmpIsNull = false;
83+
| ${ev.isNull} = false;
8484
| ${ev.value} = ${eval.value};
8585
| continue;
8686
|}
@@ -103,7 +103,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
103103
foldFunctions = _.map { funcCall =>
104104
s"""
105105
|${ev.value} = $funcCall;
106-
|if (!$tmpIsNull) {
106+
|if (!${ev.isNull}) {
107107
| continue;
108108
|}
109109
""".stripMargin
@@ -112,12 +112,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
112112

113113
ev.copy(code =
114114
s"""
115-
|$tmpIsNull = true;
115+
|${ev.isNull} = true;
116116
|$resultType ${ev.value} = ${ctx.defaultValue(dataType)};
117117
|do {
118118
| $codes
119119
|} while (false);
120-
|final boolean ${ev.isNull} = $tmpIsNull;
121120
""".stripMargin)
122121
}
123122
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
285285
|${valueGen.code}
286286
|byte $tmpResult = $HAS_NULL;
287287
|if (!${valueGen.isNull}) {
288-
| $tmpResult = 0;
288+
| $tmpResult = $NOT_MATCHED;
289289
| $javaDataType $valueArg = ${valueGen.value};
290290
| do {
291291
| $codes

0 commit comments

Comments
 (0)