Skip to content

Commit 76e3ba4

Browse files
committed
[SPARK-3230][SQL] Fix udfs that return structs
We need to convert the case classes into Rows. Author: Michael Armbrust <[email protected]> Closes #2133 from marmbrus/structUdfs and squashes the following commits: 189722f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into structUdfs 8e29b1c [Michael Armbrust] Use existing function d8d0b76 [Michael Armbrust] Fix udfs that return structs
1 parent 68f75dc commit 76e3ba4

File tree

4 files changed

+30
-12
lines changed

4 files changed

+30
-12
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst
1919

2020
import java.sql.Timestamp
2121

22-
import org.apache.spark.sql.catalyst.expressions.Attribute
23-
import org.apache.spark.sql.catalyst.expressions.AttributeReference
22+
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
2423
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2524
import org.apache.spark.sql.catalyst.types._
2625

@@ -32,6 +31,15 @@ object ScalaReflection {
3231

3332
case class Schema(dataType: DataType, nullable: Boolean)
3433

34+
/** Converts Scala objects to catalyst rows / types */
35+
def convertToCatalyst(a: Any): Any = a match {
36+
case o: Option[_] => o.orNull
37+
case s: Seq[_] => s.map(convertToCatalyst)
38+
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
39+
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
40+
case other => other
41+
}
42+
3543
/** Returns a Sequence of attributes for the given case class type. */
3644
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
3745
case Schema(s: StructType, _) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.ScalaReflection
2021
import org.apache.spark.sql.catalyst.types.DataType
2122
import org.apache.spark.util.ClosureCleaner
2223

@@ -27,6 +28,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
2728

2829
def nullable = true
2930

31+
override def toString = s"scalaUDF(${children.mkString(",")})"
32+
3033
/** This method has been generated by this script
3134
3235
(1 to 22).map { x =>
@@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
4447

4548
// scalastyle:off
4649
override def eval(input: Row): Any = {
47-
children.size match {
50+
val result = children.size match {
4851
case 0 => function.asInstanceOf[() => Any]()
4952
case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
5053
case 2 =>
@@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
343346
children(21).eval(input))
344347
}
345348
// scalastyle:on
349+
350+
ScalaReflection.convertToCatalyst(result)
346351
}
347352
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,14 +204,6 @@ case class Sort(
204204
*/
205205
@DeveloperApi
206206
object ExistingRdd {
207-
def convertToCatalyst(a: Any): Any = a match {
208-
case o: Option[_] => o.orNull
209-
case s: Seq[_] => s.map(convertToCatalyst)
210-
case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
211-
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
212-
case other => other
213-
}
214-
215207
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
216208
data.mapPartitions { iterator =>
217209
if (iterator.isEmpty) {
@@ -223,7 +215,7 @@ object ExistingRdd {
223215
bufferedIterator.map { r =>
224216
var i = 0
225217
while (i < mutableRow.length) {
226-
mutableRow(i) = convertToCatalyst(r.productElement(i))
218+
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
227219
i += 1
228220
}
229221

@@ -245,6 +237,7 @@ object ExistingRdd {
245237
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
246238
override def execute() = rdd
247239
}
240+
248241
/**
249242
* :: DeveloperApi ::
250243
* Computes the set of distinct input rows using a HashSet.

sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ import org.apache.spark.sql.test._
2222
/* Implicits */
2323
import TestSQLContext._
2424

25+
case class FunctionResult(f1: String, f2: String)
26+
2527
class UDFSuite extends QueryTest {
2628

2729
test("Simple UDF") {
@@ -33,4 +35,14 @@ class UDFSuite extends QueryTest {
3335
registerFunction("strLenScala", (_: String).length + (_:Int))
3436
assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
3537
}
38+
39+
40+
test("struct UDF") {
41+
registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
42+
43+
val result=
44+
sql("SELECT returnStruct('test', 'test2') as ret")
45+
.select("ret.f1".attr).first().getString(0)
46+
assert(result == "test")
47+
}
3648
}

0 commit comments

Comments
 (0)