diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 912744eab6a3a..aec821aca86d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -128,6 +128,15 @@ object ScalaReflection extends ScalaReflection { case _ => false } + def isValueClass(tpe: `Type`): Boolean = { + tpe.typeSymbol.asClass.isDerivedValueClass + } + + /** Returns the name and type of the underlying parameter of value class `tpe`. */ + def getUnderlyingParameterOf(tpe: `Type`): (String, Type) = { + getConstructorParameters(tpe).head + } + /** * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't @@ -165,7 +174,7 @@ object ScalaReflection extends ScalaReflection { val input = upCastToExpectedType( GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) - val expr = deserializerFor(tpe, input, walkedTypePath) + val expr = deserializerFor(tpe, input, walkedTypePath, instantiateValueClass = true) if (nullable) { expr } else { @@ -180,11 +189,16 @@ object ScalaReflection extends ScalaReflection { * @param tpe The `Type` of deserialized object. * @param path The expression which can be used to extract serialized value. * @param walkedTypePath The paths from top to bottom to access current field when deserializing. + * @param instantiateValueClass If `true`, create an instance for Scala value class. + * This is needed in case value class is top-level or it is + * the type of collection elements. Please refer to the comment in + * value class case for more details. */ private def deserializerFor( tpe: `Type`, path: Expression, - walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects { + walkedTypePath: Seq[String], + instantiateValueClass: Boolean = false): Expression = cleanUpReflectionObjects { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = { @@ -288,7 +302,8 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, casted, newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath, + instantiateValueClass = true) if (elementNullable) { converter } else { @@ -299,7 +314,7 @@ object ScalaReflection extends ScalaReflection { val arrayData = UnresolvedMapObjects(mapFunction, path) val arrayCls = arrayClassFor(elementType) - if (elementNullable) { + if (elementNullable || isValueClass(elementType)) { Invoke(arrayData, "array", arrayCls, returnNullable = false) } else { val primitiveMethod = elementType match { @@ -328,7 +343,8 @@ object ScalaReflection extends ScalaReflection { val mapFunction: Expression => Expression = element => { // upcast the array element to the data type the encoder expected. val casted = upCastToExpectedType(element, dataType, newTypePath) - val converter = deserializerFor(elementType, casted, newTypePath) + val converter = deserializerFor(elementType, casted, newTypePath, + instantiateValueClass = true) if (elementNullable) { converter } else { @@ -351,8 +367,8 @@ object ScalaReflection extends ScalaReflection { UnresolvedCatalystToExternalMap( path, - p => deserializerFor(keyType, p, walkedTypePath), - p => deserializerFor(valueType, p, walkedTypePath), + p => deserializerFor(keyType, p, walkedTypePath, instantiateValueClass = true), + p => deserializerFor(valueType, p, walkedTypePath, instantiateValueClass = true), mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -373,6 +389,29 @@ object ScalaReflection extends ScalaReflection { dataType = ObjectType(udt.getClass)) Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) + case t if isValueClass(t) => + val (_, underlyingType) = getUnderlyingParameterOf(t) + val underlyingClsName = getClassNameFromType(underlyingType) + val clsName = getUnerasedClassNameFromType(t) + val newTypePath = s"""- Scala value class: $clsName($underlyingClsName)""" +: + walkedTypePath + + // Nested value class is treated as its underlying type + // because the compiler will convert value class in the schema to + // its underlying type. + // However, for value class that is top-level or collection element or + // if it is used as another type (e.g. as its parent trait or generic), + // the compiler keeps the class so we must provide an instance of the + // class too. In other cases, the compiler will handle wrapping/unwrapping + // for us automatically. + val arg = deserializerFor(underlyingType, path, newTypePath) + if (instantiateValueClass) { + val cls = getClassFromType(t) + NewInstance(cls, Seq(arg), ObjectType(cls), propagateNull = false) + } else { + arg + } + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -617,6 +656,14 @@ object ScalaReflection extends ScalaReflection { dataType = ObjectType(udt.getClass)) Invoke(obj, "serialize", udt, inputObject :: Nil) + case t if isValueClass(t) => + val (name, underlyingType) = getUnderlyingParameterOf(t) + val underlyingClsName = getClassNameFromType(underlyingType) + val clsName = getUnerasedClassNameFromType(t) + val newPath = s"""- Scala value class: $clsName($underlyingClsName)""" +: walkedTypePath + val getArg = Invoke(inputObject, name, dataTypeFor(underlyingType)) + serializerFor(getArg, underlyingType, newPath) + case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { throw new UnsupportedOperationException( @@ -630,13 +677,21 @@ object ScalaReflection extends ScalaReflection { "cannot be used as field name\n" + walkedTypePath.mkString("\n")) } + // as a field, value class is represented by its underlying type + val trueFieldType = if (isValueClass(fieldType)) { + val (_, underlyingType) = getUnderlyingParameterOf(fieldType) + underlyingType + } else { + fieldType + } + val fieldValue = Invoke( - AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType), - returnNullable = !fieldType.typeSymbol.asClass.isPrimitive) - val clsName = getClassNameFromType(fieldType) + AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(trueFieldType), + returnNullable = !trueFieldType.typeSymbol.asClass.isPrimitive) + val clsName = getClassNameFromType(trueFieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath expressions.Literal(fieldName) :: - serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil + serializerFor(fieldValue, trueFieldType, newPath, seenTypeSet + t) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) @@ -773,6 +828,9 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if isValueClass(t) => + val (_, underlyingType) = getUnderlyingParameterOf(t) + schemaFor(underlyingType) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) Schema(StructType( @@ -930,6 +988,13 @@ trait ScalaReflection extends Logging { tpe.dealias.erasure.typeSymbol.asClass.fullName } + /** + * Same as `getClassNameFromType` but returns the class name before erasure. + */ + def getUnerasedClassNameFromType(tpe: `Type`): String = { + tpe.dealias.typeSymbol.asClass.fullName + } + /** * Returns the nullability of the input parameter types of the scala function object. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d98589db323cc..54f5d2379b3cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -109,9 +109,20 @@ object TestingUDT { } } +object TestingValueClass { + class IntWrapper(val i: Int) extends AnyVal + case class StrWrapper(s: String) extends AnyVal + + case class ValueClassData( + intField: Int, + wrappedInt: IntWrapper, // an int column + strField: String, + wrappedStr: StrWrapper) // a string column +} class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + import TestingValueClass._ // A helper method used to test `ScalaReflection.serializerForType`. private def serializerFor[T: TypeTag]: Expression = @@ -362,4 +373,34 @@ class ScalaReflectionSuite extends SparkFunSuite { assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1) assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0) } + + test("schema for case class that is a value class") { + val schema = schemaFor[IntWrapper] + assert(schema === Schema(IntegerType, nullable = false)) + } + + test("schema for case class that contains value class fields") { + val schema = schemaFor[ValueClassData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("wrappedInt", IntegerType, nullable = false), + StructField("strField", StringType, nullable = true), + StructField("wrappedStr", StringType, nullable = true))), + nullable = true)) + } + + test("schema for array of value class") { + val schema = schemaFor[Array[IntWrapper]] + assert(schema === Schema( + ArrayType(IntegerType, containsNull = false), + nullable = true)) + } + + test("schema for map of value class") { + val schema = schemaFor[Map[IntWrapper, StrWrapper]] + assert(schema === Schema( + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index e9b100b3b30db..5f9278e036324 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -112,6 +112,16 @@ object ReferenceValueClass { case class Container(data: Int) } +case class StringWrapper(s: String) extends AnyVal +case class ValueContainer( + a: Int, + b: StringWrapper) // a string column +class IntWrapper(val i: Int) extends AnyVal // child column doesn't need to be case class +case class ComplexValueClassContainer( + a: Int, + b: ValueContainer, + c: IntWrapper) // an int column + class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -297,11 +307,28 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + // test for Scala value class encodeDecodeTest( PrimitiveValueClass(42), "primitive value class") - encodeDecodeTest( ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") + encodeDecodeTest(StringWrapper("a"), "string value class") + encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class") + encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null") + encodeDecodeTest( + ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")), new IntWrapper(3)), + "complex value class") + encodeDecodeTest( + Array(new IntWrapper(1), new IntWrapper(2), new IntWrapper(3)), + "array of value class") + encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class") + encodeDecodeTest( + Seq(new IntWrapper(1), new IntWrapper(2), new IntWrapper(3)), + "seq of value class") + encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class") + encodeDecodeTest( + Map(new IntWrapper(1) -> StringWrapper("a"), new IntWrapper(2) -> StringWrapper("b")), + "map with value class") encodeDecodeTest(Option(31), "option of int") encodeDecodeTest(Option.empty[Int], "empty option of int")