diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 8ad5cb70d248..e84f51877325 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -47,7 +47,8 @@ private[hive] case class HiveSimpleUDF( with HiveInspectors with CodegenFallback with Logging - with UserDefinedExpression { + with UserDefinedExpression + with ImplicitCastInputTypes { override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) @@ -69,6 +70,12 @@ private[hive] case class HiveSimpleUDF( udfType != null && udfType.deterministic() && !udfType.stateful() } + override def inputTypes: Seq[AbstractDataType] = { + method.getGenericParameterTypes.map(javaTypeToDataType).map { dt => + if (dt.existsRecursively(_.isInstanceOf[NullType])) AnyDataType else dt + } + } + override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) // Create parameter converters diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 057f2f4ce01b..189c483f0b55 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -658,6 +658,24 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } + test("SPARK-32877: Fix Hive UDF not support decimal type in complex type") { + withUserDefinedFunction("testArraySum" -> false) { + sql(s"CREATE FUNCTION testArraySum AS '${classOf[ArraySumUDF].getName}'") + checkAnswer( + sql("SELECT testArraySum(array(1, 1.1, 1.2))"), + Seq(Row(3.3))) + + val msg = intercept[AnalysisException] { + sql("SELECT testArraySum(1)") + }.getMessage + assert(msg.contains(s"No handler for UDF/UDAF/UDTF '${classOf[ArraySumUDF].getName}'")) + + val msg2 = intercept[AnalysisException] { + sql("SELECT testArraySum(1, 2)") + }.getMessage + assert(msg2.contains(s"No handler for UDF/UDAF/UDTF '${classOf[ArraySumUDF].getName}'")) + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { @@ -741,3 +759,14 @@ class StatelessUDF extends UDF { result } } + +class ArraySumUDF extends UDF { + import scala.collection.JavaConverters._ + def evaluate(values: java.util.List[java.lang.Double]): java.lang.Double = { + var r = 0d + for (v <- values.asScala) { + r += v + } + r + } +}