diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 087b21043b309..3dc2ee03a86e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -356,22 +356,25 @@ case class CreateNamedStruct(children: Seq[Expression]) extends CreateNamedStruc override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - ctx.addMutableState("Object[]", values, s"$values = null;") + val valCodes = valExprs.zipWithIndex.map { case (e, i) => + val eval = e.genCode(ctx) + s""" + |${eval.code} + |if (${eval.isNull}) { + | $values[$i] = null; + |} else { + | $values[$i] = ${eval.value}; + |} + """.stripMargin + } val valuesCode = ctx.splitExpressionsWithCurrentInputs( - valExprs.zipWithIndex.map { case (e, i) => - val eval = e.genCode(ctx) - s""" - ${eval.code} - if (${eval.isNull}) { - $values[$i] = null; - } else { - $values[$i] = ${eval.value}; - }""" - }) + expressions = valCodes, + funcName = "createNamedStruct", + extraArguments = "Object[]" -> values :: Nil) ev.copy(code = s""" - |$values = new Object[${valExprs.size}]; + |Object[] $values = new Object[${valExprs.size}]; |$valuesCode |final InternalRow ${ev.value} = new $rowClass($values); |$values = null; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 04e669492ec6d..a42dd7ecf57de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -344,17 +344,17 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } else { "" } - ctx.addMutableState(setName, setTerm, - s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();") - ev.copy(code = s""" - ${childGen.code} - boolean ${ev.isNull} = ${childGen.isNull}; - boolean ${ev.value} = false; - if (!${ev.isNull}) { - ${ev.value} = $setTerm.contains(${childGen.value}); - $setNull - } - """) + ev.copy(code = + s""" + |${childGen.code} + |${ctx.JAVA_BOOLEAN} ${ev.isNull} = ${childGen.isNull}; + |${ctx.JAVA_BOOLEAN} ${ev.value} = false; + |if (!${ev.isNull}) { + | $setName $setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet(); + | ${ev.value} = $setTerm.contains(${childGen.value}); + | $setNull + |} + """.stripMargin) } override def sql: String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index b0eaad1c80f89..6dfca7d73a3df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -299,4 +300,10 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("=")) .checkInputDataTypes().isFailure) } + + test("SPARK-22693: CreateNamedStruct should not use global variables") { + val ctx = new CodegenContext + CreateNamedStruct(Seq("a", "x", "b", 2.0)).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 0079e4e8d6f74..95a0dfa057563 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.RandomDataGenerator import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ @@ -429,4 +430,10 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { val infinity = Literal(Double.PositiveInfinity) checkEvaluation(EqualTo(infinity, infinity), true) } + + test("SPARK-22693: InSet should not use global variables") { + val ctx = new CodegenContext + InSet(Literal(1), Set(1, 2, 3, 4)).genCode(ctx) + assert(ctx.mutableStates.isEmpty) + } }