Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -990,11 +990,7 @@ trait ScalaReflection extends Logging {
}

private def isValueClass(tpe: Type): Boolean = {
tpe.typeSymbol.asClass.isDerivedValueClass
}

private def isTypeParameter(tpe: Type): Boolean = {
tpe.typeSymbol.isParameter
tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isDerivedValueClass
}

/** Returns the name and type of the underlying parameter of value class `tpe`. */
Expand All @@ -1015,15 +1011,11 @@ trait ScalaReflection extends Logging {
val params = constructParams(dealiasedTpe)
params.map { p =>
val paramTpe = p.typeSignature
if (isTypeParameter(paramTpe)) {
// if there are type variables to fill in, do the substitution
// (SomeClass[T] -> SomeClass[Int])
p.name.decodedName.toString -> paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs)
} else if (isValueClass(paramTpe)) {
if (isValueClass(paramTpe)) {
// Replace value class with underlying type
p.name.decodedName.toString -> getUnderlyingTypeOfValueClass(paramTpe)
} else {
p.name.decodedName.toString -> paramTpe
p.name.decodedName.toString -> paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ object GenericData {
type IntData = GenericData[Int]
}

case class NestedGeneric[T](
generic: GenericData[T])

case class SeqNestedGeneric[T](
generic: Seq[T])


case class MultipleConstructorsData(a: Int, b: String, c: Double) {
def this(b: String, a: Int) = this(a, b, c = 1.0)
}
Expand Down Expand Up @@ -295,6 +302,40 @@ class ScalaReflectionSuite extends SparkFunSuite {
nullable = true))
}

test("SPARK-38681: Nested generic data") {
val schema = schemaFor[NestedGeneric[Int]]
assert(schema === Schema(
StructType(Seq(
StructField(
"generic",
StructType(Seq(
StructField("genericField", IntegerType, nullable = false))),
nullable = true))),
nullable = true))
}

test("SPARK-38681: List nested generic") {
val schema = schemaFor[SeqNestedGeneric[Int]]
assert(schema === Schema(
StructType(Seq(
StructField(
"generic",
ArrayType(IntegerType, false),
nullable = true))),
nullable = true))
}

test("SPARK-38681: List nested generic with value class") {
val schema = schemaFor[SeqNestedGeneric[IntWrapper]]
assert(schema === Schema(
StructType(Seq(
StructField(
"generic",
ArrayType(StructType(Seq(StructField("i", IntegerType, false))), true),
nullable = true))),
nullable = true))
}

test("tuple data") {
val schema = schemaFor[(Int, String)]
assert(schema === Schema(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ case class MapOfValueClassKey(m: Map[IntWrapper, String])
case class MapOfValueClassValue(m: Map[String, StringWrapper])
case class OptionOfValueClassValue(o: Option[StringWrapper])
case class CaseClassWithGeneric[T](generic: T, value: IntWrapper)
case class NestedGeneric[T](generic: CaseClassWithGeneric[T])
case class SeqNestedGeneric[T](list: Seq[T])
case class OptionNestedGeneric[T](list: Option[T])
case class MapNestedGenericKey[T](list: Map[T, Int])
case class MapNestedGenericValue[T](list: Map[Int, T])

class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest {
OuterScopes.addOuterScope(this)
Expand Down Expand Up @@ -454,6 +459,18 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
"nested tuple._2 of class value")
encodeDecodeTest(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2)),
"case class with value class in generic parameter")
encodeDecodeTest(NestedGeneric(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2))),
"case class with nested generic parameter")
encodeDecodeTest(SeqNestedGeneric(List(2)),
"case class with nested generic parameter seq")
encodeDecodeTest(SeqNestedGeneric(List(IntWrapper(2))),
"case class with value class and nested generic parameter seq")
encodeDecodeTest(OptionNestedGeneric(Some(2)),
"case class with nested generic option")
encodeDecodeTest(MapNestedGenericKey(Map(1 -> 2)),
"case class with nested generic map key ")
encodeDecodeTest(MapNestedGenericValue(Map(1 -> 2)),
"case class with nested generic map value")

encodeDecodeTest(Option(31), "option of int")
encodeDecodeTest(Option.empty[Int], "empty option of int")
Expand Down