From 427d112f3fcff00076f2895dc3a47a1ba9e035a7 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Wed, 15 Jul 2020 00:24:38 +0800 Subject: [PATCH] fix --- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 6 ++++++ .../test/scala/org/apache/spark/sql/UDFSuite.scala | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 44ee06ae011af..6e2bd96784b94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -59,6 +59,12 @@ case class ScalaUDF( override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})" + override lazy val canonicalized: Expression = { + // SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't + // need it to identify a `ScalaUDF`. + Canonicalize.execute(copy(children = children.map(_.canonicalized), inputEncoders = Nil)) + } + /** * The analyzer should be aware of Scala primitive types so as to make the * UDF return null if there is any null input value of these types. On the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 05a33f9aa17bb..f0d5a61ad8006 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -775,4 +775,16 @@ class UDFSuite extends QueryTest with SharedSparkSession { } assert(e2.getMessage.contains("UDFSuite$MalformedClassObject$MalformedPrimitiveFunction")) } + + test("SPARK-32307: Aggression that use map type input UDF as group expression") { + spark.udf.register("key", udf((m: Map[String, String]) => m.keys.head.toInt)) + Seq(Map("1" -> "one", "2" -> "two")).toDF("a").createOrReplaceTempView("t") + checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil) + } + + test("SPARK-32307: Aggression that use array type input UDF as group expression") { + spark.udf.register("key", udf((m: Array[Int]) => m.head)) + Seq(Array(1)).toDF("a").createOrReplaceTempView("t") + checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil) + } }