From d8d0b769272d9a333b927e3ad78e6cfef4d49797 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Tue, 26 Aug 2014 02:53:10 -0700 Subject: [PATCH 1/2] Fix udfs that return structs --- .../spark/sql/catalyst/expressions/ScalaUdf.scala | 11 ++++++++++- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 12 ++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 95633dd0c987..7311e34aab3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -27,6 +27,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def references = children.flatMap(_.references).toSet def nullable = true + override def toString = s"scalaUDF(${children.mkString(",")})" + /** This method has been generated by this script (1 to 22).map { x => @@ -42,9 +44,14 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi */ + def convert(a: Any): Any = a match { + case p: Product => Row.fromSeq(p.productIterator.map(convert).toSeq) + case other => other + } + // scalastyle:off override def eval(input: Row): Any = { - children.size match { + val result = children.size match { case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => @@ -343,5 +350,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi children(21).eval(input)) } // scalastyle:on + + convert(result) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 76aa9b0081d7..ef9b76b1e251 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +case class FunctionResult(f1: String, f2: String) + class UDFSuite extends QueryTest { test("Simple UDF") { @@ -33,4 +35,14 @@ class UDFSuite extends QueryTest { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) } + + + test("struct UDF") { + registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + + val result= + sql("SELECT returnStruct('test', 'test2') as ret") + .select("ret.f1".attr).first().getString(0) + assert(result == "test") + } } From 8e29b1c703986303262e3a8a1f05acf62db7b115 Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Wed, 27 Aug 2014 15:23:24 -0700 Subject: [PATCH 2/2] Use existing function --- .../apache/spark/sql/catalyst/ScalaReflection.scala | 11 +++++++++-- .../spark/sql/catalyst/expressions/ScalaUdf.scala | 8 ++------ .../apache/spark/sql/execution/basicOperators.scala | 10 ++-------- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 0d26b52a8469..9c11e894a095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ @@ -32,6 +31,14 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) + /** Converts Scala objects to catalyst rows / types */ + def convertToCatalyst(a: Any): Any = a match { + case o: Option[_] => o.orNull + case s: Seq[Any] => s.map(convertToCatalyst) + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 7311e34aab3d..a5932263d3d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression]) @@ -44,11 +45,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi */ - def convert(a: Any): Any = a match { - case p: Product => Row.fromSeq(p.productIterator.map(convert).toSeq) - case other => other - } - // scalastyle:off override def eval(input: Row): Any = { val result = children.size match { @@ -351,6 +347,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi } // scalastyle:on - convert(result) + ScalaReflection.convertToCatalyst(result) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index f9dfa3c92f1e..4abda21ffec9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -204,13 +204,6 @@ case class Sort( */ @DeveloperApi object ExistingRdd { - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull - case s: Seq[Any] => s.map(convertToCatalyst) - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { @@ -222,7 +215,7 @@ object ExistingRdd { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) i += 1 } @@ -244,6 +237,7 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } + /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet.