Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ object DeserializerBuildHelper {
path: Expression,
part: String,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
val newPath = UnresolvedExtractValue(path, expressions.Literal(part))
upCastToExpectedType(newPath, dataType, walkedTypePath)
}
Expand All @@ -39,40 +39,30 @@ object DeserializerBuildHelper {
path: Expression,
ordinal: Int,
dataType: DataType,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
val newPath = GetStructField(path, ordinal)
upCastToExpectedType(newPath, dataType, walkedTypePath)
}

def deserializerForWithNullSafety(
expr: Expression,
dataType: DataType,
nullable: Boolean,
walkedTypePath: Seq[String],
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
val newExpr = funcForCreatingNewExpr(expr, walkedTypePath)
expressionWithNullSafety(newExpr, nullable, walkedTypePath)
}

def deserializerForWithNullSafetyAndUpcast(
expr: Expression,
dataType: DataType,
nullable: Boolean,
walkedTypePath: Seq[String],
funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = {
walkedTypePath: WalkedTypePath,
funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = {
val casted = upCastToExpectedType(expr, dataType, walkedTypePath)
deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath,
funcForCreatingNewExpr)
expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath),
nullable, walkedTypePath)
}

private def expressionWithNullSafety(
def expressionWithNullSafety(
expr: Expression,
nullable: Boolean,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
if (nullable) {
expr
} else {
AssertNotNull(expr, walkedTypePath)
AssertNotNull(expr, walkedTypePath.getPaths)
}
}

Expand Down Expand Up @@ -167,10 +157,10 @@ object DeserializerBuildHelper {
private def upCastToExpectedType(
expr: Expression,
expected: DataType,
walkedTypePath: Seq[String]): Expression = expected match {
walkedTypePath: WalkedTypePath): Expression = expected match {
case _: StructType => expr
case _: ArrayType => expr
case _: MapType => expr
case _ => UpCast(expr, expected, walkedTypePath)
case _ => UpCast(expr, expected, walkedTypePath.getPaths)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ import scala.language.existentials
import com.google.common.reflect.TypeToken

import org.apache.spark.sql.catalyst.DeserializerBuildHelper._
import org.apache.spark.sql.catalyst.SerializerBuildHelper._
import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/**
* Type-inference utilities for POJOs and Java collections.
Expand Down Expand Up @@ -195,7 +195,7 @@ object JavaTypeInference {
*/
def deserializerFor(beanClass: Class[_]): Expression = {
val typeToken = TypeToken.of(beanClass)
val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil
val walkedTypePath = new WalkedTypePath().recordRoot(beanClass.getCanonicalName)
val (dataType, nullable) = inferDataType(typeToken)

// Assumes we are deserializing the first column of a row.
Expand All @@ -208,7 +208,7 @@ object JavaTypeInference {
private def deserializerFor(
typeToken: TypeToken[_],
path: Expression,
walkedTypePath: Seq[String]): Expression = {
walkedTypePath: WalkedTypePath): Expression = {
typeToken.getRawType match {
case c if !inferExternalType(c).isInstanceOf[ObjectType] => path

Expand Down Expand Up @@ -244,8 +244,7 @@ object JavaTypeInference {

case c if c.isArray =>
val elementType = c.getComponentType
val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +:
walkedTypePath
val newTypePath = walkedTypePath.recordArray(elementType.getCanonicalName)
val (dataType, elementNullable) = inferDataType(elementType)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
Expand Down Expand Up @@ -274,8 +273,7 @@ object JavaTypeInference {

case c if listType.isAssignableFrom(typeToken) =>
val et = elementType(typeToken)
val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +:
walkedTypePath
val newTypePath = walkedTypePath.recordArray(et.getType.getTypeName)
val (dataType, elementNullable) = inferDataType(et)
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
Expand All @@ -291,8 +289,8 @@ object JavaTypeInference {

case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)
val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" +
s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath
val newTypePath = walkedTypePath.recordMap(keyType.getType.getTypeName,
valueType.getType.getTypeName)

val keyData =
Invoke(
Expand Down Expand Up @@ -328,15 +326,12 @@ object JavaTypeInference {
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(fieldType)
val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" +
s""", name: "$fieldName")""") +: walkedTypePath
val setter = deserializerForWithNullSafety(
path,
dataType,
val newTypePath = walkedTypePath.recordField(fieldType.getType.getTypeName, fieldName)
val setter = expressionWithNullSafety(
deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePath),
newTypePath),
nullable = nullable,
newTypePath,
(expr, typePath) => deserializerFor(fieldType,
addToPath(expr, fieldName, dataType, typePath), typePath))
newTypePath)
p.getWriteMethod.getName -> setter
}.toMap

Expand Down Expand Up @@ -367,73 +362,37 @@ object JavaTypeInference {
def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = {
val (dataType, nullable) = inferDataType(elementType)
if (ScalaReflection.isNativeType(dataType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(dataType, nullable))
createSerializerForGenericArray(input, dataType, nullable = nullable)
} else {
MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType))
createSerializerForMapObjects(input, ObjectType(elementType.getRawType),
serializerFor(_, elementType))
}
}

if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
typeToken.getRawType match {
case c if c == classOf[String] =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Timestamp] =>
StaticInvoke(
DateTimeUtils.getClass,
TimestampType,
"fromJavaTimestamp",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.time.LocalDate] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"localDateToDays",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.sql.Date] =>
StaticInvoke(
DateTimeUtils.getClass,
DateType,
"fromJavaDate",
inputObject :: Nil,
returnNullable = false)
case c if c == classOf[String] => createSerializerForString(inputObject)

case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject)

case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject)

case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject)

case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject)

case c if c == classOf[java.math.BigDecimal] =>
StaticInvoke(
Decimal.getClass,
DecimalType.SYSTEM_DEFAULT,
"apply",
inputObject :: Nil,
returnNullable = false)

case c if c == classOf[java.lang.Boolean] =>
Invoke(inputObject, "booleanValue", BooleanType)
case c if c == classOf[java.lang.Byte] =>
Invoke(inputObject, "byteValue", ByteType)
case c if c == classOf[java.lang.Short] =>
Invoke(inputObject, "shortValue", ShortType)
case c if c == classOf[java.lang.Integer] =>
Invoke(inputObject, "intValue", IntegerType)
case c if c == classOf[java.lang.Long] =>
Invoke(inputObject, "longValue", LongType)
case c if c == classOf[java.lang.Float] =>
Invoke(inputObject, "floatValue", FloatType)
case c if c == classOf[java.lang.Double] =>
Invoke(inputObject, "doubleValue", DoubleType)
createSerializerForJavaBigDecimal(inputObject)

case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject)
case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject)
case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject)
case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject)
case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject)
case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject)
case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject)

