Skip to content

Commit f0c7e19

Browse files
committed
[SPARK-3680][SQL] Fix bug caused by eager typing of HiveGenericUDFs
Typing of UDFs should be lazy as it is often not valid to call `dataType` on an expression until after all of its children are `resolved`. Author: Michael Armbrust <[email protected]> Closes #2525 from marmbrus/concatBug and squashes the following commits: 5b8efe7 [Michael Armbrust] fix bug with eager typing of udfs
1 parent 0800881 commit f0c7e19

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq
151151
override def get(): AnyRef = wrap(func())
152152
}
153153

154-
val dataType: DataType = inspectorToDataType(returnInspector)
154+
lazy val dataType: DataType = inspectorToDataType(returnInspector)
155155

156156
override def eval(input: Row): Any = {
157157
returnInspector // Make sure initialized.

sql/hive/src/test/scala/org/apache/spark/sql/parquet/ParquetMetastoreSuite.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.parquet
2020

2121
import java.io.File
2222

23+
import org.apache.spark.sql.catalyst.expressions.Row
2324
import org.scalatest.BeforeAndAfterAll
2425

2526
import org.apache.spark.sql.QueryTest
@@ -142,15 +143,21 @@ class ParquetMetastoreSuite extends QueryTest with BeforeAndAfterAll {
142143
test("sum") {
143144
checkAnswer(
144145
sql("SELECT SUM(intField) FROM partitioned_parquet WHERE intField IN (1,2,3) AND p = 1"),
145-
1 + 2 + 3
146-
)
146+
1 + 2 + 3)
147+
}
148+
149+
test("hive udfs") {
150+
checkAnswer(
151+
sql("SELECT concat(stringField, stringField) FROM partitioned_parquet"),
152+
sql("SELECT stringField FROM partitioned_parquet").map {
153+
case Row(s: String) => Row(s + s)
154+
}.collect().toSeq)
147155
}
148156

149157
test("non-part select(*)") {
150158
checkAnswer(
151159
sql("SELECT COUNT(*) FROM normal_parquet"),
152-
10
153-
)
160+
10)
154161
}
155162

156163
test("conversion is working") {

0 commit comments

Comments
 (0)