Skip to content

Commit 8e29b1c

Browse files
committed
Use existing function
1 parent d8d0b76 commit 8e29b1c

File tree

3 files changed

+13
-16
lines changed

3 files changed

+13
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst
1919

2020
import java.sql.Timestamp
2121

22-
import org.apache.spark.sql.catalyst.expressions.Attribute
23-
import org.apache.spark.sql.catalyst.expressions.AttributeReference
22+
import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
2423
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
2524
import org.apache.spark.sql.catalyst.types._
2625

@@ -32,6 +31,14 @@ object ScalaReflection {
3231

3332
case class Schema(dataType: DataType, nullable: Boolean)
3433

34+
/** Converts Scala objects to catalyst rows / types */
35+
def convertToCatalyst(a: Any): Any = a match {
36+
case o: Option[_] => o.orNull
37+
case s: Seq[Any] => s.map(convertToCatalyst)
38+
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
39+
case other => other
40+
}
41+
3542
/** Returns a Sequence of attributes for the given case class type. */
3643
def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
3744
case Schema(s: StructType, _) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.ScalaReflection
2021
import org.apache.spark.sql.catalyst.types.DataType
2122

2223
case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expression])
@@ -44,11 +45,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
4445
4546
*/
4647

47-
def convert(a: Any): Any = a match {
48-
case p: Product => Row.fromSeq(p.productIterator.map(convert).toSeq)
49-
case other => other
50-
}
51-
5248
// scalastyle:off
5349
override def eval(input: Row): Any = {
5450
val result = children.size match {
@@ -351,6 +347,6 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
351347
}
352348
// scalastyle:on
353349

354-
convert(result)
350+
ScalaReflection.convertToCatalyst(result)
355351
}
356352
}

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,6 @@ case class Sort(
204204
*/
205205
@DeveloperApi
206206
object ExistingRdd {
207-
def convertToCatalyst(a: Any): Any = a match {
208-
case o: Option[_] => o.orNull
209-
case s: Seq[Any] => s.map(convertToCatalyst)
210-
case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
211-
case other => other
212-
}
213-
214207
def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
215208
data.mapPartitions { iterator =>
216209
if (iterator.isEmpty) {
@@ -222,7 +215,7 @@ object ExistingRdd {
222215
bufferedIterator.map { r =>
223216
var i = 0
224217
while (i < mutableRow.length) {
225-
mutableRow(i) = convertToCatalyst(r.productElement(i))
218+
mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
226219
i += 1
227220
}
228221

@@ -244,6 +237,7 @@ object ExistingRdd {
244237
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
245238
override def execute() = rdd
246239
}
240+
247241
/**
248242
* :: DeveloperApi ::
249243
* Computes the set of distinct input rows using a HashSet.

0 commit comments

Comments
 (0)