-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-20384][SQL] Support value class in schema of Dataset #22309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1c78833
7fa99a7
bc847d2
41ffac1
43fbf82
4675f11
05fa807
79be9e7
72c5963
3ddd46a
e43d62e
9c2dfdb
83552d2
ccb7927
24d5858
e8a8cba
893b6fc
c31f7e7
d2f35a9
d70cac5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -128,6 +128,15 @@ object ScalaReflection extends ScalaReflection { | |
| case _ => false | ||
| } | ||
|
|
||
| def isValueClass(tpe: `Type`): Boolean = { | ||
| tpe.typeSymbol.asClass.isDerivedValueClass | ||
| } | ||
|
|
||
| /** Returns the name and type of the underlying parameter of value class `tpe`. */ | ||
| def getUnderlyingParameterOf(tpe: `Type`): (String, Type) = { | ||
| getConstructorParameters(tpe).head | ||
| } | ||
|
|
||
| /** | ||
| * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff | ||
| * and lost the required data type, which may lead to runtime error if the real type doesn't | ||
|
|
@@ -165,7 +174,7 @@ object ScalaReflection extends ScalaReflection { | |
| val input = upCastToExpectedType( | ||
| GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) | ||
|
|
||
| val expr = deserializerFor(tpe, input, walkedTypePath) | ||
| val expr = deserializerFor(tpe, input, walkedTypePath, instantiateValueClass = true) | ||
| if (nullable) { | ||
| expr | ||
| } else { | ||
|
|
@@ -180,11 +189,16 @@ object ScalaReflection extends ScalaReflection { | |
| * @param tpe The `Type` of deserialized object. | ||
| * @param path The expression which can be used to extract serialized value. | ||
| * @param walkedTypePath The paths from top to bottom to access current field when deserializing. | ||
| * @param instantiateValueClass If `true`, create an instance for Scala value class. | ||
| * This is needed in case value class is top-level or it is | ||
| * the type of collection elements. Please refer to the comment in | ||
| * value class case for more details. | ||
| */ | ||
| private def deserializerFor( | ||
| tpe: `Type`, | ||
| path: Expression, | ||
| walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { | ||
| walkedTypePath: Seq[String], | ||
| instantiateValueClass: Boolean = false): Expression = cleanUpReflectionObjects { | ||
|
|
||
| /** Returns the current path with a sub-field extracted. */ | ||
| def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { | ||
|
|
@@ -288,7 +302,8 @@ object ScalaReflection extends ScalaReflection { | |
| val mapFunction: Expression => Expression = element => { | ||
| // upcast the array element to the data type the encoder expected. | ||
| val casted = upCastToExpectedType(element, dataType, newTypePath) | ||
| val converter = deserializerFor(elementType, casted, newTypePath) | ||
| val converter = deserializerFor(elementType, casted, newTypePath, | ||
| instantiateValueClass = true) | ||
| if (elementNullable) { | ||
| converter | ||
| } else { | ||
|
|
@@ -299,7 +314,7 @@ object ScalaReflection extends ScalaReflection { | |
| val arrayData = UnresolvedMapObjects(mapFunction, path) | ||
| val arrayCls = arrayClassFor(elementType) | ||
|
|
||
| if (elementNullable) { | ||
| if (elementNullable || isValueClass(elementType)) { | ||
| Invoke(arrayData, "array", arrayCls, returnNullable = false) | ||
| } else { | ||
| val primitiveMethod = elementType match { | ||
|
|
@@ -328,7 +343,8 @@ object ScalaReflection extends ScalaReflection { | |
| val mapFunction: Expression => Expression = element => { | ||
| // upcast the array element to the data type the encoder expected. | ||
| val casted = upCastToExpectedType(element, dataType, newTypePath) | ||
| val converter = deserializerFor(elementType, casted, newTypePath) | ||
| val converter = deserializerFor(elementType, casted, newTypePath, | ||
| instantiateValueClass = true) | ||
| if (elementNullable) { | ||
| converter | ||
| } else { | ||
|
|
@@ -351,8 +367,8 @@ object ScalaReflection extends ScalaReflection { | |
|
|
||
| UnresolvedCatalystToExternalMap( | ||
| path, | ||
| p => deserializerFor(keyType, p, walkedTypePath), | ||
| p => deserializerFor(valueType, p, walkedTypePath), | ||
| p => deserializerFor(keyType, p, walkedTypePath, instantiateValueClass = true), | ||
| p => deserializerFor(valueType, p, walkedTypePath, instantiateValueClass = true), | ||
| mirror.runtimeClass(t.typeSymbol.asClass) | ||
| ) | ||
|
|
||
|
|
@@ -373,6 +389,29 @@ object ScalaReflection extends ScalaReflection { | |
| dataType = ObjectType(udt.getClass)) | ||
| Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) | ||
|
|
||
| case t if isValueClass(t) => | ||
| val (_, underlyingType) = getUnderlyingParameterOf(t) | ||
| val underlyingClsName = getClassNameFromType(underlyingType) | ||
| val clsName = getUnerasedClassNameFromType(t) | ||
| val newTypePath = s"""- Scala value class: $clsName($underlyingClsName)""" +: | ||
|
||
| walkedTypePath | ||
|
|
||
| // Nested value class is treated as its underlying type | ||
| // because the compiler will convert value class in the schema to | ||
| // its underlying type. | ||
| // However, for value class that is top-level or collection element or | ||
| // if it is used as another type (e.g. as its parent trait or generic), | ||
| // the compiler keeps the class so we must provide an instance of the | ||
| // class too. In other cases, the compiler will handle wrapping/unwrapping | ||
| // for us automatically. | ||
| val arg = deserializerFor(underlyingType, path, newTypePath) | ||
| if (instantiateValueClass) { | ||
| val cls = getClassFromType(t) | ||
| NewInstance(cls, Seq(arg), ObjectType(cls), propagateNull = false) | ||
| } else { | ||
| arg | ||
| } | ||
|
|
||
| case t if definedByConstructorParams(t) => | ||
| val params = getConstructorParameters(t) | ||
|
|
||
|
|
@@ -617,6 +656,14 @@ object ScalaReflection extends ScalaReflection { | |
| dataType = ObjectType(udt.getClass)) | ||
| Invoke(obj, "serialize", udt, inputObject :: Nil) | ||
|
|
||
| case t if isValueClass(t) => | ||
| val (name, underlyingType) = getUnderlyingParameterOf(t) | ||
| val underlyingClsName = getClassNameFromType(underlyingType) | ||
| val clsName = getUnerasedClassNameFromType(t) | ||
| val newPath = s"""- Scala value class: $clsName($underlyingClsName)""" +: walkedTypePath | ||
| val getArg = Invoke(inputObject, name, dataTypeFor(underlyingType)) | ||
| serializerFor(getArg, underlyingType, newPath) | ||
|
|
||
| case t if definedByConstructorParams(t) => | ||
| if (seenTypeSet.contains(t)) { | ||
| throw new UnsupportedOperationException( | ||
|
|
@@ -630,13 +677,21 @@ object ScalaReflection extends ScalaReflection { | |
| "cannot be used as field name\n" + walkedTypePath.mkString("\n")) | ||
| } | ||
|
|
||
| // as a field, value class is represented by its underlying type | ||
| val trueFieldType = if (isValueClass(fieldType)) { | ||
| val (_, underlyingType) = getUnderlyingParameterOf(fieldType) | ||
| underlyingType | ||
| } else { | ||
| fieldType | ||
| } | ||
|
|
||
| val fieldValue = Invoke( | ||
| AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType), | ||
| returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) | ||
| val clsName = getClassNameFromType(fieldType) | ||
| AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(trueFieldType), | ||
| returnNullable = !trueFieldType.typeSymbol.asClass.isPrimitive) | ||
| val clsName = getClassNameFromType(trueFieldType) | ||
|
||
| val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath | ||
| expressions.Literal(fieldName) :: | ||
| serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil | ||
| serializerFor(fieldValue, trueFieldType, newPath, seenTypeSet + t) :: Nil | ||
| }) | ||
| val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) | ||
| expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) | ||
|
|
@@ -773,6 +828,9 @@ object ScalaReflection extends ScalaReflection { | |
| case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) | ||
| case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) | ||
| case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) | ||
| case t if isValueClass(t) => | ||
| val (_, underlyingType) = getUnderlyingParameterOf(t) | ||
| schemaFor(underlyingType) | ||
| case t if definedByConstructorParams(t) => | ||
| val params = getConstructorParameters(t) | ||
| Schema(StructType( | ||
|
|
@@ -930,6 +988,13 @@ trait ScalaReflection extends Logging { | |
| tpe.dealias.erasure.typeSymbol.asClass.fullName | ||
| } | ||
|
|
||
| /** | ||
| * Same as `getClassNameFromType` but returns the class name before erasure. | ||
| */ | ||
| def getUnerasedClassNameFromType(tpe: `Type`): String = { | ||
| tpe.dealias.typeSymbol.asClass.fullName | ||
| } | ||
|
|
||
| /** | ||
| * Returns the nullability of the input parameter types of the scala function object. | ||
| * | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -109,9 +109,20 @@ object TestingUDT { | |
| } | ||
| } | ||
|
|
||
| object TestingValueClass { | ||
| class IntWrapper(val i: Int) extends AnyVal | ||
| case class StrWrapper(s: String) extends AnyVal | ||
|
|
||
| case class ValueClassData( | ||
| intField: Int, | ||
| wrappedInt: IntWrapper, // an int column | ||
| strField: String, | ||
| wrappedStr: StrWrapper) // a string column | ||
| } | ||
|
|
||
| class ScalaReflectionSuite extends SparkFunSuite { | ||
| import org.apache.spark.sql.catalyst.ScalaReflection._ | ||
| import TestingValueClass._ | ||
|
|
||
| // A helper method used to test `ScalaReflection.serializerForType`. | ||
| private def serializerFor[T: TypeTag]: Expression = | ||
|
|
@@ -362,4 +373,34 @@ class ScalaReflectionSuite extends SparkFunSuite { | |
| assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) | ||
| assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) | ||
| } | ||
|
|
||
| test("schema for case class that is a value class") { | ||
| val schema = schemaFor[IntWrapper] | ||
| assert(schema === Schema(IntegerType, nullable = false)) | ||
| } | ||
|
|
||
| test("schema for case class that contains value class fields") { | ||
| val schema = schemaFor[ValueClassData] | ||
| assert(schema === Schema( | ||
| StructType(Seq( | ||
| StructField("intField", IntegerType, nullable = false), | ||
| StructField("wrappedInt", IntegerType, nullable = false), | ||
|
||
| StructField("strField", StringType, nullable = true), | ||
| StructField("wrappedStr", StringType, nullable = true))), | ||
| nullable = true)) | ||
| } | ||
|
|
||
| test("schema for array of value class") { | ||
| val schema = schemaFor[Array[IntWrapper]] | ||
| assert(schema === Schema( | ||
| ArrayType(IntegerType, containsNull = false), | ||
| nullable = true)) | ||
| } | ||
|
|
||
| test("schema for map of value class") { | ||
| val schema = schemaFor[Map[IntWrapper, StrWrapper]] | ||
| assert(schema === Schema( | ||
| MapType(IntegerType, StringType, valueContainsNull = true), | ||
| nullable = true)) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -112,6 +112,16 @@ object ReferenceValueClass { | |
| case class Container(data: Int) | ||
| } | ||
|
|
||
| case class StringWrapper(s: String) extends AnyVal | ||
| case class ValueContainer( | ||
| a: Int, | ||
| b: StringWrapper) // a string column | ||
| class IntWrapper(val i: Int) extends AnyVal // child column doesn't need to be case class | ||
| case class ComplexValueClassContainer( | ||
| a: Int, | ||
| b: ValueContainer, | ||
| c: IntWrapper) // an int column | ||
|
|
||
| class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { | ||
| OuterScopes.addOuterScope(this) | ||
|
|
||
|
|
@@ -297,11 +307,28 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes | |
| ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) | ||
| } | ||
|
|
||
| // test for Scala value class | ||
| encodeDecodeTest( | ||
| PrimitiveValueClass(42), "primitive value class") | ||
|
|
||
| encodeDecodeTest( | ||
| ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") | ||
| encodeDecodeTest(StringWrapper("a"), "string value class") | ||
| encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class") | ||
| encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null") | ||
| encodeDecodeTest( | ||
|
||
| ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")), new IntWrapper(3)), | ||
| "complex value class") | ||
| encodeDecodeTest( | ||
| Array(new IntWrapper(1), new IntWrapper(2), new IntWrapper(3)), | ||
| "array of value class") | ||
| encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class") | ||
| encodeDecodeTest( | ||
| Seq(new IntWrapper(1), new IntWrapper(2), new IntWrapper(3)), | ||
| "seq of value class") | ||
| encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class") | ||
| encodeDecodeTest( | ||
| Map(new IntWrapper(1) -> StringWrapper("a"), new IntWrapper(2) -> StringWrapper("b")), | ||
| "map with value class") | ||
|
|
||
| encodeDecodeTest(Option(31), "option of int") | ||
| encodeDecodeTest(Option.empty[Int], "empty option of int") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a more official way to get the value class field name?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, I can't find any