Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e1b5dee
Aggregator should be able to use Option of Product encoder.
viirya Jul 9, 2018
80506f4
Enable top-level Option of Product encoders.
viirya Jul 13, 2018
ed3d5cb
Remove topLevel parameter.
viirya Aug 24, 2018
9fc3f61
Merge remote-tracking branch 'upstream/master' into SPARK-24762
viirya Aug 24, 2018
5f95bd0
Remove useless change.
viirya Aug 24, 2018
a4f0405
Add more tests.
viirya Aug 24, 2018
c1f798f
Add test.
viirya Aug 25, 2018
80e11d2
Merge remote-tracking branch 'upstream/master' into SPARK-24762
viirya Oct 6, 2018
0f029b0
Improve code comments.
viirya Oct 6, 2018
84f3ce0
Refactoring ExpressionEncoder.
viirya Oct 15, 2018
6a6fa45
Fix Malformed class name.
viirya Oct 17, 2018
25a6162
Fix error message.
viirya Oct 17, 2018
295ecde
Fix test.
viirya Oct 18, 2018
85a9122
Merge remote-tracking branch 'upstream/master' into SPARK-24762-refactor
viirya Oct 18, 2018
35700f4
Fix rebase error.
viirya Oct 19, 2018
b211ed0
Fix unintentional style change.
viirya Oct 19, 2018
0c78b73
Address comments.
viirya Oct 19, 2018
5b9abb6
Address ComplexTypeMergingExpression issue.
viirya Oct 20, 2018
7432344
Try more reasonable solution.
viirya Oct 20, 2018
400f878
Address comment.
viirya Oct 22, 2018
552e8dd
Merge remote-tracking branch 'upstream/master' into SPARK-24762-refactor
viirya Oct 24, 2018
ed4f4c9
Merge remote-tracking branch 'upstream/master' into SPARK-24762-refactor
viirya Oct 24, 2018
8cb710b
Address comments.
viirya Oct 24, 2018
682fa4b
Make comment more precise.
viirya Oct 24, 2018
078a071
Simplify test change.
viirya Oct 24, 2018
c00d5e4
Address comment.
viirya Oct 24, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,10 @@ object Encoders {
validatePublicClass[T]()

ExpressionEncoder[T](
schema = new StructType().add("value", BinaryType),
flat = true,
serializer = Seq(
objSerializer =
EncodeUsingSerializer(
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
deserializer =
BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo),
objDeserializer =
DecodeUsingSerializer[T](
Cast(GetColumnByOrdinal(0, BinaryType), BinaryType),
classTag[T],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,26 +187,23 @@ object JavaTypeInference {
}

/**
* Returns an expression that can be used to deserialize an internal row to an object of java bean
* `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes
* of the same name as the constructor arguments. Nested classes will have their fields accessed
* using UnresolvedExtractValue.
* Returns an expression that can be used to deserialize a Spark SQL representation to an object
* of java bean `T` with a compatible schema. The Spark SQL representation is located at ordinal
* 0 of a row, i.e., `GetColumnByOrdinal(0, _)`. Nested classes will have their fields accessed
* using `UnresolvedExtractValue`.
*/
def deserializerFor(beanClass: Class[_]): Expression = {
deserializerFor(TypeToken.of(beanClass), None)
val typeToken = TypeToken.of(beanClass)
deserializerFor(typeToken, GetColumnByOrdinal(0, inferDataType(typeToken)._1))
}

private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = {
private def deserializerFor(typeToken: TypeToken[_], path: Expression): Expression = {
/** Returns the current path with a sub-field extracted. */
def addToPath(part: String): Expression = path
.map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
.getOrElse(UnresolvedAttribute(part))

/** Returns the current path or `GetColumnByOrdinal`. */
def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1))
def addToPath(part: String): Expression = UnresolvedExtractValue(path,
expressions.Literal(part))

typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path

case c if c == classOf[java.lang.Short] ||
c == classOf[java.lang.Integer] ||
Expand All @@ -219,30 +216,30 @@ object JavaTypeInference {
c,
ObjectType(c),
"valueOf",
getPath :: Nil,
path :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(c),
"toJavaDate",
getPath :: Nil,
path :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
ObjectType(c),
"toJavaTimestamp",
getPath :: Nil,
path :: Nil,
returnNullable = false)

case c if c == classOf[java.lang.String] =>
Invoke(getPath, "toString", ObjectType(classOf[String]))
Invoke(path, "toString", ObjectType(classOf[String]))

case c if c == classOf[java.math.BigDecimal] =>
Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))
Invoke(path, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal]))

case c if c.isArray =>
val elementType = c.getComponentType
Expand All @@ -258,12 +255,12 @@ object JavaTypeInference {
}

primitiveMethod.map { method =>
Invoke(getPath, method, ObjectType(c))
Invoke(path, method, ObjectType(c))
}.getOrElse {
Invoke(
MapObjects(
p => deserializerFor(typeToken.getComponentType, Some(p)),
getPath,
p => deserializerFor(typeToken.getComponentType, p),
path,
inferDataType(elementType)._1),
"array",
ObjectType(c))
Expand All @@ -272,8 +269,8 @@ object JavaTypeInference {
case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
UnresolvedMapObjects(
p => deserializerFor(et, Some(p)),
getPath,
p => deserializerFor(et, p),
path,
customCollectionCls = Some(c))

case _ if mapType.isAssignableFrom(typeToken) =>
Expand All @@ -282,16 +279,16 @@ object JavaTypeInference {
val keyData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(keyType, Some(p)),
GetKeyArrayFromMap(getPath)),
p => deserializerFor(keyType, p),
GetKeyArrayFromMap(path)),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(valueType, Some(p)),
GetValueArrayFromMap(getPath)),
p => deserializerFor(valueType, p),
GetValueArrayFromMap(path)),
"array",
ObjectType(classOf[Array[Any]]))

Expand All @@ -307,7 +304,7 @@ object JavaTypeInference {
other,
ObjectType(other),
"valueOf",
Invoke(getPath, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false) :: Nil,
returnNullable = false)

case other =>
Expand All @@ -316,7 +313,7 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (_, nullable) = inferDataType(fieldType)
val constructor = deserializerFor(fieldType, Some(addToPath(fieldName)))
val constructor = deserializerFor(fieldType, addToPath(fieldName))
val setter = if (nullable) {
constructor
} else {
Expand All @@ -328,28 +325,23 @@ object JavaTypeInference {
val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false)
val result = InitializeJavaBean(newInstance, setters)

if (path.nonEmpty) {
expressions.If(
IsNull(getPath),
expressions.Literal.create(null, ObjectType(other)),
result
)
} else {
expressions.If(
IsNull(path),
expressions.Literal.create(null, ObjectType(other)),
result
}
)
}
}

/**
* Returns an expression for serializing an object of the given type to an internal row.
* Returns an expression for serializing an object of the given type to a Spark SQL
* representation. The input object is located at ordinal 0 of a row, i.e.,
* `BoundReference(0, _)`.
*/
def serializerFor(beanClass: Class[_]): CreateNamedStruct = {
def serializerFor(beanClass: Class[_]): Expression = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
serializerFor(nullSafeInput, TypeToken.of(beanClass)) match {
case expressions.If(_, _, s: CreateNamedStruct) => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
serializerFor(nullSafeInput, TypeToken.of(beanClass))
}

private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = {
Expand Down
Loading