Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,15 @@ object ScalaReflection extends ScalaReflection {
dataType = ObjectType(udt.getClass))
Invoke(obj, "deserialize", ObjectType(udt.userClass), path :: Nil)

case t if isValueClass(t) =>
val (_, underlyingType) = getUnderlyingParameterOf(t)
val underlyingClsName = getClassNameFromType(underlyingType)
val clsName = getUnerasedClassNameFromType(t)
val newTypePath = walkedTypePath.recordValueClass(clsName, underlyingClsName)
val arg = deserializerFor(underlyingType, path, newTypePath)
val cls = getClassFromType(t)
NewInstance(cls, Seq(arg), ObjectType(cls), propagateNull = false)

case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)

Expand Down Expand Up @@ -579,6 +588,15 @@ object ScalaReflection extends ScalaReflection {
val udtClass = udt.getClass
createSerializerForUserDefinedType(inputObject, udt, udtClass)

case t if isValueClass(t) =>
val (name, underlyingType) = getUnderlyingParameterOf(t)
val underlyingClsName = getClassNameFromType(underlyingType)
val clsName = getUnerasedClassNameFromType(t)
val newPath = walkedTypePath.recordValueClass(clsName, underlyingClsName)
val getArg = Invoke(KnownNotNull(inputObject), name, dataTypeFor(underlyingType),
returnNullable = !underlyingType.typeSymbol.asClass.isPrimitive)
serializerFor(getArg, underlyingType, newPath)

case t if definedByConstructorParams(t) =>
if (seenTypeSet.contains(t)) {
throw QueryExecutionErrors.cannotHaveCircularReferencesInClassError(t.toString)
Expand Down Expand Up @@ -787,6 +805,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false)
case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false)
case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false)
case t if isValueClass(t) =>
val (_, underlyingType) = getUnderlyingParameterOf(t)
schemaFor(underlyingType)
case t if definedByConstructorParams(t) =>
val params = getConstructorParameters(t)
Schema(StructType(
Expand Down Expand Up @@ -956,6 +977,22 @@ trait ScalaReflection extends Logging {
tag.in(mirror).tpe.dealias
}

/**
* Same as `getClassNameFromType` but returns the class name before erasure.
*/
def getUnerasedClassNameFromType(tpe: `Type`): String = {
tpe.dealias.typeSymbol.asClass.fullName
}

def isValueClass(tpe: `Type`): Boolean = {
tpe.typeSymbol.asClass.isDerivedValueClass
}

/** Returns the name and type of the underlying parameter of value class `tpe`. */
def getUnderlyingParameterOf(tpe: `Type`): (String, Type) = {
getConstructorParameters(tpe).head
}

/**
* Returns the parameter names and types for the primary constructor of this type.
*
Expand All @@ -968,15 +1005,27 @@ trait ScalaReflection extends Logging {
val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
val params = constructParams(dealiasedTpe)
// if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
if (actualTypeArgs.nonEmpty) {
params.map { p =>
params.map { p => {
if (isTypeParameter(p)) {
p.name.decodedName.toString ->
p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
} else {
p.name.decodedName.toString -> unwrapIfValueClassType(p.typeSignature)
}
}
}
}

private def isTypeParameter(sym: Symbol): Boolean = {
sym.typeSignature.typeSymbol.isParameter
}

private def unwrapIfValueClassType(tpe: Type): Type = {
if (isValueClass(tpe)) {
val (_, underlyingType) = getUnderlyingParameterOf(tpe)
underlyingType
} else {
params.map { p =>
p.name.decodedName.toString -> p.typeSignature
}
tpe
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ case class WalkedTypePath(private val walkedPaths: Seq[String] = Nil) {
def recordField(className: String, fieldName: String): WalkedTypePath =
newInstance(s"""- field (class: "$className", name: "$fieldName")""")

def recordValueClass(className: String, underlyingClassName: String): WalkedTypePath =
newInstance(s"""- Scala value class: $className($underlyingClassName)""")

override def toString: String = {
walkedPaths.mkString("\n")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,16 @@ object TraitProductWithNoConstructorCompanion {}

trait TraitProductWithNoConstructorCompanion extends Product1[Int] {}

class IntWrapper(val i: Int) extends AnyVal
case class StrWrapper(s: String) extends AnyVal
case class CaseClassWithGeneric[T](generic: T)

case class ValueClassData(
intField: Int,
wrappedInt: IntWrapper, // an int column
strField: String,
wrappedStr: StrWrapper) // a string column

class ScalaReflectionSuite extends SparkFunSuite {
import org.apache.spark.sql.catalyst.ScalaReflection._

Expand Down Expand Up @@ -451,4 +461,51 @@ class ScalaReflectionSuite extends SparkFunSuite {
StructField("e", StringType, true))))
assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum]))
}

test("SPARK-20384: schema for case class that is a value class") {
val schema = schemaFor[IntWrapper]
assert(schema === Schema(IntegerType, nullable = false))
}

test("SPARK-20384: schema for case class that contains value class fields") {
val schema = schemaFor[ValueClassData]
assert(schema === Schema(
StructType(Seq(
StructField("intField", IntegerType, nullable = false),
StructField("wrappedInt", IntegerType, nullable = false),
StructField("strField", StringType, nullable = true),
StructField("wrappedStr", StringType, nullable = true))),
nullable = true))
}

test("SPARK-20384: schema for array of value class") {
val schema = schemaFor[Array[IntWrapper]]
assert(schema === Schema(
ArrayType(IntegerType, containsNull = false),
nullable = true))
}

test("SPARK-20384: schema for map of value class") {
val schema = schemaFor[Map[IntWrapper, StrWrapper]]
assert(schema === Schema(
MapType(IntegerType, StringType, valueContainsNull = true),
nullable = true))
}

test("SPARK-20384: schema for tuple with value class") {
val schema = schemaFor[(IntWrapper, StrWrapper)]
assert(schema === Schema(
StructType(Seq(
StructField("_1", IntegerType, nullable = false),
StructField("_2", StringType, nullable = true))),
nullable = true))
}

test("SPARK-20384: schema for case class with generic field") {
val schema = schemaFor[CaseClassWithGeneric[IntWrapper]]
assert(schema === Schema(
StructType(Seq(
StructField("generic", IntegerType, nullable = false))),
nullable = true))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,16 @@ object ReferenceValueClass {
}
case class IntAndString(i: Int, s: String)

case class StringWrapper(s: String) extends AnyVal
case class ValueContainer(a: Int, b: StringWrapper)
case class IntWrapper(i: Int) extends AnyVal
case class ComplexValueClassContainer(a: Int, b: ValueContainer, c: IntWrapper)
case class SeqOfValueClass(s: Seq[StringWrapper])
case class MapOfValueClassKey(m: Map[IntWrapper, String])
case class MapOfValueClassValue(m: Map[String, StringWrapper])
case class OptionOfValueClassValue(o: Option[StringWrapper])
case class CaseClassWithGeneric[T, S](t: T, s: S)

class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)

Expand Down Expand Up @@ -397,6 +407,45 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
encodeDecodeTest(
ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value class")

encodeDecodeTest(StringWrapper("a"), "SPARK-20384: string value class")
encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "SPARK-20384: nested value class")
encodeDecodeTest(
ValueContainer(1, StringWrapper(null)),
"SPARK-20384: nested value class with null")
encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, StringWrapper("b")),
IntWrapper(3)), "SPARK-20384: complex value class")
encodeDecodeTest(
Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
"SPARK-20384: array of value class")
encodeDecodeTest(Array.empty[IntWrapper], "SPARK-20384: empty array of value class")
encodeDecodeTest(
Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
"SPARK-20384: seq of value class")
encodeDecodeTest(Seq.empty[IntWrapper], "SPARK-20384: empty seq of value class")
encodeDecodeTest(
Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> StringWrapper("b")),
"SPARK-20384: map with value class")

// test for nested value class collections
encodeDecodeTest(
MapOfValueClassKey(Map(IntWrapper(1)-> "a")),
"SPARK-20384: case class with map of value class key")
encodeDecodeTest(
MapOfValueClassValue(Map("a"-> StringWrapper("b"))),
"SPARK-20384: case class with map of value class value")
encodeDecodeTest(
SeqOfValueClass(Seq(StringWrapper("a"))),
"SPARK-20384: case class with seq of class value")
encodeDecodeTest(
OptionOfValueClassValue(Some(StringWrapper("a"))),
"SPARK-20384: case class with option of class value")
encodeDecodeTest(
(IntWrapper(1), StringWrapper("a")),
"SPARK-20384: Tuple with value classes")
encodeDecodeTest(
CaseClassWithGeneric(IntWrapper(1), StringWrapper("a")),
"SPARK-20384: case class with value class in generic fields")

encodeDecodeTest(Option(31), "option of int")
encodeDecodeTest(Option.empty[Int], "empty option of int")
encodeDecodeTest(Option("abc"), "option of string")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, ContainerStringWrapper, DecimalData, StringWrapper, TestData2}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -834,6 +834,56 @@ class DataFrameSuite extends QueryTest
assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
}

test("SPARK-20384: Value class filter") {
val df = spark.sparkContext
.parallelize(Seq(StringWrapper("a"), StringWrapper("b"), StringWrapper("c")))
.toDF()
val filtered = df.where("value = \"a\"")
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF)
}

test("SPARK-20384: Tuple2 of value class filter") {
val df = spark.sparkContext
.parallelize(Seq(
(StringWrapper("a1"), StringWrapper("a2")),
(StringWrapper("b1"), StringWrapper("b2"))))
.toDF()
val filtered = df.where("_2 = \"a2\"")
checkAnswer(filtered,
spark.sparkContext.parallelize(Seq((StringWrapper("a1"), StringWrapper("a2")))).toDF)
}

test("SPARK-20384: Tuple3 of value class filter") {
val df = spark.sparkContext
.parallelize(Seq(
(StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")),
(StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3"))))
.toDF()
val filtered = df.where("_3 = \"a3\"")
checkAnswer(filtered,
spark.sparkContext.parallelize(
Seq((StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")))).toDF)
}

test("SPARK-20384: Array value class filter") {
val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b")))
val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d")))

val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF
val filtered = df.where(array_contains(col("wrappers"), "b"))
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF)
}

test("SPARK-20384: Nested value class filter") {
val a = ContainerStringWrapper(StringWrapper("a"))
val b = ContainerStringWrapper(StringWrapper("b"))

val df = spark.sparkContext.parallelize(Seq(a, b)).toDF
// flat value class, `s` field is not in schema
val filtered = df.where("wrapper = \"a\"")
checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF)
}

private lazy val person2: DataFrame = Seq(
("Bob", 16, 176),
("Alice", 32, 164),
Expand Down
51 changes: 51 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.test.SQLTestData.{StringWrapper}
import org.apache.spark.sql.types._

case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
Expand Down Expand Up @@ -2040,6 +2041,53 @@ class DatasetSuite extends QueryTest
joined,
(1, 1), (2, 2), (3, 3))
}

test("SPARK-20384: Dataset with only value class") {
val intData = Seq(1, 2, 3)
val wrappedData = intData.map(IntWrapper(_))
val wrappedDs = wrappedData.toDS()
val intDs = intData.toDS()

assert(intDs.schema === wrappedDs.schema)

checkDataset(wrappedDs, wrappedData: _*)
}

test("SPARK-20384: Model with StringWrapper as field and generic field") {
val data = Seq(
ThreeStrings("a", StringWrapper("a"), StringWrapper("a")),
ThreeStrings("b", StringWrapper("b"), StringWrapper("b")),
ThreeStrings("c", StringWrapper("c"), StringWrapper("c"))
)
val ds = data.toDS()

val expectedSchema = StructType(Seq(
StructField("value1", StringType, nullable = true),
StructField("value2", StringType, nullable = true),
StructField("value3", StringType, nullable = true)
))

assert(ds.schema === expectedSchema)
checkDataset(ds, data: _*)
}

test("SPARK-20384: Dataset with IntWrapper as field and generic field") {
val data = Seq(
ThreeInts(1, IntWrapper(1), IntWrapper(1)),
ThreeInts(2, IntWrapper(2), IntWrapper(2)),
ThreeInts(3, IntWrapper(3), IntWrapper(3))
)
val ds = data.toDS()

val expectedSchema = StructType(Seq(
StructField("value1", IntegerType, nullable = false),
StructField("value2", IntegerType, nullable = false),
StructField("value3", IntegerType, nullable = false)
))

assert(ds.schema === expectedSchema)
checkDataset(ds, data: _*)
}
}

case class Bar(a: Int)
Expand All @@ -2066,6 +2114,9 @@ case class Generic[T](id: T, value: Double)
case class OtherTuple(_1: String, _2: Int)

case class TupleClass(data: (Int, String))
case class IntWrapper(wrapped: Int) extends AnyVal
case class ThreeInts[T](value1: Int, value2: IntWrapper, value3: T)
case class ThreeStrings[T](value1: String, value2: StringWrapper, value3: T)

class OuterClass extends Serializable {
case class InnerClass(a: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,7 @@ private[sql] object SQLTestData {
case class CourseSales(course: String, year: Int, earnings: Double)
case class TrainingSales(training: String, sales: CourseSales)
case class IntervalData(data: CalendarInterval)
case class StringWrapper(s: String) extends AnyVal
case class ArrayStringWrapper(wrappers: Seq[StringWrapper])
case class ContainerStringWrapper(wrapper: StringWrapper)
}