diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4de7e5caa8862..bab407b4311a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -956,6 +956,19 @@ trait ScalaReflection extends Logging { tag.in(mirror).tpe.dealias } + private def isValueClass(tpe: Type): Boolean = { + tpe.typeSymbol.asClass.isDerivedValueClass + } + + private def isTypeParameter(tpe: Type): Boolean = { + tpe.typeSymbol.isParameter + } + + /** Returns the name and type of the underlying parameter of value class `tpe`. */ + private def getUnderlyingTypeOfValueClass(tpe: `Type`): Type = { + getConstructorParameters(tpe).head._2 + } + /** * Returns the parameter names and types for the primary constructor of this type. * @@ -967,15 +980,17 @@ trait ScalaReflection extends Logging { val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams val TypeRef(_, _, actualTypeArgs) = dealiasedTpe val params = constructParams(dealiasedTpe) - // if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int]) - if (actualTypeArgs.nonEmpty) { - params.map { p => - p.name.decodedName.toString -> - p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - } - } else { - params.map { p => - p.name.decodedName.toString -> p.typeSignature + params.map { p => + val paramTpe = p.typeSignature + if (isTypeParameter(paramTpe)) { + // if there are type variables to fill in, do the substitution + // (SomeClass[T] -> SomeClass[Int]) + p.name.decodedName.toString -> paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs) + } else if (isValueClass(paramTpe)) { + // Replace value class with underlying type + p.name.decodedName.toString -> getUnderlyingTypeOfValueClass(paramTpe) + } else { + p.name.decodedName.toString -> paramTpe } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 164bbd7f34d04..d86f9865edca5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -156,8 +156,19 @@ object TraitProductWithNoConstructorCompanion {} trait TraitProductWithNoConstructorCompanion extends Product1[Int] {} +object TestingValueClass { + case class IntWrapper(val i: Int) extends AnyVal + case class StrWrapper(s: String) extends AnyVal + + case class ValueClassData(intField: Int, + wrappedInt: IntWrapper, // an int column + strField: String, + wrappedStr: StrWrapper) // a string column +} + class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + import TestingValueClass._ // A helper method used to test `ScalaReflection.serializerForType`. private def serializerFor[T: TypeTag]: Expression = @@ -451,4 +462,87 @@ class ScalaReflectionSuite extends SparkFunSuite { StructField("e", StringType, true)))) assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum])) } + + test("schema for case class that is a value class") { + val schema = schemaFor[IntWrapper] + assert( + schema === Schema(StructType(Seq(StructField("i", IntegerType, false))), nullable = true)) + } + + test("SPARK-20384: schema for case class that contains value class fields") { + val schema = schemaFor[ValueClassData] + assert( + schema === Schema( + StructType(Seq( + StructField("intField", IntegerType, nullable = false), + StructField("wrappedInt", IntegerType, nullable = false), + StructField("strField", StringType), + StructField("wrappedStr", StringType) + )), + nullable = true)) + } + + test("SPARK-20384: schema for array of value class") { + val schema = schemaFor[Array[IntWrapper]] + assert( + schema === Schema( + ArrayType(StructType(Seq(StructField("i", IntegerType, false))), containsNull = true), + nullable = true)) + } + + test("SPARK-20384: schema for map of value class") { + val schema = schemaFor[Map[IntWrapper, StrWrapper]] + assert( + schema === Schema( + MapType( + StructType(Seq(StructField("i", IntegerType, false))), + StructType(Seq(StructField("s", StringType))), + valueContainsNull = true), + nullable = true)) + } + + test("SPARK-20384: schema for tuple_2 of value class") { + val schema = schemaFor[(IntWrapper, StrWrapper)] + assert( + schema === Schema( + StructType( + Seq( + StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))), + StructField("_2", StructType(Seq(StructField("s", StringType)))) + ) + ), + nullable = true)) + } + + test("SPARK-20384: schema for tuple_3 of value class") { + val schema = schemaFor[(IntWrapper, StrWrapper, StrWrapper)] + assert( + schema === Schema( + StructType( + Seq( + StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))), + StructField("_2", StructType(Seq(StructField("s", StringType)))), + StructField("_3", StructType(Seq(StructField("s", StringType)))) + ) + ), + nullable = true)) + } + + test("SPARK-20384: schema for nested tuple of value class") { + val schema = schemaFor[(IntWrapper, (StrWrapper, StrWrapper))] + assert( + schema === Schema( + StructType( + Seq( + StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))), + StructField("_2", StructType( + Seq( + StructField("_1", StructType(Seq(StructField("s", StringType)))), + StructField("_2", StructType(Seq(StructField("s", StringType))))) + ) + ) + ) + ), + nullable = true)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index bf4afac2f8be6..ae5ce6021efb2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -116,6 +116,21 @@ object ReferenceValueClass { } case class IntAndString(i: Int, s: String) +case class StringWrapper(s: String) extends AnyVal +case class ValueContainer( + a: Int, + b: StringWrapper) // a string column +case class IntWrapper(i: Int) extends AnyVal +case class ComplexValueClassContainer( + a: Int, + b: ValueContainer, + c: IntWrapper) +case class SeqOfValueClass(s: Seq[StringWrapper]) +case class MapOfValueClassKey(m: Map[IntWrapper, String]) +case class MapOfValueClassValue(m: Map[String, StringWrapper]) +case class OptionOfValueClassValue(o: Option[StringWrapper]) +case class CaseClassWithGeneric[T](generic: T, value: IntWrapper) + class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -391,12 +406,54 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + // test for value classes encodeDecodeTest( PrimitiveValueClass(42), "primitive value class") encodeDecodeTest( ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class") + encodeDecodeTest(StringWrapper("a"), "string value class") + encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class") + encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null") + encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")), + IntWrapper(3)), "complex value class") + encodeDecodeTest( + Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)), + "array of value class") + encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class") + encodeDecodeTest( + Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)), + "seq of value class") + encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class") + encodeDecodeTest( + Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> StringWrapper("b")), + "map with value class") + + // test for nested value class collections + encodeDecodeTest( + MapOfValueClassKey(Map(IntWrapper(1)-> "a")), + "case class with map of value class key") + encodeDecodeTest( + MapOfValueClassValue(Map("a"-> StringWrapper("b"))), + "case class with map of value class value") + encodeDecodeTest( + SeqOfValueClass(Seq(StringWrapper("a"))), + "case class with seq of class value") + encodeDecodeTest( + OptionOfValueClassValue(Some(StringWrapper("a"))), + "case class with option of class value") + encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2")), + "tuple2 of class value") + encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2"), StringWrapper("a_3")), + "tuple3 of class value") + encodeDecodeTest(((StringWrapper("a_1"), StringWrapper("a_2")), StringWrapper("b_2")), + "nested tuple._1 of class value") + encodeDecodeTest((StringWrapper("a_1"), (StringWrapper("b_1"), StringWrapper("b_2"))), + "nested tuple._2 of class value") + encodeDecodeTest(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2)), + "case class with value class in generic parameter") + encodeDecodeTest(Option(31), "option of int") encodeDecodeTest(Option.empty[Int], "empty option of int") encodeDecodeTest(Option("abc"), "option of string") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1e3d219220fa6..cf2c81881f616 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession} -import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2} +import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.Utils @@ -834,6 +834,56 @@ class DataFrameSuite extends QueryTest assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol")) } + test("SPARK-20384: Value class filter") { + val df = spark.sparkContext + .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c"))) + .toDF() + val filtered = df.where("s = \"a\"") + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF) + } + + test("SPARK-20384: Tuple2 of value class filter") { + val df = spark.sparkContext + .parallelize(Seq( + (StringWrapper("a1"), StringWrapper("a2")), + (StringWrapper("b1"), StringWrapper("b2")))) + .toDF() + val filtered = df.where("_2.s = \"a2\"") + checkAnswer(filtered, + spark.sparkContext.parallelize(Seq((StringWrapper("a1"), StringWrapper("a2")))).toDF) + } + + test("SPARK-20384: Tuple3 of value class filter") { + val df = spark.sparkContext + .parallelize(Seq( + (StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")), + (StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3")))) + .toDF() + val filtered = df.where("_3.s = \"a3\"") + checkAnswer(filtered, + spark.sparkContext.parallelize( + Seq((StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")))).toDF) + } + + test("SPARK-20384: Array value class filter") { + val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b"))) + val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d"))) + + val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF + val filtered = df.where(array_contains(col("wrappers.s"), "b")) + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF) + } + + test("SPARK-20384: Nested value class filter") { + val a = ContainerStringWrapper(StringWrapper("a")) + val b = ContainerStringWrapper(StringWrapper("b")) + + val df = spark.sparkContext.parallelize(Seq(a, b)).toDF + // flat value class, `s` field is not in schema + val filtered = df.where("wrapper = \"a\"") + checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF) + } + private lazy val person2: DataFrame = Seq( ("Bob", 16, 176), ("Alice", 32, 164), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 307c4f33b2035..21064b5afddeb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -450,4 +450,7 @@ private[sql] object SQLTestData { case class CourseSales(course: String, year: Int, earnings: Double) case class TrainingSales(training: String, sales: CourseSales) case class IntervalData(data: CalendarInterval) + case class StringWrapper(s: String) extends AnyVal + case class ArrayStringWrapper(wrappers: Seq[StringWrapper]) + case class ContainerStringWrapper(wrapper: StringWrapper) }