From 4eaca3e4176eb8107f0e9f23d2d4116cd9cba04b Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 5 Dec 2017 14:59:49 +0100 Subject: [PATCH 1/3] [SPARK-22695][SQL] ScalaUDF should not use global variables --- .../sql/catalyst/expressions/ScalaUDF.scala | 90 ++++++++++--------- .../catalyst/expressions/ScalaUDFSuite.scala | 7 ++ 2 files changed, 54 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 179853032035e..f8ca35f9d827e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -982,35 +982,30 @@ case class ScalaUDF( // scalastyle:on line.size.limit - // Generate codes used to convert the arguments to Scala type for user-defined functions - private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = { - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - val expressionClassName = classOf[Expression].getName - val scalaUDFClassName = classOf[ScalaUDF].getName + private val converterClassName = classOf[Any => Any].getName + private val expressionClassName = classOf[Expression].getName + private val scalaUDFClassName = classOf[ScalaUDF].getName + private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" + // Generate codes used to convert the arguments to Scala type for user-defined functions + private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String) = { val converterTerm = ctx.freshName("converter") val expressionIdx = ctx.references.size - 1 - ctx.addMutableState(converterClassName, converterTerm, - s"$converterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" + - s"references[$expressionIdx]).getChildren().apply($index))).dataType());") - converterTerm + (converterTerm, + s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" + + s".createToScalaConverter((($expressionClassName)((($scalaUDFClassName)" + + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") } override def doGenCode( ctx: CodegenContext, ev: ExprCode): ExprCode = { + val thisClassName = this.getClass.getName + val scalaUDF = ctx.freshName("scalaUDF") + val scalaUDFRef = ctx.addReferenceMinorObj(this, thisClassName) - val scalaUDF = ctx.addReferenceObj("scalaUDF", this) - val converterClassName = classOf[Any => Any].getName - val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" - - // Generate codes used to convert the returned value of user-defined functions to Catalyst type + // Object to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") - ctx.addMutableState(converterClassName, catalystConverterTerm, - s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToCatalystConverter($scalaUDF.dataType());") val resultTerm = ctx.freshName("result") @@ -1022,8 +1017,6 @@ case class ScalaUDF( val funcClassName = s"scala.Function${children.size}" val funcTerm = ctx.freshName("udf") - ctx.addMutableState(funcClassName, funcTerm, - s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();") // codegen for children expressions val evals = children.map(_.genCode(ctx)) @@ -1033,34 +1026,45 @@ case class ScalaUDF( // such as IntegerType, its javaType is `int` and the returned type of user-defined // function is Object. Trying to convert an Object to `int` will cause casting exception. val evalCode = evals.map(_.code).mkString - val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) => - val eval = evals(i) - val argTerm = ctx.freshName("arg") - val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});" - (convert, argTerm) + val (converters, funcArguments) = converterTerms.zipWithIndex.map { + case ((convName, convInit), i) => + val eval = evals(i) + val argTerm = ctx.freshName("arg") + val convert = + s""" + |$convInit + |Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value}); + """.stripMargin + (convert, argTerm) }.unzip val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})" val callFunc = s""" - ${ctx.boxedType(dataType)} $resultTerm = null; - try { - $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); - } catch (Exception e) { - throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); - } - """ + |${ctx.boxedType(dataType)} $resultTerm = null; + |$thisClassName $scalaUDF = $scalaUDFRef; + |try { + | $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc(); + | $converterClassName $catalystConverterTerm = ($converterClassName) + | $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType()); + | $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult); + |} catch (Exception e) { + | throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e); + |} + """.stripMargin - ev.copy(code = s""" - $evalCode - ${converters.mkString("\n")} - $callFunc - - boolean ${ev.isNull} = $resultTerm == null; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $resultTerm; - }""") + ev.copy(code = + s""" + |$evalCode + |${converters.mkString("\n")} + |$callFunc + | + |boolean ${ev.isNull} = $resultTerm == null; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |if (!${ev.isNull}) { + | ${ev.value} = $resultTerm; + |} + """.stripMargin) } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index 13bd363c8b692..bc66b8bcb29e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.util.Locale import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types.{IntegerType, StringType} class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -47,4 +48,10 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { assert(e2.getMessage.contains("Failed to execute user defined function")) } + test("SPARK-22695: ScalaUDF should not use global variables") { + val ctx = new CodegenContext + ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) + // we have one variable (globalIsNull) introduced by reduceCodeSize + assert(ctx.mutableStates.length == 1) + } } From eef803639ba350eaf6b7c930b06d58d4d7a0a1f9 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 6 Dec 2017 14:28:29 +0100 Subject: [PATCH 2/3] address review comments --- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 3 +-- .../apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index f8ca35f9d827e..cd7d93fe25b14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -983,7 +983,6 @@ case class ScalaUDF( // scalastyle:on line.size.limit private val converterClassName = classOf[Any => Any].getName - private val expressionClassName = classOf[Expression].getName private val scalaUDFClassName = classOf[ScalaUDF].getName private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$" @@ -993,7 +992,7 @@ case class ScalaUDF( val expressionIdx = ctx.references.size - 1 (converterTerm, s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" + - s".createToScalaConverter((($expressionClassName)((($scalaUDFClassName)" + + s".createToScalaConverter(((Expression)((($scalaUDFClassName)" + s"references[$expressionIdx]).getChildren().apply($index))).dataType());") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala index bc66b8bcb29e3..70dea4b39d55d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDFSuite.scala @@ -51,7 +51,6 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-22695: ScalaUDF should not use global variables") { val ctx = new CodegenContext ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx) - // we have one variable (globalIsNull) introduced by reduceCodeSize - assert(ctx.mutableStates.length == 1) + assert(ctx.mutableStates.isEmpty) } } From f188d55083c20f85583929c04bf916ef494b744a Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Wed, 6 Dec 2017 14:48:12 +0100 Subject: [PATCH 3/3] remove useless thisClassName --- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index cd7d93fe25b14..4d26d9819321b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -999,9 +999,8 @@ case class ScalaUDF( override def doGenCode( ctx: CodegenContext, ev: ExprCode): ExprCode = { - val thisClassName = this.getClass.getName val scalaUDF = ctx.freshName("scalaUDF") - val scalaUDFRef = ctx.addReferenceMinorObj(this, thisClassName) + val scalaUDFRef = ctx.addReferenceMinorObj(this, scalaUDFClassName) // Object to convert the returned value of user-defined functions to Catalyst type val catalystConverterTerm = ctx.freshName("catalystConverter") @@ -1041,7 +1040,7 @@ case class ScalaUDF( val callFunc = s""" |${ctx.boxedType(dataType)} $resultTerm = null; - |$thisClassName $scalaUDF = $scalaUDFRef; + |$scalaUDFClassName $scalaUDF = $scalaUDFRef; |try { | $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc(); | $converterClassName $catalystConverterTerm = ($converterClassName)