diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 7abc3498c54cd..565774a5561c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -990,11 +990,7 @@ trait ScalaReflection extends Logging { } private def isValueClass(tpe: Type): Boolean = { - tpe.typeSymbol.asClass.isDerivedValueClass - } - - private def isTypeParameter(tpe: Type): Boolean = { - tpe.typeSymbol.isParameter + tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isDerivedValueClass } /** Returns the name and type of the underlying parameter of value class `tpe`. */ @@ -1015,15 +1011,11 @@ trait ScalaReflection extends Logging { val params = constructParams(dealiasedTpe) params.map { p => val paramTpe = p.typeSignature - if (isTypeParameter(paramTpe)) { - // if there are type variables to fill in, do the substitution - // (SomeClass[T] -> SomeClass[Int]) - p.name.decodedName.toString -> paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs) - } else if (isValueClass(paramTpe)) { + if (isValueClass(paramTpe)) { // Replace value class with underlying type p.name.decodedName.toString -> getUnderlyingTypeOfValueClass(paramTpe) } else { - p.name.decodedName.toString -> paramTpe + p.name.decodedName.toString -> paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index d63fbd8785a96..2c0cb7f640b2c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -81,6 +81,13 @@ object GenericData { type IntData = GenericData[Int] } +case class NestedGeneric[T]( + generic: GenericData[T]) + +case class SeqNestedGeneric[T]( + generic: Seq[T]) + + case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } @@ -295,6 +302,40 @@ class ScalaReflectionSuite extends SparkFunSuite { nullable = true)) } + test("SPARK-38681: Nested generic data") { + val schema = schemaFor[NestedGeneric[Int]] + assert(schema === Schema( + StructType(Seq( + StructField( + "generic", + StructType(Seq( + StructField("genericField", IntegerType, nullable = false))), + nullable = true))), + nullable = true)) + } + + test("SPARK-38681: List nested generic") { + val schema = schemaFor[SeqNestedGeneric[Int]] + assert(schema === Schema( + StructType(Seq( + StructField( + "generic", + ArrayType(IntegerType, false), + nullable = true))), + nullable = true)) + } + + test("SPARK-38681: List nested generic with value class") { + val schema = schemaFor[SeqNestedGeneric[IntWrapper]] + assert(schema === Schema( + StructType(Seq( + StructField( + "generic", + ArrayType(StructType(Seq(StructField("i", IntegerType, false))), true), + nullable = true))), + nullable = true)) + } + test("tuple data") { val schema = schemaFor[(Int, String)] assert(schema === Schema( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index bb4b58dc28a8b..e2eafb7370d18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -131,6 +131,11 @@ case class MapOfValueClassKey(m: Map[IntWrapper, String]) case class MapOfValueClassValue(m: Map[String, StringWrapper]) case class OptionOfValueClassValue(o: Option[StringWrapper]) case class CaseClassWithGeneric[T](generic: T, value: IntWrapper) +case class NestedGeneric[T](generic: CaseClassWithGeneric[T]) +case class SeqNestedGeneric[T](list: Seq[T]) +case class OptionNestedGeneric[T](list: Option[T]) +case class MapNestedGenericKey[T](list: Map[T, Int]) +case class MapNestedGenericValue[T](list: Map[Int, T]) class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTest { OuterScopes.addOuterScope(this) @@ -454,6 +459,18 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes "nested tuple._2 of class value") encodeDecodeTest(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2)), "case class with value class in generic parameter") + encodeDecodeTest(NestedGeneric(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2))), + "case class with nested generic parameter") + encodeDecodeTest(SeqNestedGeneric(List(2)), + "case class with nested generic parameter seq") + encodeDecodeTest(SeqNestedGeneric(List(IntWrapper(2))), + "case class with value class and nested generic parameter seq") + encodeDecodeTest(OptionNestedGeneric(Some(2)), + "case class with nested generic option") + encodeDecodeTest(MapNestedGenericKey(Map(1 -> 2)), + "case class with nested generic map key ") + encodeDecodeTest(MapNestedGenericValue(Map(1 -> 2)), + "case class with nested generic map value") encodeDecodeTest(Option(31), "option of int") encodeDecodeTest(Option.empty[Int], "empty option of int")