Skip to content

Commit c5378f9

Browse files
committed
addressed review comments
1 parent b5473e3 commit c5378f9

File tree

4 files changed

+22
-32
lines changed

4 files changed

+22
-32
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -444,13 +444,7 @@ object ScalaReflection extends ScalaReflection {
444444
case dt @ (BooleanType | ByteType | ShortType | IntegerType | LongType |
445445
FloatType | DoubleType) =>
446446
val cls = input.dataType.asInstanceOf[ObjectType].cls
447-
if (cls.isAssignableFrom(classOf[Array[Boolean]]) ||
448-
cls.isAssignableFrom(classOf[Array[Byte]]) ||
449-
cls.isAssignableFrom(classOf[Array[Short]]) ||
450-
cls.isAssignableFrom(classOf[Array[Int]]) ||
451-
cls.isAssignableFrom(classOf[Array[Long]]) ||
452-
cls.isAssignableFrom(classOf[Array[Float]]) ||
453-
cls.isAssignableFrom(classOf[Array[Double]])) {
447+
if (cls.isArray && cls.getComponentType.isPrimitive) {
454448
StaticInvoke(
455449
classOf[UnsafeArrayData],
456450
ArrayType(dt, false),

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import scala.reflect.ClassTag
2323
import org.apache.spark.SparkException
2424
import org.apache.spark.sql.Row
2525
import org.apache.spark.sql.catalyst.expressions._
26-
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
26+
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
2727
import org.apache.spark.sql.catalyst.ScalaReflection
2828
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
2929
import org.apache.spark.sql.catalyst.expressions.objects._
@@ -122,28 +122,12 @@ object RowEncoder {
122122
case t @ ArrayType(et, cn) =>
123123
val cls = inputObject.dataType.asInstanceOf[ObjectType].cls
124124
et match {
125-
/*
126-
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType
127-
if !cn && (
128-
cls.isAssignableFrom(classOf[Array[Boolean]]) ||
129-
cls.isAssignableFrom(classOf[Array[Byte]]) ||
130-
cls.isAssignableFrom(classOf[Array[Short]]) ||
131-
cls.isAssignableFrom(classOf[Array[Int]]) ||
132-
cls.isAssignableFrom(classOf[Array[Long]]) ||
133-
cls.isAssignableFrom(classOf[Array[Float]]) ||
134-
cls.isAssignableFrom(classOf[Array[Double]])) =>
135-
print(s"1@ET: $et, $cn, $cls\n")
125+
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
136126
StaticInvoke(
137-
classOf[UnsafeArrayData],
138-
ArrayType(et, false),
139-
"fromPrimitiveArray",
127+
classOf[ArrayData],
128+
ObjectType(classOf[ArrayData]),
129+
"toArrayData",
140130
inputObject :: Nil)
141-
*/
142-
case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType =>
143-
NewInstance(
144-
classOf[GenericArrayData],
145-
inputObject :: Nil,
146-
dataType = t)
147131
case _ => MapObjects(
148132
element => serializerFor(ValidateExternalType(element, et), et),
149133
inputObject,
@@ -211,8 +195,7 @@ object RowEncoder {
211195
// as java.lang.Object.
212196
case _: DecimalType => ObjectType(classOf[java.lang.Object])
213197
// In order to support both Array and Seq in external row, we make this as java.lang.Object.
214-
case a @ ArrayType(et, cn) =>
215-
ObjectType(classOf[java.lang.Object])
198+
case _: ArrayType => ObjectType(classOf[java.lang.Object])
216199
case _ => externalDataTypeFor(dt)
217200
}
218201

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -993,7 +993,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
993993
case _: DecimalType =>
994994
Seq(classOf[java.math.BigDecimal], classOf[scala.math.BigDecimal], classOf[Decimal])
995995
.map(cls => s"$obj instanceof ${cls.getName}").mkString(" || ")
996-
case a @ ArrayType(et, cn) =>
996+
case _: ArrayType =>
997997
s"$obj instanceof ${classOf[Seq[_]].getName} || $obj.getClass().isArray()"
998998
case _ =>
999999
s"$obj instanceof ${ctx.boxedType(dataType)}"

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,22 @@ package org.apache.spark.sql.catalyst.util
1919

2020
import scala.reflect.ClassTag
2121

22-
import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
22+
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData}
2323
import org.apache.spark.sql.types.DataType
2424

25+
object ArrayData {
26+
def toArrayData(input: Any): ArrayData = input match {
27+
case a: Array[Boolean] => UnsafeArrayData.fromPrimitiveArray(a)
28+
case a: Array[Byte] => UnsafeArrayData.fromPrimitiveArray(a)
29+
case a: Array[Short] => UnsafeArrayData.fromPrimitiveArray(a)
30+
case a: Array[Int] => UnsafeArrayData.fromPrimitiveArray(a)
31+
case a: Array[Long] => UnsafeArrayData.fromPrimitiveArray(a)
32+
case a: Array[Float] => UnsafeArrayData.fromPrimitiveArray(a)
33+
case a: Array[Double] => UnsafeArrayData.fromPrimitiveArray(a)
34+
case other => new GenericArrayData(other)
35+
}
36+
}
37+
2538
abstract class ArrayData extends SpecializedGetters with Serializable {
2639
def numElements(): Int
2740

0 commit comments

Comments
 (0)