Skip to content
Closed
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ object DeserializerBuildHelper {
path: Expression,
part: String,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
Expand All @@ -39,40 +39,30 @@ object DeserializerBuildHelper {
path: Expression,
ordinal: Int,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
val newPath = GetStructField(path, ordinal)
upCastToExpectedType(newPath, dataType, walkedTypePath)
}

def deserializerForWithNullSafety(
expr: Expression,
dataType: DataType,
nullable: Boolean,
walkedTypePath: Seq[String],
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
val newExpr = funcForCreatingNewExpr(expr, walkedTypePath)
expressionWithNullSafety(newExpr, nullable, walkedTypePath)
}

def deserializerForWithNullSafetyAndUpcast(
expr: Expression,
dataType: DataType,
nullable: Boolean,
walkedTypePath: Seq[String],
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
walkedTypePath: WalkedTypePath,
funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = {
val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath,
funcForCreatingNewExpr)
expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath),
nullable, walkedTypePath)
}

private def expressionWithNullSafety(
def expressionWithNullSafety(
expr: Expression,
nullable: Boolean,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
if (nullable) {
expr
} else {
AssertNotNull(expr, walkedTypePath)
AssertNotNull(expr, walkedTypePath.copy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can let AssertNotNull take a Seq[String], to force us to copy the WalkedTypePath when creating AssertNotNull

}
}

Expand Down Expand Up @@ -167,10 +157,10 @@ object DeserializerBuildHelper {
private def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
walkedTypePath: WalkedTypePath): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
case _: MapType => expr
case _ => UpCast(expr, expected, walkedTypePath)
case _ => UpCast(expr, expected, walkedTypePath.copy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ import scala.language.existentials
import com.google.common.reflect.TypeToken

import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Type-inference utilities for POJOs and Java collections.
Expand Down Expand Up @@ -195,7 +195,8 @@ object JavaTypeInference {
*/
def deserializerFor(beanClass: Class[_]): Expression = {
val typeToken = TypeToken.of(beanClass)
val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil
val walkedTypePath = WalkedTypePath()
walkedTypePath.recordRoot(beanClass.getCanonicalName)
val (dataType, nullable) = inferDataType(typeToken)

// Assumes we are deserializing the first column of a row.
Expand All @@ -208,7 +209,7 @@ object JavaTypeInference {
private def deserializerFor(
typeToken: TypeToken[_],
path: Expression,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path

Expand Down Expand Up @@ -244,16 +245,15 @@ object JavaTypeInference {

case c if c.isArray =>
val elementType = c.getComponentType
val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +:
walkedTypePath
walkedTypePath.recordArray(elementType.getCanonicalName)
val (dataType, elementNullable) = inferDataType(elementType)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
deserializerForWithNullSafetyAndUpcast(
element,
dataType,
nullable = elementNullable,
newTypePath,
walkedTypePath,
(casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath))
}

Expand All @@ -274,38 +274,40 @@ object JavaTypeInference {

case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +:
walkedTypePath
walkedTypePath.recordArray(et.getType.getTypeName)
val (dataType, elementNullable) = inferDataType(et)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
deserializerForWithNullSafetyAndUpcast(
element,
dataType,
nullable = elementNullable,
newTypePath,
walkedTypePath,
(casted, typePath) => deserializerFor(et, casted, typePath))
}

UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c))

case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" +
s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath
walkedTypePath.recordMap(keyType.getType.getTypeName,
valueType.getType.getTypeName)

val newTypePathForKey = walkedTypePath.copy()
val newTypePathForValue = walkedTypePath.copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the back and forth. But seems it's better to make WalkedTypePath immutable as there are branches. It's hard to maintain and we can easily mess it up if we forget the call copy somewhere.

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Mar 1, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, same understanding. No problem! I'll revert back to let WalkedTypePath be immutable one.


val keyData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(keyType, p, newTypePath),
p => deserializerFor(keyType, p, newTypePathForKey),
MapKeys(path)),
"array",
ObjectType(classOf[Array[Any]]))

val valueData =
Invoke(
UnresolvedMapObjects(
p => deserializerFor(valueType, p, newTypePath),
p => deserializerFor(valueType, p, newTypePathForValue),
MapValues(path)),
"array",
ObjectType(classOf[Array[Any]]))
Expand All @@ -328,15 +330,13 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(fieldType)
val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" +
s""", name: "$fieldName")""") +: walkedTypePath
val setter = deserializerForWithNullSafety(
path,
dataType,
val newTypePathForField = walkedTypePath.copy()
newTypePathForField.recordField(fieldType.getType.getTypeName, fieldName)
val setter = expressionWithNullSafety(
deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePathForField),
newTypePathForField),
nullable = nullable,
newTypePath,
(expr, typePath) => deserializerFor(fieldType,
addToPath(expr, fieldName, dataType, typePath), typePath))
newTypePathForField)
p.getWriteMethod.getName -> setter
}.toMap

Expand All @@ -358,7 +358,10 @@ object JavaTypeInference {
*/
def serializerFor(beanClass: Class[_]): Expression = {
val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true)
val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean"))
val walkedTypePath = WalkedTypePath()
walkedTypePath.recordRoot("top level input bean")
// not copying walkedTypePath since the instance will be only used here
val nullSafeInput = AssertNotNull(inputObject, walkedTypePath)
serializerFor(nullSafeInput, TypeToken.of(beanClass))
}

Expand All @@ -367,73 +370,37 @@ object JavaTypeInference {
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val (dataType, nullable) = inferDataType(elementType)
if (ScalaReflection.isNativeType(dataType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(dataType, nullable))
createSerializerForGenericArray(input, dataType, nullable = nullable)
} else {
MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType))
createSerializerForMapObjects(input, ObjectType(elementType.getRawType),
serializerFor(_, elementType))
}
}

if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
typeToken.getRawType match {
case c if c == classOf[String] =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.time.LocalDate] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"localDateToDays",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil,
returnNullable = false)
case c if c == classOf[String] => createSerializerForString(inputObject)

case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject)

case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject)

case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject)

case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
case c if c == classOf[java.lang.Byte] =>
Invoke(inputObject, "byteValue", ByteType)
case c if c == classOf[java.lang.Short] =>
Invoke(inputObject, "shortValue", ShortType)
case c if c == classOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case c if c == classOf[java.lang.Long] =>
Invoke(inputObject, "longValue", LongType)
case c if c == classOf[java.lang.Float] =>
Invoke(inputObject, "floatValue", FloatType)
case c if c == classOf[java.lang.Double] =>
Invoke(inputObject, "doubleValue", DoubleType)
createSerializerForJavaBigDecimal(inputObject)

case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject)
case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject)
case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject)
case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject)
case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject)
case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject)

case _ if typeToken.isArray =>
toCatalystArray(inputObject, typeToken.getComponentType)
Expand All @@ -444,38 +411,34 @@ object JavaTypeInference {
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)

ExternalMapToCatalyst(
createSerializerForMap(
inputObject,
ObjectType(keyType.getRawType),
serializerFor(_, keyType),
keyNullable = true,
ObjectType(valueType.getRawType),
serializerFor(_, valueType),
valueNullable = true
MapElementInformation(
ObjectType(keyType.getRawType),
nullable = true,
serializerFor(_, keyType)),
MapElementInformation(
ObjectType(valueType.getRawType),
nullable = true,
serializerFor(_, valueType))
)

case other if other.isEnum =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil,
returnNullable = false)
createSerializerForString(
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
val fields = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})

val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
(fieldName, serializerFor(fieldValue, fieldType))
}
createSerializerForObject(inputObject, fields)
}
}
}
Expand Down
Loading