Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ object ScalaReflection extends ScalaReflection {
case _ => false
}

def isValueClass(tpe: `Type`): Boolean = {
tpe.typeSymbol.asClass.isDerivedValueClass
}

/** Returns the name and type of the underlying parameter of value class `tpe`. */
def getUnderlyingParameterOf(tpe: `Type`): (String, Type) = {
getConstructorParameters(tpe).head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a more official way to get the value class field name?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, I can't find any

}

/**
* When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff
* and lost the required data type, which may lead to runtime error if the real type doesn't
Expand Down Expand Up @@ -165,7 +174,7 @@ object ScalaReflection extends ScalaReflection {
val input = upCastToExpectedType(
GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)

val expr = deserializerFor(tpe, input, walkedTypePath)
val expr = deserializerFor(tpe, input, walkedTypePath, instantiateValueClass = true)
if (nullable) {
expr
} else {
Expand All @@ -180,11 +189,16 @@ object ScalaReflection extends ScalaReflection {
* @param tpe The `Type` of deserialized object.
* @param path The expression which can be used to extract serialized value.
* @param walkedTypePath The paths from top to bottom to access current field when deserializing.
* @param instantiateValueClass If `true`, create an instance for Scala value class.
* This is needed in case value class is top-level or it is
* the type of collection elements. Please refer to the comment in
* value class case for more details.
*/
private def deserializerFor(
tpe: `Type`,
path: Expression,
walkedTypePath: Seq[String]): Expression = cleanUpReflectionObjects {
walkedTypePath: Seq[String],
instantiateValueClass: Boolean = false): Expression = cleanUpReflectionObjects {

/** Returns the current path with a sub-field extracted. */
def addToPath(part: String, dataType: DataType, walkedTypePath: Seq[String]): Expression = {
Expand Down Expand Up @@ -288,7 +302,8 @@ object ScalaReflection extends ScalaReflection {
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
val casted = upCastToExpectedType(element, dataType, newTypePath)
val converter = deserializerFor(elementType, casted, newTypePath)
val converter = deserializerFor(elementType, casted, newTypePath,
instantiateValueClass = true)
if (elementNullable) {
converter
} else {
Expand All @@ -299,7 +314,7 @@ object ScalaReflection extends ScalaReflection {
val arrayData = UnresolvedMapObjects(mapFunction, path)
val arrayCls = arrayClassFor(elementType)

if (elementNullable) {
if (elementNullable || isValueClass(elementType)) {
Invoke(arrayData, "array", arrayCls, returnNullable = false)
} else {
val primitiveMethod = elementType match {
Expand Down Expand Up @@ -328,7 +343,8 @@ object ScalaReflection extends ScalaReflection {
val mapFunction: Expression => Expression = element => {
// upcast the array element to the data type the encoder expected.
val casted = upCastToExpectedType(element, dataType, newTypePath)
val converter = deserializerFor(elementType, casted, newTypePath)
val converter = deserializerFor(elementType, casted, newTypePath,
instantiateValueClass = true)
if (elementNullable) {
converter
} else {
Expand All @@ -351,8 +367,8 @@ object ScalaReflection extends ScalaReflection {

UnresolvedCatalystToExternalMap(
path,
p => deserializerFor(keyType, p, walkedTypePath),
p => deserializerFor(valueType, p, walkedTypePath),
p => deserializerFor(keyType, p, walkedTypePath, instantiateValueClass = true),
p => deserializerFor(valueType, p, walkedTypePath, instantiateValueClass = true),
mirror.runtimeClass(t.typeSymbol.asClass)
)

Expand All @@ -373,6 +389,29 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)

case t if isValueClass(t) =>
val (_, underlyingType) = getUnderlyingParameterOf(t)
val underlyingClsName = getClassNameFromType(underlyingType)
val clsName = getUnerasedClassNameFromType(t)
val newTypePath = s"""- Scala value class: $clsName($underlyingClsName)""" +:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you have an example of this message?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, here is the message when running with value class Id above
- Scala value class: org.apache.spark.sql.catalyst.encoders.Id(scala.Int)

walkedTypePath

// Nested value class is treated as its underlying type
// because the compiler will convert value class in the schema to
// its underlying type.
// However, for value class that is top-level or collection element or
// if it is used as another type (e.g. as its parent trait or generic),
// the compiler keeps the class so we must provide an instance of the
// class too. In other cases, the compiler will handle wrapping/unwrapping
// for us automatically.
val arg = deserializerFor(underlyingType, path, newTypePath)
if (instantiateValueClass) {
val cls = getClassFromType(t)
NewInstance(cls, Seq(arg), ObjectType(cls), propagateNull = false)
} else {
arg
}

case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)

Expand Down Expand Up @@ -617,6 +656,14 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.getClass))
Invoke(obj, "serialize", udt, inputObject :: Nil)

case t if isValueClass(t) =>
val (name, underlyingType) = getUnderlyingParameterOf(t)
val underlyingClsName = getClassNameFromType(underlyingType)
val clsName = getUnerasedClassNameFromType(t)
val newPath = s"""- Scala value class: $clsName($underlyingClsName)""" +: walkedTypePath
val getArg = Invoke(inputObject, name, dataTypeFor(underlyingType))
serializerFor(getArg, underlyingType, newPath)

case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
throw new UnsupportedOperationException(
Expand All @@ -630,13 +677,21 @@ object ScalaReflection extends ScalaReflection {
"cannot be used as field name\n" + walkedTypePath.mkString("\n"))
}

// as a field, value class is represented by its underlying type
val trueFieldType = if (isValueClass(fieldType)) {
val (_, underlyingType) = getUnderlyingParameterOf(fieldType)
underlyingType
} else {
fieldType
}

val fieldValue = Invoke(
AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(fieldType),
returnNullable = !fieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(fieldType)
AssertNotNull(inputObject, walkedTypePath), fieldName, dataTypeFor(trueFieldType),
returnNullable = !trueFieldType.typeSymbol.asClass.isPrimitive)
val clsName = getClassNameFromType(trueFieldType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need such special handling? There is new serialization handling for value class added above, can't we simple get the object of value class here and let recursively call of serializerFor to handle it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried moving the special logic to the value class case but have a concern I don't know how to resolve yet. I need to change dataTypeFor to return ObjectType for top level value class and dataTypeFor(underlyingType) otherwise (see my comment). I'm going with something like this:

private def dataTypeFor(tpe: `Type`, isTopLevelValueClass: Boolean = true)

but this isn't right because:

  • the default value true doesn't make sense for other types
  • if default is false or there is no default value, many places that call this method need to be changed
  • it also feels clunky because dataTypeFor now has to be aware of the context of its parameter

Do you have any suggestion on this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! In order to cover both value class instantiated and not instantiated cases, I think we may need this special handling.

val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
expressions.Literal(fieldName) ::
serializerFor(fieldValue, fieldType, newPath, seenTypeSet + t) :: Nil
serializerFor(fieldValue, trueFieldType, newPath, seenTypeSet + t) :: Nil
})
val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
Expand Down Expand Up @@ -773,6 +828,9 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
case t if isValueClass(t) =>
val (_, underlyingType) = getUnderlyingParameterOf(t)
schemaFor(underlyingType)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Schema(StructType(
Expand Down Expand Up @@ -930,6 +988,13 @@ trait ScalaReflection extends Logging {
tpe.dealias.erasure.typeSymbol.asClass.fullName
}

/**
* Same as `getClassNameFromType` but returns the class name before erasure.
*/
def getUnerasedClassNameFromType(tpe: `Type`): String = {
tpe.dealias.typeSymbol.asClass.fullName
}

/**
* Returns the nullability of the input parameter types of the scala function object.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,20 @@ object TestingUDT {
}
}

object TestingValueClass {
class IntWrapper(val i: Int) extends AnyVal
case class StrWrapper(s: String) extends AnyVal

case class ValueClassData(
intField: Int,
wrappedInt: IntWrapper, // an int column
strField: String,
wrappedStr: StrWrapper) // a string column
}

class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._
import TestingValueClass._

// A helper method used to test `ScalaReflection.serializerForType`.
private def serializerFor[T: TypeTag]: Expression =
Expand Down Expand Up @@ -362,4 +373,34 @@ class ScalaReflectionSuite extends SparkFunSuite {
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
}

test("schema for case class that is a value class") {
val schema = schemaFor[IntWrapper]
assert(schema === Schema(IntegerType, nullable = false))
}

test("schema for case class that contains value class fields") {
val schema = schemaFor[ValueClassData]
assert(schema === Schema(
StructType(Seq(
StructField("intField", IntegerType, nullable = false),
StructField("wrappedInt", IntegerType, nullable = false),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to confirm, scala value class for primitive type can't be null?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, value class in general cannot be null since it is a subtype of AnyVal

StructField("strField", StringType, nullable = true),
StructField("wrappedStr", StringType, nullable = true))),
nullable = true))
}

test("schema for array of value class") {
val schema = schemaFor[Array[IntWrapper]]
assert(schema === Schema(
ArrayType(IntegerType, containsNull = false),
nullable = true))
}

test("schema for map of value class") {
val schema = schemaFor[Map[IntWrapper, StrWrapper]]
assert(schema === Schema(
MapType(IntegerType, StringType, valueContainsNull = true),
nullable = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,16 @@ object ReferenceValueClass {
case class Container(data: Int)
}

case class StringWrapper(s: String) extends AnyVal
case class ValueContainer(
a: Int,
b: StringWrapper) // a string column
class IntWrapper(val i: Int) extends AnyVal // child column doesn't need to be case class
case class ComplexValueClassContainer(
a: Int,
b: ValueContainer,
c: IntWrapper) // an int column

class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)

Expand Down Expand Up @@ -297,11 +307,28 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
}

// test for Scala value class
encodeDecodeTest(
PrimitiveValueClass(42), "primitive value class")

encodeDecodeTest(
ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")
encodeDecodeTest(StringWrapper("a"), "string value class")
encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class")
encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null")
encodeDecodeTest(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also test with null values?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, I added a test with StringWrapper(null)

ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")), new IntWrapper(3)),
"complex value class")
encodeDecodeTest(
Array(new IntWrapper(1), new IntWrapper(2), new IntWrapper(3)),
"array of value class")
encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class")
encodeDecodeTest(
Seq(new IntWrapper(1), new IntWrapper(2), new IntWrapper(3)),
"seq of value class")
encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class")
encodeDecodeTest(
Map(new IntWrapper(1) -> StringWrapper("a"), new IntWrapper(2) -> StringWrapper("b")),
"map with value class")

encodeDecodeTest(Option(31), "option of int")
encodeDecodeTest(Option.empty[Int], "empty option of int")
Expand Down