diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c1b1d5cd2dee..0aaf07204086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -407,10 +407,13 @@ object ScalaReflection extends ScalaReflection { val externalDataType = dataTypeFor(elementType) val Schema(catalystType, nullable) = silentSchemaFor(elementType) if (isNativeType(catalystType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) + expressions.If( + IsNull(input), + expressions.Literal.create(null, ArrayType(catalystType, nullable)), + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable))) } else { val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7233e0f1b5ba..4fa67981e538 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -27,6 +27,7 @@ import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} @@ -239,6 +240,24 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + test("null as array") { + val data = Seq( + (Array[Int](2, 1, 3), Array("b", "c", "a")), + (Array[Int](), Array[String]()), + (null, null) + ) + + val schema = ScalaReflection.schemaFor[Tuple2[Array[Int], Array[String]]] + .dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + val arrayDataEncoder = encoderFor[Tuple2[Array[Int], Array[String]]] + val boundEncoder = arrayDataEncoder.resolve(attributeSeq, outers).bind(attributeSeq) + data.foreach { x => + val convertedBack = boundEncoder.fromRow(boundEncoder.toRow(x)) + assert(convertedBack._1 === x._1 && convertedBack._2 === x._2) + } + } + test("nullable of encoder schema") { def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)