Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -982,35 +982,30 @@ case class ScalaUDF(

// scalastyle:on line.size.limit

// Generate codes used to convert the arguments to Scala type for user-defined functions
private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): String = {
val converterClassName = classOf[Any => Any].getName
val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
val expressionClassName = classOf[Expression].getName
val scalaUDFClassName = classOf[ScalaUDF].getName
private val converterClassName = classOf[Any => Any].getName
private val expressionClassName = classOf[Expression].getName

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expression is pre-imported, we can just write Expression in the generated code.

private val scalaUDFClassName = classOf[ScalaUDF].getName
private val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"

// Generate codes used to convert the arguments to Scala type for user-defined functions
private[this] def genCodeForConverter(ctx: CodegenContext, index: Int): (String, String) = {
val converterTerm = ctx.freshName("converter")
val expressionIdx = ctx.references.size - 1
ctx.addMutableState(converterClassName, converterTerm,
s"$converterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToScalaConverter(((${expressionClassName})((($scalaUDFClassName)" +
s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
converterTerm
(converterTerm,
s"$converterClassName $converterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToScalaConverter((($expressionClassName)((($scalaUDFClassName)" +
s"references[$expressionIdx]).getChildren().apply($index))).dataType());")
}

override def doGenCode(
ctx: CodegenContext,
ev: ExprCode): ExprCode = {
val thisClassName = this.getClass.getName

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it just scalaUDFClassName?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks, nice catch! I am updating it. Thank you.

val scalaUDF = ctx.freshName("scalaUDF")
val scalaUDFRef = ctx.addReferenceMinorObj(this, thisClassName)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ctx.addReferenceMinorObj has a default value for class name, which is obj.getClass.getNane, so the thisClassName is redundant.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see, it's used later.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am using thisClassName also later (line 1045), that is why I passed it, despite it is not needed. What is your suggestion? Just not passing it as a parameter or getting rid of the thisClassName variable itself? Thanks,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok we can keep it.


val scalaUDF = ctx.addReferenceObj("scalaUDF", this)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We may need to revisit all the usage of ctx.addReferenceObj, I created https://issues.apache.org/jira/browse/SPARK-22716 for it. @mgaido91 do you have interests?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, sure, thanks. I would be happy to work on it.

val converterClassName = classOf[Any => Any].getName
val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"

// Generate codes used to convert the returned value of user-defined functions to Catalyst type
// Object to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
ctx.addMutableState(converterClassName, catalystConverterTerm,
s"$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToCatalystConverter($scalaUDF.dataType());")

val resultTerm = ctx.freshName("result")

Expand All @@ -1022,8 +1017,6 @@ case class ScalaUDF(
val funcClassName = s"scala.Function${children.size}"

val funcTerm = ctx.freshName("udf")
ctx.addMutableState(funcClassName, funcTerm,
s"$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")

// codegen for children expressions
val evals = children.map(_.genCode(ctx))
Expand All @@ -1033,34 +1026,45 @@ case class ScalaUDF(
// such as IntegerType, its javaType is `int` and the returned type of user-defined
// function is Object. Trying to convert an Object to `int` will cause casting exception.
val evalCode = evals.map(_.code).mkString
val (converters, funcArguments) = converterTerms.zipWithIndex.map { case (converter, i) =>
val eval = evals(i)
val argTerm = ctx.freshName("arg")
val convert = s"Object $argTerm = ${eval.isNull} ? null : $converter.apply(${eval.value});"
(convert, argTerm)
val (converters, funcArguments) = converterTerms.zipWithIndex.map {
case ((convName, convInit), i) =>
val eval = evals(i)
val argTerm = ctx.freshName("arg")
val convert =
s"""
|$convInit
|Object $argTerm = ${eval.isNull} ? null : $convName.apply(${eval.value});
""".stripMargin
(convert, argTerm)
}.unzip

val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
val callFunc =
s"""
${ctx.boxedType(dataType)} $resultTerm = null;
try {
$resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
} catch (Exception e) {
throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
}
"""
|${ctx.boxedType(dataType)} $resultTerm = null;
|$thisClassName $scalaUDF = $scalaUDFRef;
|try {
| $funcClassName $funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();
| $converterClassName $catalystConverterTerm = ($converterClassName)
| $typeConvertersClassName.createToCatalystConverter($scalaUDF.dataType());
| $resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
|} catch (Exception e) {
| throw new org.apache.spark.SparkException($scalaUDF.udfErrorMessage(), e);
|}
""".stripMargin

ev.copy(code = s"""
$evalCode
${converters.mkString("\n")}
$callFunc

boolean ${ev.isNull} = $resultTerm == null;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $resultTerm;
}""")
ev.copy(code =
s"""
|$evalCode
|${converters.mkString("\n")}
|$callFunc
|
|boolean ${ev.isNull} = $resultTerm == null;
|${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
|if (!${ev.isNull}) {
| ${ev.value} = $resultTerm;
|}
""".stripMargin)
}

private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Locale

import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -47,4 +48,10 @@ class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(e2.getMessage.contains("Failed to execute user defined function"))
}

test("SPARK-22695: ScalaUDF should not use global variables") {
val ctx = new CodegenContext
ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil).genCode(ctx)
// we have one variable (globalIsNull) introduced by reduceCodeSize

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow this simple UDF will trigger the code splitting logic in reduceCodeSize?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

assert(ctx.mutableStates.length == 1)
}
}