Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -994,20 +994,15 @@ case class ScalaUDF(
ctx: CodegenContext,
ev: ExprCode): ExprCode = {

ctx.references += this

val scalaUDFClassName = classOf[ScalaUDF].getName
val scalaUDF = ctx.addReferenceObj("scalaUDF", this)
val converterClassName = classOf[Any => Any].getName
val typeConvertersClassName = CatalystTypeConverters.getClass.getName + ".MODULE$"
val expressionClassName = classOf[Expression].getName

// Generate codes used to convert the returned value of user-defined functions to Catalyst type
val catalystConverterTerm = ctx.freshName("catalystConverter")
val catalystConverterTermIdx = ctx.references.size - 1
ctx.addMutableState(converterClassName, catalystConverterTerm,
s"this.$catalystConverterTerm = ($converterClassName)$typeConvertersClassName" +
s".createToCatalystConverter((($scalaUDFClassName)references" +
s"[$catalystConverterTermIdx]).dataType());")
s".createToCatalystConverter($scalaUDF.dataType());")

val resultTerm = ctx.freshName("result")

Expand All @@ -1019,10 +1014,8 @@ case class ScalaUDF(
val funcClassName = s"scala.Function${children.size}"

val funcTerm = ctx.freshName("udf")
val funcExpressionIdx = ctx.references.size - 1
ctx.addMutableState(funcClassName, funcTerm,
s"this.$funcTerm = ($funcClassName)((($scalaUDFClassName)references" +
s"[$funcExpressionIdx]).userDefinedFunc());")
s"this.$funcTerm = ($funcClassName)$scalaUDF.userDefinedFunc();")

// codegen for children expressions
val evals = children.map(_.genCode(ctx))
Expand All @@ -1039,9 +1032,18 @@ case class ScalaUDF(
(convert, argTerm)
}.unzip

val callFunc = s"${ctx.boxedType(dataType)} $resultTerm = " +
s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" +
s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));"
val getFuncResult = s"$funcTerm.apply(${funcArguments.mkString(", ")})"
val callFunc =
s"""
${ctx.boxedType(dataType)} $resultTerm = null;
try {
$resultTerm = (${ctx.boxedType(dataType)})$catalystConverterTerm.apply($getFuncResult);
} catch (NullPointerException e) {
NullPointerException npe = new NullPointerException($scalaUDF.npeErrorMessage());
npe.setStackTrace(e.getStackTrace());
throw npe;
}
""".stripMargin

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is margin?


ev.copy(code = s"""
$evalCode
Expand All @@ -1057,5 +1059,19 @@ case class ScalaUDF(

private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType)

override def eval(input: InternalRow): Any = converter(f(input))
val npeErrorMessage = "Given UDF throws NPE during execution, please check the UDF " +
"to make sure it handles null parameters correctly."

override def eval(input: InternalRow): Any = {
val result = try {
f(input)
} catch {
case e: NullPointerException =>

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit hacky to set stack trace like this.
npe.setStackTrace(e.getStackTrace)

If user search the code line reported in the stack trace, user may not able to find the code that matches the error message.

@clockfly clockfly Sep 1, 2016

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. For this code branch eval(input: InternalRow), the existing NPE message should be clear enough if there is a full stacktrace, and the stack contains method of the UDF.
  2. The error message you provided can be totally wrong.
    "Given UDF throws NPE during execution, please check the UDF to make sure it handles null parameters correctly".

What if NPE is not caused by null parameter? prompting this message is misleading.

val npe = new NullPointerException(npeErrorMessage)
npe.setStackTrace(e.getStackTrace)
throw npe
}

converter(result)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types.{IntegerType, StringType}

class ScalaUDFSuite extends SparkFunSuite with ExpressionEvalHelper {

test("basic") {
val intUdf = ScalaUDF((i: Int) => i + 1, IntegerType, Literal(1) :: Nil)
checkEvaluation(intUdf, 2)

val stringUdf = ScalaUDF((s: String) => s + "x", StringType, Literal("a") :: Nil)
checkEvaluation(stringUdf, "ax")
}

test("better error message for NPE") {
val udf = ScalaUDF(
(s: String) => s.toLowerCase,
StringType,
Literal.create(null, StringType) :: Nil)

val e1 = intercept[NullPointerException](udf.eval())
assert(e1.getMessage.contains("Given UDF throws NPE during execution"))

val e2 = intercept[NullPointerException] {
checkEvalutionWithUnsafeProjection(udf, null)
}
assert(e2.getMessage.contains("Given UDF throws NPE during execution"))
}

}