diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4de7e5caa8862..fdb475f76135a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -357,6 +357,15 @@ object ScalaReflection extends ScalaReflection { dataType = ObjectType(udt.getClass)) Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil) + case t if isValueClass(t) => + val (_, underlyingType) = getUnderlyingParameterOf(t) + val underlyingClsName = getClassNameFromType(underlyingType) + val clsName = getUnerasedClassNameFromType(t) + val newTypePath = walkedTypePath.recordValueClass(clsName, underlyingClsName) + val arg = deserializerFor(underlyingType, path, newTypePath) + val cls = getClassFromType(t) + NewInstance(cls, Seq(arg), ObjectType(cls), propagateNull = false) + case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) @@ -579,6 +588,15 @@ object ScalaReflection extends ScalaReflection { val udtClass = udt.getClass createSerializerForUserDefinedType(inputObject, udt, udtClass) + case t if isValueClass(t) => + val (name, underlyingType) = getUnderlyingParameterOf(t) + val underlyingClsName = getClassNameFromType(underlyingType) + val clsName = getUnerasedClassNameFromType(t) + val newPath = walkedTypePath.recordValueClass(clsName, underlyingClsName) + val getArg = Invoke(KnownNotNull(inputObject), name, dataTypeFor(underlyingType), + returnNullable = !underlyingType.typeSymbol.asClass.isPrimitive) + serializerFor(getArg, underlyingType, newPath) + case t if definedByConstructorParams(t) => if (seenTypeSet.contains(t)) { throw QueryExecutionErrors.cannotHaveCircularReferencesInClassError(t.toString) @@ -787,6 +805,9 @@ object ScalaReflection extends ScalaReflection { case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false) case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false) case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false) + case t if isValueClass(t) => + val (_, underlyingType) = getUnderlyingParameterOf(t) + schemaFor(underlyingType) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) Schema(StructType( @@ -956,6 +977,22 @@ trait ScalaReflection extends Logging { tag.in(mirror).tpe.dealias } + /** + * Same as `getClassNameFromType` but returns the class name before erasure. + */ + def getUnerasedClassNameFromType(tpe: `Type`): String = { + tpe.dealias.typeSymbol.asClass.fullName + } + + def isValueClass(tpe: `Type`): Boolean = { + tpe.typeSymbol.asClass.isDerivedValueClass + } + + /** Returns the name and type of the underlying parameter of value class `tpe`. */ + def getUnderlyingParameterOf(tpe: `Type`): (String, Type) = { + getConstructorParameters(tpe).head + } + /** * Returns the parameter names and types for the primary constructor of this type. * @@ -968,15 +1005,27 @@ trait ScalaReflection extends Logging { val TypeRef(_, _, actualTypeArgs) = dealiasedTpe val params = constructParams(dealiasedTpe) // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) - if (actualTypeArgs.nonEmpty) { - params.map { p => + params.map { p => { + if (isTypeParameter(p)) { p.name.decodedName.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } else { + p.name.decodedName.toString -> unwrapIfValueClassType(p.typeSignature) } + } + } + } + + private def isTypeParameter(sym: Symbol): Boolean = { + sym.typeSignature.typeSymbol.isParameter + } + + private def unwrapIfValueClassType(tpe: Type): Type = { + if (isValueClass(tpe)) { + val (_, underlyingType) = getUnderlyingParameterOf(tpe) + underlyingType } else { - params.map { p => - p.name.decodedName.toString -> p.typeSignature - } + tpe } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala index cbf1f01344c92..09af3c7474404 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/WalkedTypePath.scala @@ -46,6 +46,9 @@ case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) { def recordField(className: String, fieldName: String): WalkedTypePath = newInstance(s"""- field (class: "$className", name: "$fieldName")""") + def recordValueClass(className: String, underlyingClassName: String): WalkedTypePath = + newInstance(s"""- Scala value class: $className($underlyingClassName)""") + override def toString: String = { walkedPaths.mkString("\n") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 164bbd7f34d04..c043e05354d88 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -156,6 +156,16 @@ object TraitProductWithNoConstructorCompanion {} trait TraitProductWithNoConstructorCompanion extends Product1[Int] {} +class IntWrapper(val i: Int) extends AnyVal +case class StrWrapper(s: String) extends AnyVal +case class CaseClassWithGeneric[T](generic: T) + +case class ValueClassData( + intField: Int, + wrappedInt: IntWrapper, // an int column + strField: String, + wrappedStr: StrWrapper) // a string column + class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ @@ -451,4 +461,51 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("e", StringType, true)))) assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum])) } + + test("SPARK-20384: schema for case class that is a value class") { + val schema = schemaFor[IntWrapper] + assert(schema === Schema(IntegerType, nullable = false)) + } + + test("SPARK-20384: schema for case class that contains value class fields") { + val schema = schemaFor[ValueClassData] + assert(schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("wrappedInt", IntegerType, nullable = false), + StructField("strField", StringType, nullable = true), + StructField("wrappedStr", StringType, nullable = true))), + nullable = true)) + } + + test("SPARK-20384: schema for array of value class") { + val schema = schemaFor[Array[IntWrapper]] + assert(schema === Schema( + ArrayType(IntegerType, containsNull = false), + nullable = true)) + } + + test("SPARK-20384: schema for map of value class") { + val schema = schemaFor[Map[IntWrapper, StrWrapper]] + assert(schema === Schema( + MapType(IntegerType, StringType, valueContainsNull = true), + nullable = true)) + } + + test("SPARK-20384: schema for tuple with value class") { + val schema = schemaFor[(IntWrapper, StrWrapper)] + assert(schema === Schema( + StructType(Seq( + StructField("_1", IntegerType, nullable = false), + StructField("_2", StringType, nullable = true))), + nullable = true)) + } + + test("SPARK-20384: schema for case class with generic field") { + val schema = schemaFor[CaseClassWithGeneric[IntWrapper]] + assert(schema === Schema( + StructType(Seq( + StructField("generic", IntegerType, nullable = false))), + nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index bf4afac2f8be6..69c5512b3c3be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -116,6 +116,16 @@ object ReferenceValueClass { } case class IntAndString(i: Int, s: String) +case class StringWrapper(s: String) extends AnyVal +case class ValueContainer(a: Int, b: StringWrapper) +case class IntWrapper(i: Int) extends AnyVal +case class ComplexValueClassContainer(a: Int, b: ValueContainer, c: IntWrapper) +case class SeqOfValueClass(s: Seq[StringWrapper]) +case class MapOfValueClassKey(m: Map[IntWrapper, String]) +case class MapOfValueClassValue(m: Map[String, StringWrapper]) +case class OptionOfValueClassValue(o: Option[StringWrapper]) +case class CaseClassWithGeneric[T, S](t: T, s: S) + class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -397,6 +407,45 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes encodeDecodeTest( ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") + encodeDecodeTest(StringWrapper("a"), "SPARK-20384: string value class") + encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "SPARK-20384: nested value class") + encodeDecodeTest( + ValueContainer(1, StringWrapper(null)), + "SPARK-20384: nested value class with null") + encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")), + IntWrapper(3)), "SPARK-20384: complex value class") + encodeDecodeTest( + Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)), + "SPARK-20384: array of value class") + encodeDecodeTest(Array.empty[IntWrapper], "SPARK-20384: empty array of value class") + encodeDecodeTest( + Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)), + "SPARK-20384: seq of value class") + encodeDecodeTest(Seq.empty[IntWrapper], "SPARK-20384: empty seq of value class") + encodeDecodeTest( + Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> StringWrapper("b")), + "SPARK-20384: map with value class") + + // test for nested value class collections + encodeDecodeTest( + MapOfValueClassKey(Map(IntWrapper(1)-> "a")), + "SPARK-20384: case class with map of value class key") + encodeDecodeTest( + MapOfValueClassValue(Map("a"-> StringWrapper("b"))), + "SPARK-20384: case class with map of value class value") + encodeDecodeTest( + SeqOfValueClass(Seq(StringWrapper("a"))), + "SPARK-20384: case class with seq of class value") + encodeDecodeTest( + OptionOfValueClassValue(Some(StringWrapper("a"))), + "SPARK-20384: case class with option of class value") + encodeDecodeTest( + (IntWrapper(1), StringWrapper("a")), + "SPARK-20384: Tuple with value classes") + encodeDecodeTest( + CaseClassWithGeneric(IntWrapper(1), StringWrapper("a")), + "SPARK-20384: case class with value class in generic fields") + encodeDecodeTest(Option(31), "option of int") encodeDecodeTest(Option.empty[Int], "empty option of int") encodeDecodeTest(Option("abc"), "option of string") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1e3d219220fa6..278d43697809c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} -import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2} +import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -834,6 +834,56 @@ class DataFrameSuite extends QueryTest assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } + test("SPARK-20384: Value class filter") { + val df = spark.sparkContext + .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c"))) + .toDF() + val filtered = df.where("value = \"a\"") + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF) + } + + test("SPARK-20384: Tuple2 of value class filter") { + val df = spark.sparkContext + .parallelize(Seq( + (StringWrapper("a1"), StringWrapper("a2")), + (StringWrapper("b1"), StringWrapper("b2")))) + .toDF() + val filtered = df.where("_2 = \"a2\"") + checkAnswer(filtered, + spark.sparkContext.parallelize(Seq((StringWrapper("a1"), StringWrapper("a2")))).toDF) + } + + test("SPARK-20384: Tuple3 of value class filter") { + val df = spark.sparkContext + .parallelize(Seq( + (StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")), + (StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3")))) + .toDF() + val filtered = df.where("_3 = \"a3\"") + checkAnswer(filtered, + spark.sparkContext.parallelize( + Seq((StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")))).toDF) + } + + test("SPARK-20384: Array value class filter") { + val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b"))) + val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d"))) + + val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF + val filtered = df.where(array_contains(col("wrappers"), "b")) + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF) + } + + test("SPARK-20384: Nested value class filter") { + val a = ContainerStringWrapper(StringWrapper("a")) + val b = ContainerStringWrapper(StringWrapper("b")) + + val df = spark.sparkContext.parallelize(Seq(a, b)).toDF + // flat value class, `s` field is not in schema + val filtered = df.where("wrapper = \"a\"") + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF) + } + private lazy val person2: DataFrame = Seq( ("Bob", 16, 176), ("Alice", 32, 164), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 0dae9a2cd2886..c8156f1f58654 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.test.SQLTestData.{StringWrapper} import org.apache.spark.sql.types._ case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2) @@ -2040,6 +2041,53 @@ class DatasetSuite extends QueryTest joined, (1, 1), (2, 2), (3, 3)) } + + test("SPARK-20384: Dataset with only value class") { + val intData = Seq(1, 2, 3) + val wrappedData = intData.map(IntWrapper(_)) + val wrappedDs = wrappedData.toDS() + val intDs = intData.toDS() + + assert(intDs.schema === wrappedDs.schema) + + checkDataset(wrappedDs, wrappedData: _*) + } + + test("SPARK-20384: Model with StringWrapper as field and generic field") { + val data = Seq( + ThreeStrings("a", StringWrapper("a"), StringWrapper("a")), + ThreeStrings("b", StringWrapper("b"), StringWrapper("b")), + ThreeStrings("c", StringWrapper("c"), StringWrapper("c")) + ) + val ds = data.toDS() + + val expectedSchema = StructType(Seq( + StructField("value1", StringType, nullable = true), + StructField("value2", StringType, nullable = true), + StructField("value3", StringType, nullable = true) + )) + + assert(ds.schema === expectedSchema) + checkDataset(ds, data: _*) + } + + test("SPARK-20384: Dataset with IntWrapper as field and generic field") { + val data = Seq( + ThreeInts(1, IntWrapper(1), IntWrapper(1)), + ThreeInts(2, IntWrapper(2), IntWrapper(2)), + ThreeInts(3, IntWrapper(3), IntWrapper(3)) + ) + val ds = data.toDS() + + val expectedSchema = StructType(Seq( + StructField("value1", IntegerType, nullable = false), + StructField("value2", IntegerType, nullable = false), + StructField("value3", IntegerType, nullable = false) + )) + + assert(ds.schema === expectedSchema) + checkDataset(ds, data: _*) + } } case class Bar(a: Int) @@ -2066,6 +2114,9 @@ case class Generic[T](id: T, value: Double) case class OtherTuple(_1: String, _2: Int) case class TupleClass(data: (Int, String)) +case class IntWrapper(wrapped: Int) extends AnyVal +case class ThreeInts[T](value1: Int, value2: IntWrapper, value3: T) +case class ThreeStrings[T](value1: String, value2: StringWrapper, value3: T) class OuterClass extends Serializable { case class InnerClass(a: String) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 307c4f33b2035..21064b5afddeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -450,4 +450,7 @@ private[sql] object SQLTestData { case class CourseSales(course: String, year: Int, earnings: Double) case class TrainingSales(training: String, sales: CourseSales) case class IntervalData(data: CalendarInterval) + case class StringWrapper(s: String) extends AnyVal + case class ArrayStringWrapper(wrappers: Seq[StringWrapper]) + case class ContainerStringWrapper(wrapper: StringWrapper) }