Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ case class ScalaUDF(

override def toString: String = s"${udfName.getOrElse("UDF")}(${children.mkString(", ")})"

override lazy val canonicalized: Expression = {
// SPARK-32307: `ExpressionEncoder` can't be canonicalized, and technically we don't
// need it to identify a `ScalaUDF`.
Canonicalize.execute(copy(children = children.map(_.canonicalized), inputEncoders = Nil))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Ngone51 shall we do the same for outputEncoder?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense. I'll do a follow-up.

}

/**
* The analyzer should be aware of Scala primitive types so as to make the
* UDF return null if there is any null input value of these types. On the
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -775,4 +775,16 @@ class UDFSuite extends QueryTest with SharedSparkSession {
}
assert(e2.getMessage.contains("UDFSuite$MalformedClassObject$MalformedPrimitiveFunction"))
}

test("SPARK-32307: Aggression that use map type input UDF as group expression") {
spark.udf.register("key", udf((m: Map[String, String]) => m.keys.head.toInt))
Seq(Map("1" -> "one", "2" -> "two")).toDF("a").createOrReplaceTempView("t")
checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil)
}

test("SPARK-32307: Aggression that use array type input UDF as group expression") {
spark.udf.register("key", udf((m: Array[Int]) => m.head))
Seq(Array(1)).toDF("a").createOrReplaceTempView("t")
checkAnswer(sql("SELECT key(a) AS k FROM t GROUP BY key(a)"), Row(1) :: Nil)
}
}