-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-27001][SQL] Refactor "serializerFor" method between ScalaReflection and JavaTypeInference #23908
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-27001][SQL] Refactor "serializerFor" method between ScalaReflection and JavaTypeInference #23908
Changes from 9 commits
44fa876
d683d80
43a69f0
371a2d1
4dfe3c7
1970e50
6b26513
c67826a
852debd
e0d7495
65d2079
578d8fe
20e8d5a
50c2ddc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,7 +29,7 @@ object DeserializerBuildHelper { | |
| path: Expression, | ||
| part: String, | ||
| dataType: DataType, | ||
| walkedTypePath: Seq[String]): Expression = { | ||
| walkedTypePath: WalkedTypePath): Expression = { | ||
| val newPath = UnresolvedExtractValue(path, expressions.Literal(part)) | ||
| upCastToExpectedType(newPath, dataType, walkedTypePath) | ||
| } | ||
|
|
@@ -39,40 +39,30 @@ object DeserializerBuildHelper { | |
| path: Expression, | ||
| ordinal: Int, | ||
| dataType: DataType, | ||
| walkedTypePath: Seq[String]): Expression = { | ||
| walkedTypePath: WalkedTypePath): Expression = { | ||
| val newPath = GetStructField(path, ordinal) | ||
| upCastToExpectedType(newPath, dataType, walkedTypePath) | ||
| } | ||
|
|
||
| def deserializerForWithNullSafety( | ||
| expr: Expression, | ||
| dataType: DataType, | ||
| nullable: Boolean, | ||
| walkedTypePath: Seq[String], | ||
| funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { | ||
| val newExpr = funcForCreatingNewExpr(expr, walkedTypePath) | ||
| expressionWithNullSafety(newExpr, nullable, walkedTypePath) | ||
| } | ||
|
|
||
| def deserializerForWithNullSafetyAndUpcast( | ||
| expr: Expression, | ||
| dataType: DataType, | ||
| nullable: Boolean, | ||
| walkedTypePath: Seq[String], | ||
| funcForCreatingNewExpr: (Expression, Seq[String]) => Expression): Expression = { | ||
| walkedTypePath: WalkedTypePath, | ||
| funcForCreatingDeserializer: (Expression, WalkedTypePath) => Expression): Expression = { | ||
| val casted = upCastToExpectedType(expr, dataType, walkedTypePath) | ||
| deserializerForWithNullSafety(casted, dataType, nullable, walkedTypePath, | ||
| funcForCreatingNewExpr) | ||
| expressionWithNullSafety(funcForCreatingDeserializer(casted, walkedTypePath), | ||
| nullable, walkedTypePath) | ||
| } | ||
|
|
||
| private def expressionWithNullSafety( | ||
| def expressionWithNullSafety( | ||
| expr: Expression, | ||
| nullable: Boolean, | ||
| walkedTypePath: Seq[String]): Expression = { | ||
| walkedTypePath: WalkedTypePath): Expression = { | ||
| if (nullable) { | ||
| expr | ||
| } else { | ||
| AssertNotNull(expr, walkedTypePath) | ||
| AssertNotNull(expr, walkedTypePath.copy()) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -167,10 +157,10 @@ object DeserializerBuildHelper { | |
| private def upCastToExpectedType( | ||
| expr: Expression, | ||
| expected: DataType, | ||
| walkedTypePath: Seq[String]): Expression = expected match { | ||
| walkedTypePath: WalkedTypePath): Expression = expected match { | ||
| case _: StructType => expr | ||
| case _: ArrayType => expr | ||
| case _: MapType => expr | ||
| case _ => UpCast(expr, expected, walkedTypePath) | ||
| case _ => UpCast(expr, expected, walkedTypePath.copy()) | ||
|
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,12 +27,12 @@ import scala.language.existentials | |
| import com.google.common.reflect.TypeToken | ||
|
|
||
| import org.apache.spark.sql.catalyst.DeserializerBuildHelper._ | ||
| import org.apache.spark.sql.catalyst.SerializerBuildHelper._ | ||
| import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.objects._ | ||
| import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} | ||
| import org.apache.spark.sql.catalyst.util.ArrayBasedMapData | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
||
| /** | ||
| * Type-inference utilities for POJOs and Java collections. | ||
|
|
@@ -195,7 +195,8 @@ object JavaTypeInference { | |
| */ | ||
| def deserializerFor(beanClass: Class[_]): Expression = { | ||
| val typeToken = TypeToken.of(beanClass) | ||
| val walkedTypePath = s"""- root class: "${beanClass.getCanonicalName}"""" :: Nil | ||
| val walkedTypePath = WalkedTypePath() | ||
| walkedTypePath.recordRoot(beanClass.getCanonicalName) | ||
| val (dataType, nullable) = inferDataType(typeToken) | ||
|
|
||
| // Assumes we are deserializing the first column of a row. | ||
|
|
@@ -208,7 +209,7 @@ object JavaTypeInference { | |
| private def deserializerFor( | ||
| typeToken: TypeToken[_], | ||
| path: Expression, | ||
| walkedTypePath: Seq[String]): Expression = { | ||
| walkedTypePath: WalkedTypePath): Expression = { | ||
| typeToken.getRawType match { | ||
| case c if !inferExternalType(c).isInstanceOf[ObjectType] => path | ||
|
|
||
|
|
@@ -244,16 +245,15 @@ object JavaTypeInference { | |
|
|
||
| case c if c.isArray => | ||
| val elementType = c.getComponentType | ||
| val newTypePath = s"""- array element class: "${elementType.getCanonicalName}"""" +: | ||
| walkedTypePath | ||
| walkedTypePath.recordArray(elementType.getCanonicalName) | ||
| val (dataType, elementNullable) = inferDataType(elementType) | ||
| val mapFunction: Expression => Expression = element => { | ||
| // upcast the array element to the data type the encoder expected. | ||
| deserializerForWithNullSafetyAndUpcast( | ||
| element, | ||
| dataType, | ||
| nullable = elementNullable, | ||
| newTypePath, | ||
| walkedTypePath, | ||
| (casted, typePath) => deserializerFor(typeToken.getComponentType, casted, typePath)) | ||
| } | ||
|
|
||
|
|
@@ -274,38 +274,40 @@ object JavaTypeInference { | |
|
|
||
| case c if listType.isAssignableFrom(typeToken) => | ||
| val et = elementType(typeToken) | ||
| val newTypePath = s"""- array element class: "${et.getType.getTypeName}"""" +: | ||
| walkedTypePath | ||
| walkedTypePath.recordArray(et.getType.getTypeName) | ||
| val (dataType, elementNullable) = inferDataType(et) | ||
| val mapFunction: Expression => Expression = element => { | ||
| // upcast the array element to the data type the encoder expected. | ||
| deserializerForWithNullSafetyAndUpcast( | ||
| element, | ||
| dataType, | ||
| nullable = elementNullable, | ||
| newTypePath, | ||
| walkedTypePath, | ||
| (casted, typePath) => deserializerFor(et, casted, typePath)) | ||
| } | ||
|
|
||
| UnresolvedMapObjects(mapFunction, path, customCollectionCls = Some(c)) | ||
|
|
||
| case _ if mapType.isAssignableFrom(typeToken) => | ||
| val (keyType, valueType) = mapKeyValueType(typeToken) | ||
| val newTypePath = (s"""- map key class: "${keyType.getType.getTypeName}"""" + | ||
| s""", value class: "${valueType.getType.getTypeName}"""") +: walkedTypePath | ||
| walkedTypePath.recordMap(keyType.getType.getTypeName, | ||
| valueType.getType.getTypeName) | ||
|
|
||
| val newTypePathForKey = walkedTypePath.copy() | ||
| val newTypePathForValue = walkedTypePath.copy() | ||
|
||
|
|
||
| val keyData = | ||
| Invoke( | ||
| UnresolvedMapObjects( | ||
| p => deserializerFor(keyType, p, newTypePath), | ||
| p => deserializerFor(keyType, p, newTypePathForKey), | ||
| MapKeys(path)), | ||
| "array", | ||
| ObjectType(classOf[Array[Any]])) | ||
|
|
||
| val valueData = | ||
| Invoke( | ||
| UnresolvedMapObjects( | ||
| p => deserializerFor(valueType, p, newTypePath), | ||
| p => deserializerFor(valueType, p, newTypePathForValue), | ||
| MapValues(path)), | ||
| "array", | ||
| ObjectType(classOf[Array[Any]])) | ||
|
|
@@ -328,15 +330,13 @@ object JavaTypeInference { | |
| val fieldName = p.getName | ||
| val fieldType = typeToken.method(p.getReadMethod).getReturnType | ||
| val (dataType, nullable) = inferDataType(fieldType) | ||
| val newTypePath = (s"""- field (class: "${fieldType.getType.getTypeName}"""" + | ||
| s""", name: "$fieldName")""") +: walkedTypePath | ||
| val setter = deserializerForWithNullSafety( | ||
| path, | ||
| dataType, | ||
| val newTypePathForField = walkedTypePath.copy() | ||
| newTypePathForField.recordField(fieldType.getType.getTypeName, fieldName) | ||
| val setter = expressionWithNullSafety( | ||
| deserializerFor(fieldType, addToPath(path, fieldName, dataType, newTypePathForField), | ||
| newTypePathForField), | ||
| nullable = nullable, | ||
| newTypePath, | ||
| (expr, typePath) => deserializerFor(fieldType, | ||
| addToPath(expr, fieldName, dataType, typePath), typePath)) | ||
| newTypePathForField) | ||
| p.getWriteMethod.getName -> setter | ||
| }.toMap | ||
|
|
||
|
|
@@ -358,7 +358,10 @@ object JavaTypeInference { | |
| */ | ||
| def serializerFor(beanClass: Class[_]): Expression = { | ||
| val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) | ||
| val nullSafeInput = AssertNotNull(inputObject, Seq("top level input bean")) | ||
| val walkedTypePath = WalkedTypePath() | ||
| walkedTypePath.recordRoot("top level input bean") | ||
| // not copying walkedTypePath since the instance will be only used here | ||
| val nullSafeInput = AssertNotNull(inputObject, walkedTypePath) | ||
| serializerFor(nullSafeInput, TypeToken.of(beanClass)) | ||
| } | ||
|
|
||
|
|
@@ -367,73 +370,37 @@ object JavaTypeInference { | |
| def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { | ||
| val (dataType, nullable) = inferDataType(elementType) | ||
| if (ScalaReflection.isNativeType(dataType)) { | ||
| NewInstance( | ||
| classOf[GenericArrayData], | ||
| input :: Nil, | ||
| dataType = ArrayType(dataType, nullable)) | ||
| createSerializerForGenericArray(input, dataType, nullable = nullable) | ||
| } else { | ||
| MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType)) | ||
| createSerializerForMapObjects(input, ObjectType(elementType.getRawType), | ||
| serializerFor(_, elementType)) | ||
| } | ||
| } | ||
|
|
||
| if (!inputObject.dataType.isInstanceOf[ObjectType]) { | ||
| inputObject | ||
| } else { | ||
| typeToken.getRawType match { | ||
| case c if c == classOf[String] => | ||
| StaticInvoke( | ||
| classOf[UTF8String], | ||
| StringType, | ||
| "fromString", | ||
| inputObject :: Nil, | ||
| returnNullable = false) | ||
|
|
||
| case c if c == classOf[java.sql.Timestamp] => | ||
| StaticInvoke( | ||
| DateTimeUtils.getClass, | ||
| TimestampType, | ||
| "fromJavaTimestamp", | ||
| inputObject :: Nil, | ||
| returnNullable = false) | ||
|
|
||
| case c if c == classOf[java.time.LocalDate] => | ||
| StaticInvoke( | ||
| DateTimeUtils.getClass, | ||
| DateType, | ||
| "localDateToDays", | ||
| inputObject :: Nil, | ||
| returnNullable = false) | ||
|
|
||
| case c if c == classOf[java.sql.Date] => | ||
| StaticInvoke( | ||
| DateTimeUtils.getClass, | ||
| DateType, | ||
| "fromJavaDate", | ||
| inputObject :: Nil, | ||
| returnNullable = false) | ||
| case c if c == classOf[String] => createSerializerForString(inputObject) | ||
|
|
||
| case c if c == classOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) | ||
|
|
||
| case c if c == classOf[java.sql.Timestamp] => createSerializerForSqlTimestamp(inputObject) | ||
|
|
||
| case c if c == classOf[java.time.LocalDate] => createSerializerForJavaLocalDate(inputObject) | ||
|
|
||
| case c if c == classOf[java.sql.Date] => createSerializerForSqlDate(inputObject) | ||
|
|
||
| case c if c == classOf[java.math.BigDecimal] => | ||
| StaticInvoke( | ||
| Decimal.getClass, | ||
| DecimalType.SYSTEM_DEFAULT, | ||
| "apply", | ||
| inputObject :: Nil, | ||
| returnNullable = false) | ||
|
|
||
| case c if c == classOf[java.lang.Boolean] => | ||
| Invoke(inputObject, "booleanValue", BooleanType) | ||
| case c if c == classOf[java.lang.Byte] => | ||
| Invoke(inputObject, "byteValue", ByteType) | ||
| case c if c == classOf[java.lang.Short] => | ||
| Invoke(inputObject, "shortValue", ShortType) | ||
| case c if c == classOf[java.lang.Integer] => | ||
| Invoke(inputObject, "intValue", IntegerType) | ||
| case c if c == classOf[java.lang.Long] => | ||
| Invoke(inputObject, "longValue", LongType) | ||
| case c if c == classOf[java.lang.Float] => | ||
| Invoke(inputObject, "floatValue", FloatType) | ||
| case c if c == classOf[java.lang.Double] => | ||
| Invoke(inputObject, "doubleValue", DoubleType) | ||
| createSerializerForJavaBigDecimal(inputObject) | ||
|
|
||
| case c if c == classOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) | ||
| case c if c == classOf[java.lang.Byte] => createSerializerForByte(inputObject) | ||
| case c if c == classOf[java.lang.Short] => createSerializerForShort(inputObject) | ||
| case c if c == classOf[java.lang.Integer] => createSerializerForInteger(inputObject) | ||
| case c if c == classOf[java.lang.Long] => createSerializerForLong(inputObject) | ||
| case c if c == classOf[java.lang.Float] => createSerializerForFloat(inputObject) | ||
| case c if c == classOf[java.lang.Double] => createSerializerForDouble(inputObject) | ||
|
|
||
| case _ if typeToken.isArray => | ||
| toCatalystArray(inputObject, typeToken.getComponentType) | ||
|
|
@@ -444,38 +411,34 @@ object JavaTypeInference { | |
| case _ if mapType.isAssignableFrom(typeToken) => | ||
| val (keyType, valueType) = mapKeyValueType(typeToken) | ||
|
|
||
| ExternalMapToCatalyst( | ||
| createSerializerForMap( | ||
| inputObject, | ||
| ObjectType(keyType.getRawType), | ||
| serializerFor(_, keyType), | ||
| keyNullable = true, | ||
| ObjectType(valueType.getRawType), | ||
| serializerFor(_, valueType), | ||
| valueNullable = true | ||
| MapElementInformation( | ||
| ObjectType(keyType.getRawType), | ||
| nullable = true, | ||
| serializerFor(_, keyType)), | ||
| MapElementInformation( | ||
| ObjectType(valueType.getRawType), | ||
| nullable = true, | ||
| serializerFor(_, valueType)) | ||
| ) | ||
|
|
||
| case other if other.isEnum => | ||
| StaticInvoke( | ||
| classOf[UTF8String], | ||
| StringType, | ||
| "fromString", | ||
| Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false) :: Nil, | ||
| returnNullable = false) | ||
| createSerializerForString( | ||
| Invoke(inputObject, "name", ObjectType(classOf[String]), returnNullable = false)) | ||
|
|
||
| case other => | ||
| val properties = getJavaBeanReadableAndWritableProperties(other) | ||
| val nonNullOutput = CreateNamedStruct(properties.flatMap { p => | ||
| val fields = properties.map { p => | ||
| val fieldName = p.getName | ||
| val fieldType = typeToken.method(p.getReadMethod).getReturnType | ||
| val fieldValue = Invoke( | ||
| inputObject, | ||
| p.getReadMethod.getName, | ||
| inferExternalType(fieldType.getRawType)) | ||
| expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil | ||
| }) | ||
|
|
||
| val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) | ||
| expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) | ||
| (fieldName, serializerFor(fieldValue, fieldType)) | ||
| } | ||
| createSerializerForObject(inputObject, fields) | ||
| } | ||
| } | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can let
AssertNotNulltake aSeq[String], to force us to copy theWalkedTypePathwhen creatingAssertNotNull