diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 1f7634bafa420..f587245116581 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -611,10 +611,39 @@ object ScalaReflection extends ScalaReflection { } } + private def erasure(tpe: Type): Type = { + // For user-defined AnyVal classes, we should not erasure it. Otherwise, it will + // resolve to underlying type which wrapped by this class, e.g erasure + // `case class Foo(i: Int) extends AnyVal` will return type `Int` instead of `Foo`. + // But, for other types, we do need to erasure it. For example, we need to erasure + // `scala.Any` to `java.lang.Object` in order to load it from Java ClassLoader. + // Please see SPARK-17368 & SPARK-31190 for more details. + if (isSubtype(tpe, localTypeOf[AnyVal]) && !tpe.toString.startsWith("scala")) { + tpe + } else { + tpe.erasure + } + } + + /** + * Returns the full class name for a type. The returned name is the canonical + * Scala name, where each component is separated by a period. It is NOT the + * Java-equivalent runtime name (no dollar signs). + * + * In simple cases, both the Scala and Java names are the same, however when Scala + * generates constructs that do not map to a Java equivalent, such as singleton objects + * or nested classes in package objects, it uses the dollar sign ($) to create + * synthetic classes, emulating behaviour in Java bytecode. + */ + def getClassNameFromType(tpe: `Type`): String = { + erasure(tpe).dealias.typeSymbol.asClass.fullName + } + /* * Retrieves the runtime class corresponding to the provided type. */ - def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass) + def getClassFromType(tpe: Type): Class[_] = + mirror.runtimeClass(erasure(tpe).dealias.typeSymbol.asClass) case class Schema(dataType: DataType, nullable: Boolean) @@ -863,20 +892,6 @@ trait ScalaReflection extends Logging { tag.in(mirror).tpe.dealias } - /** - * Returns the full class name for a type. The returned name is the canonical - * Scala name, where each component is separated by a period. It is NOT the - * Java-equivalent runtime name (no dollar signs). - * - * In simple cases, both the Scala and Java names are the same, however when Scala - * generates constructs that do not map to a Java equivalent, such as singleton objects - * or nested classes in package objects, it uses the dollar sign ($) to create - * synthetic classes, emulating behaviour in Java bytecode. - */ - def getClassNameFromType(tpe: `Type`): String = { - tpe.dealias.erasure.typeSymbol.asClass.fullName - } - /** * Returns the parameter names and types for the primary constructor of this type. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index c1f1be3b30e4b..66a1bbe01f5c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -107,6 +107,8 @@ class UDTForCaseClass extends UserDefinedType[UDTCaseClass] { } } +case class Bar(i: Any) +case class Foo(i: Bar) extends AnyVal case class PrimitiveValueClass(wrapped: Int) extends AnyVal case class ReferenceValueClass(wrapped: ReferenceValueClass.Container) extends AnyVal object ReferenceValueClass { @@ -311,6 +313,13 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes productTest(("UDT", new ExamplePoint(0.1, 0.2))) + test("AnyVal class with Any fields") { + val exception = intercept[UnsupportedOperationException](implicitly[ExpressionEncoder[Foo]]) + val errorMsg = exception.getMessage + assert(errorMsg.contains("root class: \"org.apache.spark.sql.catalyst.encoders.Foo\"")) + assert(errorMsg.contains("No Encoder found for Any")) + } + test("nullable of encoder schema") { def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq)