Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -178,19 +178,19 @@ object JavaTypeInference {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath

case c if c == classOf[java.lang.Short] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Integer] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Long] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Double] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Byte] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Float] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))
case c if c == classOf[java.lang.Boolean] =>
NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c))
NewInstance(c, getPath :: Nil, ObjectType(c))

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
Expand Down Expand Up @@ -298,7 +298,7 @@ object JavaTypeInference {
p.getWriteMethod.getName -> setter
}.toMap

val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other))
val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
val result = InitializeJavaBean(newInstance, setters)

if (path.nonEmpty) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,37 +189,37 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< localTypeOf[java.lang.Integer] =>
val boxedType = classOf[java.lang.Integer]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)

case t if t <:< localTypeOf[java.lang.Long] =>
val boxedType = classOf[java.lang.Long]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)

case t if t <:< localTypeOf[java.lang.Double] =>
val boxedType = classOf[java.lang.Double]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)

case t if t <:< localTypeOf[java.lang.Float] =>
val boxedType = classOf[java.lang.Float]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)

case t if t <:< localTypeOf[java.lang.Short] =>
val boxedType = classOf[java.lang.Short]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)

case t if t <:< localTypeOf[java.lang.Byte] =>
val boxedType = classOf[java.lang.Byte]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)

case t if t <:< localTypeOf[java.lang.Boolean] =>
val boxedType = classOf[java.lang.Boolean]
val objectType = ObjectType(boxedType)
NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType)
NewInstance(boxedType, getPath :: Nil, objectType)

case t if t <:< localTypeOf[java.sql.Date] =>
StaticInvoke(
Expand Down Expand Up @@ -349,7 +349,7 @@ object ScalaReflection extends ScalaReflection {
}
}

val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls))
val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false)

if (path.nonEmpty) {
expressions.If(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ object ExpressionEncoder {
}

val fromRowExpression =
NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls))
NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false)

new ExpressionEncoder[Any](
schema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ object RowEncoder {
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)

Expand Down Expand Up @@ -166,7 +165,6 @@ object RowEncoder {
val obj = NewInstance(
udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
Nil,
false,
dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ case class Invoke(
${obj.code}
${argGen.map(_.code).mkString("\n")}

boolean ${ev.isNull} = ${obj.value} == null;
boolean ${ev.isNull} = ${obj.isNull};
$javaType ${ev.value} =
${ev.isNull} ?
${ctx.defaultValue(dataType)} : ($javaType) $value;
Expand All @@ -178,8 +178,8 @@ object NewInstance {
def apply(
cls: Class[_],
arguments: Seq[Expression],
propagateNull: Boolean = false,
dataType: DataType): NewInstance =
dataType: DataType,
propagateNull: Boolean = true): NewInstance =
new NewInstance(cls, arguments, propagateNull, dataType, None)
}

Expand Down Expand Up @@ -231,7 +231,7 @@ case class NewInstance(
s"new $className($argString)"
}

if (propagateNull) {
if (propagateNull && argGen.nonEmpty) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a small fix: we should not go into this branch if argGen is empty, or the argsNonNull will be !(), which is wrong.

I found this when I leave the propagateNull true for NewInstance of UDT.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})"

s"""
Expand All @@ -248,8 +248,8 @@ case class NewInstance(
s"""
$setup

$javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = ${ev.value} == null;
final $javaType ${ev.value} = $constructorCall;
final boolean ${ev.isNull} = false;
"""
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class EncoderResolutionSuite extends PlanTest {
toExternalString('a.string),
AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long")
),
false,
ObjectType(cls))
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}

Expand All @@ -60,8 +60,8 @@ class EncoderResolutionSuite extends PlanTest {
toExternalString('a.int.cast(StringType)),
AssertNotNull('b.long, cls.getName, "b", "Long")
),
false,
ObjectType(cls))
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}
}
Expand All @@ -88,11 +88,11 @@ class EncoderResolutionSuite extends PlanTest {
AssertNotNull(
GetStructField('b.struct('a.int, 'b.long), 1, Some("b")),
innerCls.getName, "b", "Long")),
false,
ObjectType(innerCls))
ObjectType(innerCls),
propagateNull = false)
)),
false,
ObjectType(cls))
ObjectType(cls),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}

Expand All @@ -114,11 +114,11 @@ class EncoderResolutionSuite extends PlanTest {
AssertNotNull(
GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType),
cls.getName, "b", "Long")),
false,
ObjectType(cls)),
ObjectType(cls),
propagateNull = false),
'b.int.cast(LongType)),
false,
ObjectType(classOf[Tuple2[_, _]]))
ObjectType(classOf[Tuple2[_, _]]),
propagateNull = false)
compareExpressions(fromRowExpr, expected)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ class ExpressionEncoderSuite extends SparkFunSuite {
encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null")
encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map")

encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")

// Kryo encoders
encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
encodeDecodeTest(new KryoSerializable(15), "kryo object")(
Expand Down