case _ if typeToken.isArray =>
toCatalystArray(inputObject, typeToken.getComponentType)
Expand All @@ -444,38 +403,34 @@ object JavaTypeInference {
case _ if mapType.isAssignableFrom(typeToken) =>
val (keyType, valueType) = mapKeyValueType(typeToken)

ExternalMapToCatalyst(
createSerializerForMap(
inputObject,
ObjectType(keyType.getRawType),
serializerFor(_, keyType),
keyNullable = true,
ObjectType(valueType.getRawType),
serializerFor(_, valueType),
valueNullable = true
MapElementInformation(
ObjectType(keyType.getRawType),
nullable = true,
serializerFor(_, keyType)),
MapElementInformation(
ObjectType(valueType.getRawType),
nullable = true,
serializerFor(_, valueType))
)

case other if other.isEnum =>
StaticInvoke(
classOf[UTF8String],
StringType,
"fromString",
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil,
returnNullable = false)
createSerializerForString(
Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false))

case other =>
val properties = getJavaBeanReadableAndWritableProperties(other)
val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>
val fields = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
val fieldValue = Invoke(
inputObject,
p.getReadMethod.getName,
inferExternalType(fieldType.getRawType))
expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
})

val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
(fieldName, serializerFor(fieldValue, fieldType))
}
createSerializerForObject(inputObject, fields)
}
}
}
Expand Down
Loading