Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -956,6 +956,19 @@ trait ScalaReflection extends Logging {
tag.in(mirror).tpe.dealias
}

private def isValueClass(tpe: Type): Boolean = {
tpe.typeSymbol.asClass.isDerivedValueClass
}

private def isTypeParameter(tpe: Type): Boolean = {
tpe.typeSymbol.isParameter
}

/** Returns the name and type of the underlying parameter of value class `tpe`. */
private def getUnderlyingTypeOfValueClass(tpe: `Type`): Type = {
getConstructorParameters(tpe).head._2
}

/**
* Returns the parameter names and types for the primary constructor of this type.
*
Expand All @@ -967,15 +980,17 @@ trait ScalaReflection extends Logging {
val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
val params = constructParams(dealiasedTpe)
// if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
if (actualTypeArgs.nonEmpty) {
params.map { p =>
p.name.decodedName.toString ->
p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
}
} else {
params.map { p =>
p.name.decodedName.toString -> p.typeSignature
params.map { p =>
val paramTpe = p.typeSignature
if (isTypeParameter(paramTpe)) {
// if there are type variables to fill in, do the substitution
// (SomeClass[T] -> SomeClass[Int])
p.name.decodedName.toString -> paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs)
} else if (isValueClass(paramTpe)) {
// Replace value class with underlying type
p.name.decodedName.toString -> getUnderlyingTypeOfValueClass(paramTpe)
} else {
p.name.decodedName.toString -> paramTpe
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,19 @@ object TraitProductWithNoConstructorCompanion {}

trait TraitProductWithNoConstructorCompanion extends Product1[Int] {}

object TestingValueClass {
case class IntWrapper(val i: Int) extends AnyVal
case class StrWrapper(s: String) extends AnyVal

case class ValueClassData(intField: Int,
wrappedInt: IntWrapper, // an int column
strField: String,
wrappedStr: StrWrapper) // a string column
}

class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._
import TestingValueClass._

// A helper method used to test `ScalaReflection.serializerForType`.
private def serializerFor[T: TypeTag]: Expression =
Expand Down Expand Up @@ -451,4 +462,87 @@ class ScalaReflectionSuite extends SparkFunSuite {
StructField("e", StringType, true))))
assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum]))
}

test("schema for case class that is a value class") {
val schema = schemaFor[IntWrapper]
assert(
schema === Schema(StructType(Seq(StructField("i", IntegerType, false))), nullable = true))
}

test("SPARK-20384: schema for case class that contains value class fields") {
val schema = schemaFor[ValueClassData]
assert(
schema === Schema(
StructType(Seq(
StructField("intField", IntegerType, nullable = false),
StructField("wrappedInt", IntegerType, nullable = false),
StructField("strField", StringType),
StructField("wrappedStr", StringType)
)),
nullable = true))
}

test("SPARK-20384: schema for array of value class") {
val schema = schemaFor[Array[IntWrapper]]
assert(
schema === Schema(
ArrayType(StructType(Seq(StructField("i", IntegerType, false))), containsNull = true),
nullable = true))
}

test("SPARK-20384: schema for map of value class") {
val schema = schemaFor[Map[IntWrapper, StrWrapper]]
assert(
schema === Schema(
MapType(
StructType(Seq(StructField("i", IntegerType, false))),
StructType(Seq(StructField("s", StringType))),
valueContainsNull = true),
nullable = true))
}

test("SPARK-20384: schema for tuple_2 of value class") {
val schema = schemaFor[(IntWrapper, StrWrapper)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit weird that the schema of case class of value classes is not consistent with the schema of tuple of value classes, but there seems no better solution as we need to keep backward compatibility.

assert(
schema === Schema(
StructType(
Seq(
StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))),
StructField("_2", StructType(Seq(StructField("s", StringType))))
)
),
nullable = true))
}

test("SPARK-20384: schema for tuple_3 of value class") {
val schema = schemaFor[(IntWrapper, StrWrapper, StrWrapper)]
assert(
schema === Schema(
StructType(
Seq(
StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))),
StructField("_2", StructType(Seq(StructField("s", StringType)))),
StructField("_3", StructType(Seq(StructField("s", StringType))))
)
),
nullable = true))
}

test("SPARK-20384: schema for nested tuple of value class") {
val schema = schemaFor[(IntWrapper, (StrWrapper, StrWrapper))]
assert(
schema === Schema(
StructType(
Seq(
StructField("_1", StructType(Seq(StructField("i", IntegerType, false)))),
StructField("_2", StructType(
Seq(
StructField("_1", StructType(Seq(StructField("s", StringType)))),
StructField("_2", StructType(Seq(StructField("s", StringType)))))
)
)
)
),
nullable = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ object ReferenceValueClass {
}
case class IntAndString(i: Int, s: String)

case class StringWrapper(s: String) extends AnyVal
case class ValueContainer(
a: Int,
b: StringWrapper) // a string column
case class IntWrapper(i: Int) extends AnyVal
case class ComplexValueClassContainer(
a: Int,
b: ValueContainer,
c: IntWrapper)
case class SeqOfValueClass(s: Seq[StringWrapper])
case class MapOfValueClassKey(m: Map[IntWrapper, String])
case class MapOfValueClassValue(m: Map[String, StringWrapper])
case class OptionOfValueClassValue(o: Option[StringWrapper])
case class CaseClassWithGeneric[T](generic: T, value: IntWrapper)

class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)

Expand Down Expand Up @@ -391,12 +406,54 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
}

// test for value classes
encodeDecodeTest(
PrimitiveValueClass(42), "primitive value class")

encodeDecodeTest(
ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")

encodeDecodeTest(StringWrapper("a"), "string value class")
encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class")
encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class with null")
encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")),
IntWrapper(3)), "complex value class")
encodeDecodeTest(
Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
"array of value class")
encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class")
encodeDecodeTest(
Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
"seq of value class")
encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class")
encodeDecodeTest(
Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> StringWrapper("b")),
"map with value class")

// test for nested value class collections
encodeDecodeTest(
MapOfValueClassKey(Map(IntWrapper(1)-> "a")),
"case class with map of value class key")
encodeDecodeTest(
MapOfValueClassValue(Map("a"-> StringWrapper("b"))),
"case class with map of value class value")
encodeDecodeTest(
SeqOfValueClass(Seq(StringWrapper("a"))),
"case class with seq of class value")
encodeDecodeTest(
OptionOfValueClassValue(Some(StringWrapper("a"))),
"case class with option of class value")
encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2")),
"tuple2 of class value")
encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2"), StringWrapper("a_3")),
"tuple3 of class value")
encodeDecodeTest(((StringWrapper("a_1"), StringWrapper("a_2")), StringWrapper("b_2")),
"nested tuple._1 of class value")
encodeDecodeTest((StringWrapper("a_1"), (StringWrapper("b_1"), StringWrapper("b_2"))),
"nested tuple._2 of class value")
encodeDecodeTest(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2)),
"case class with value class in generic parameter")

encodeDecodeTest(Option(31), "option of int")
encodeDecodeTest(Option.empty[Int], "empty option of int")
encodeDecodeTest(Option("abc"), "option of string")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -834,6 +834,56 @@ class DataFrameSuite extends QueryTest
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}

test("SPARK-20384: Value class filter") {
val df = spark.sparkContext
.parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c")))
.toDF()
val filtered = df.where("s = \"a\"")
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF)
}

test("SPARK-20384: Tuple2 of value class filter") {
val df = spark.sparkContext
.parallelize(Seq(
(StringWrapper("a1"), StringWrapper("a2")),
(StringWrapper("b1"), StringWrapper("b2"))))
.toDF()
val filtered = df.where("_2.s = \"a2\"")
checkAnswer(filtered,
spark.sparkContext.parallelize(Seq((StringWrapper("a1"), StringWrapper("a2")))).toDF)
}

test("SPARK-20384: Tuple3 of value class filter") {
val df = spark.sparkContext
.parallelize(Seq(
(StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")),
(StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3"))))
.toDF()
val filtered = df.where("_3.s = \"a3\"")
checkAnswer(filtered,
spark.sparkContext.parallelize(
Seq((StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")))).toDF)
}

test("SPARK-20384: Array value class filter") {
val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b")))
val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d")))

val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF
val filtered = df.where(array_contains(col("wrappers.s"), "b"))
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF)
}

test("SPARK-20384: Nested value class filter") {
val a = ContainerStringWrapper(StringWrapper("a"))
val b = ContainerStringWrapper(StringWrapper("b"))

val df = spark.sparkContext.parallelize(Seq(a, b)).toDF
// flat value class, `s` field is not in schema
val filtered = df.where("wrapper = \"a\"")
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF)
}

private lazy val person2: DataFrame = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,7 @@ private[sql] object SQLTestData {
case class CourseSales(course: String, year: Int, earnings: Double)
case class TrainingSales(training: String, sales: CourseSales)
case class IntervalData(data: CalendarInterval)
case class StringWrapper(s: String) extends AnyVal
case class ArrayStringWrapper(wrappers: Seq[StringWrapper])
case class ContainerStringWrapper(wrapper: StringWrapper)
}