diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index a1500cbc305d..ed153d1f8894 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -178,19 +178,19 @@ object JavaTypeInference { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath case c if c == classOf[java.lang.Short] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Integer] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Long] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Double] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Byte] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Float] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Boolean] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.sql.Date] => StaticInvoke( @@ -298,7 +298,7 @@ object JavaTypeInference { p.getWriteMethod.getName -> setter }.toMap - val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) val result = InitializeJavaBean(newInstance, setters) if (path.nonEmpty) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 8a22b37d07fc..9784c969665d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -189,37 +189,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( @@ -349,7 +349,7 @@ object ScalaReflection extends ScalaReflection { } } - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) if (path.nonEmpty) { expressions.If( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7a4401cf5810..ad4beda9c491 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -133,7 +133,7 @@ object ExpressionEncoder { } val fromRowExpression = - NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls)) + NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 63bdf05ca7c2..6f3d5ba84c9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -55,7 +55,6 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) @@ -166,7 +165,6 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index d40cd9690573..fb404c12d5a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -165,7 +165,7 @@ case class Invoke( ${obj.code} ${argGen.map(_.code).mkString("\n")} - boolean ${ev.isNull} = ${obj.value} == null; + boolean ${ev.isNull} = ${obj.isNull}; $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value; @@ -178,8 +178,8 @@ object NewInstance { def apply( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = false, - dataType: DataType): NewInstance = + dataType: DataType, + propagateNull: Boolean = true): NewInstance = new NewInstance(cls, arguments, propagateNull, dataType, None) } @@ -231,7 +231,7 @@ case class NewInstance( s"new $className($argString)" } - if (propagateNull) { + if (propagateNull && argGen.nonEmpty) { val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -248,8 +248,8 @@ case class NewInstance( s""" $setup - $javaType ${ev.value} = $constructorCall; - final boolean ${ev.isNull} = ${ev.value} == null; + final $javaType ${ev.value} = $constructorCall; + final boolean ${ev.isNull} = false; """ } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 764ffdc0947c..bc36a55ae0ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -46,8 +46,8 @@ class EncoderResolutionSuite extends PlanTest { toExternalString('a.string), AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") ), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -60,8 +60,8 @@ class EncoderResolutionSuite extends PlanTest { toExternalString('a.int.cast(StringType)), AssertNotNull('b.long, cls.getName, "b", "Long") ), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } } @@ -88,11 +88,11 @@ class EncoderResolutionSuite extends PlanTest { AssertNotNull( GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), innerCls.getName, "b", "Long")), - false, - ObjectType(innerCls)) + ObjectType(innerCls), + propagateNull = false) )), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -114,11 +114,11 @@ class EncoderResolutionSuite extends PlanTest { AssertNotNull( GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), cls.getName, "b", "Long")), - false, - ObjectType(cls)), + ObjectType(cls), + propagateNull = false), 'b.int.cast(LongType)), - false, - ObjectType(classOf[Tuple2[_, _]])) + ObjectType(classOf[Tuple2[_, _]]), + propagateNull = false) compareExpressions(fromRowExpr, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7233e0f1b5ba..666699e18d4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -128,6 +128,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") + encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + // Kryo encoders encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) encodeDecodeTest(new KryoSerializable(15), "kryo object")(