diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d8d268a77ca18..fa8993e8d24c8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -38,6 +38,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} trait DefinedByConstructorParams +private[catalyst] object ScalaSubtypeLock + + /** * A default version of ScalaReflection that uses the runtime universe. */ @@ -66,19 +69,32 @@ object ScalaReflection extends ScalaReflection { */ def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T]) + /** + * Synchronize to prevent concurrent usage of `<:<` operator. + * This operator is not thread safe in any current version of scala; i.e. + * (2.11.12, 2.12.8, 2.13.0-M5). + * + * See https://github.com/scala/bug/issues/10766 + */ + private[catalyst] def isSubtype(tpe1: `Type`, tpe2: `Type`): Boolean = { + ScalaSubtypeLock.synchronized { + tpe1 <:< tpe2 + } + } + private def dataTypeFor(tpe: `Type`): DataType = cleanUpReflectionObjects { tpe.dealias match { - case t if t <:< definitions.NullTpe => NullType - case t if t <:< definitions.IntTpe => IntegerType - case t if t <:< definitions.LongTpe => LongType - case t if t <:< definitions.DoubleTpe => DoubleType - case t if t <:< definitions.FloatTpe => FloatType - case t if t <:< definitions.ShortTpe => ShortType - case t if t <:< definitions.ByteTpe => ByteType - case t if t <:< definitions.BooleanTpe => BooleanType - case t if t <:< localTypeOf[Array[Byte]] => BinaryType - case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType - case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT + case t if isSubtype(t, definitions.NullTpe) => NullType + case t if isSubtype(t, definitions.IntTpe) => IntegerType + case t if isSubtype(t, definitions.LongTpe) => LongType + case t if isSubtype(t, definitions.DoubleTpe) => DoubleType + case t if isSubtype(t, definitions.FloatTpe) => FloatType + case t if isSubtype(t, definitions.ShortTpe) => ShortType + case t if isSubtype(t, definitions.ByteTpe) => ByteType + case t if isSubtype(t, definitions.BooleanTpe) => BooleanType + case t if isSubtype(t, localTypeOf[Array[Byte]]) => BinaryType + case t if isSubtype(t, localTypeOf[CalendarInterval]) => CalendarIntervalType + case t if isSubtype(t, localTypeOf[Decimal]) => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) className match { @@ -101,13 +117,13 @@ object ScalaReflection extends ScalaReflection { */ private def arrayClassFor(tpe: `Type`): ObjectType = cleanUpReflectionObjects { val cls = tpe.dealias match { - case t if t <:< definitions.IntTpe => classOf[Array[Int]] - case t if t <:< definitions.LongTpe => classOf[Array[Long]] - case t if t <:< definitions.DoubleTpe => classOf[Array[Double]] - case t if t <:< definitions.FloatTpe => classOf[Array[Float]] - case t if t <:< definitions.ShortTpe => classOf[Array[Short]] - case t if t <:< definitions.ByteTpe => classOf[Array[Byte]] - case t if t <:< definitions.BooleanTpe => classOf[Array[Boolean]] + case t if isSubtype(t, definitions.IntTpe) => classOf[Array[Int]] + case t if isSubtype(t, definitions.LongTpe) => classOf[Array[Long]] + case t if isSubtype(t, definitions.DoubleTpe) => classOf[Array[Double]] + case t if isSubtype(t, definitions.FloatTpe) => classOf[Array[Float]] + case t if isSubtype(t, definitions.ShortTpe) => classOf[Array[Short]] + case t if isSubtype(t, definitions.ByteTpe) => classOf[Array[Byte]] + case t if isSubtype(t, definitions.BooleanTpe) => classOf[Array[Boolean]] case other => // There is probably a better way to do this, but I couldn't find it... val elementType = dataTypeFor(other).asInstanceOf[ObjectType].cls @@ -161,68 +177,68 @@ object ScalaReflection extends ScalaReflection { tpe.dealias match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => path - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = walkedTypePath.recordOption(className) WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) - case t if t <:< localTypeOf[java.lang.Integer] => + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Integer]) - case t if t <:< localTypeOf[java.lang.Long] => + case t if isSubtype(t, localTypeOf[java.lang.Long]) => createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Long]) - case t if t <:< localTypeOf[java.lang.Double] => + case t if isSubtype(t, localTypeOf[java.lang.Double]) => createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Double]) - case t if t <:< localTypeOf[java.lang.Float] => + case t if isSubtype(t, localTypeOf[java.lang.Float]) => createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Float]) - case t if t <:< localTypeOf[java.lang.Short] => + case t if isSubtype(t, localTypeOf[java.lang.Short]) => createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Short]) - case t if t <:< localTypeOf[java.lang.Byte] => + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Byte]) - case t if t <:< localTypeOf[java.lang.Boolean] => + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => createDeserializerForTypesSupportValueOf(path, classOf[java.lang.Boolean]) - case t if t <:< localTypeOf[java.time.LocalDate] => + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => createDeserializerForLocalDate(path) - case t if t <:< localTypeOf[java.sql.Date] => + case t if isSubtype(t, localTypeOf[java.sql.Date]) => createDeserializerForSqlDate(path) - case t if t <:< localTypeOf[java.time.Instant] => + case t if isSubtype(t, localTypeOf[java.time.Instant]) => createDeserializerForInstant(path) - case t if t <:< localTypeOf[java.sql.Timestamp] => + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => createDeserializerForSqlTimestamp(path) - case t if t <:< localTypeOf[java.lang.String] => + case t if isSubtype(t, localTypeOf[java.lang.String]) => createDeserializerForString(path, returnNullable = false) - case t if t <:< localTypeOf[java.math.BigDecimal] => + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => createDeserializerForJavaBigDecimal(path, returnNullable = false) - case t if t <:< localTypeOf[BigDecimal] => + case t if isSubtype(t, localTypeOf[BigDecimal]) => createDeserializerForScalaBigDecimal(path, returnNullable = false) - case t if t <:< localTypeOf[java.math.BigInteger] => + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => createDeserializerForJavaBigInteger(path, returnNullable = false) - case t if t <:< localTypeOf[scala.math.BigInt] => + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => createDeserializerForScalaBigInt(path) - case t if t <:< localTypeOf[Array[_]] => + case t if isSubtype(t, localTypeOf[Array[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -242,13 +258,13 @@ object ScalaReflection extends ScalaReflection { val arrayCls = arrayClassFor(elementType) val methodName = elementType match { - case t if t <:< definitions.IntTpe => "toIntArray" - case t if t <:< definitions.LongTpe => "toLongArray" - case t if t <:< definitions.DoubleTpe => "toDoubleArray" - case t if t <:< definitions.FloatTpe => "toFloatArray" - case t if t <:< definitions.ShortTpe => "toShortArray" - case t if t <:< definitions.ByteTpe => "toByteArray" - case t if t <:< definitions.BooleanTpe => "toBooleanArray" + case t if isSubtype(t, definitions.IntTpe) => "toIntArray" + case t if isSubtype(t, definitions.LongTpe) => "toLongArray" + case t if isSubtype(t, definitions.DoubleTpe) => "toDoubleArray" + case t if isSubtype(t, definitions.FloatTpe) => "toFloatArray" + case t if isSubtype(t, definitions.ShortTpe) => "toShortArray" + case t if isSubtype(t, definitions.ByteTpe) => "toByteArray" + case t if isSubtype(t, definitions.BooleanTpe) => "toBooleanArray" // non-primitive case _ => "array" } @@ -256,8 +272,8 @@ object ScalaReflection extends ScalaReflection { // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. - case t if t <:< localTypeOf[Seq[_]] || - t <:< localTypeOf[scala.collection.Set[_]] => + case t if isSubtype(t, localTypeOf[Seq[_]]) || + isSubtype(t, localTypeOf[scala.collection.Set[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -274,14 +290,14 @@ object ScalaReflection extends ScalaReflection { val companion = t.dealias.typeSymbol.companion.typeSignature val cls = companion.member(TermName("newBuilder")) match { - case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]] - case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] => + case NoSymbol if isSubtype(t, localTypeOf[Seq[_]]) => classOf[Seq[_]] + case NoSymbol if isSubtype(t, localTypeOf[scala.collection.Set[_]]) => classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } UnresolvedMapObjects(mapFunction, path, Some(cls)) - case t if t <:< localTypeOf[Map[_, _]] => + case t if isSubtype(t, localTypeOf[Map[_, _]]) => val TypeRef(_, _, Seq(keyType, valueType)) = t val classNameForKey = getClassNameFromType(keyType) @@ -411,7 +427,7 @@ object ScalaReflection extends ScalaReflection { tpe.dealias match { case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newPath = walkedTypePath.recordOption(className) @@ -421,15 +437,15 @@ object ScalaReflection extends ScalaReflection { // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the // case "localTypeOf[Seq[_]]" - case t if t <:< localTypeOf[Seq[_]] => + case t if isSubtype(t, localTypeOf[Seq[_]]) => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if t <:< localTypeOf[Array[_]] => + case t if isSubtype(t, localTypeOf[Array[_]]) => val TypeRef(_, _, Seq(elementType)) = t toCatalystArray(inputObject, elementType) - case t if t <:< localTypeOf[Map[_, _]] => + case t if isSubtype(t, localTypeOf[Map[_, _]]) => val TypeRef(_, _, Seq(keyType, valueType)) = t val keyClsName = getClassNameFromType(keyType) val valueClsName = getClassNameFromType(valueType) @@ -448,7 +464,7 @@ object ScalaReflection extends ScalaReflection { serializerFor(_, valueType, valuePath, seenTypeSet)) ) - case t if t <:< localTypeOf[scala.collection.Set[_]] => + case t if isSubtype(t, localTypeOf[scala.collection.Set[_]]) => val TypeRef(_, _, Seq(elementType)) = t // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. @@ -461,35 +477,41 @@ object ScalaReflection extends ScalaReflection { toCatalystArray(newInput, elementType) - case t if t <:< localTypeOf[String] => createSerializerForString(inputObject) + case t if isSubtype(t, localTypeOf[String]) => createSerializerForString(inputObject) - case t if t <:< localTypeOf[java.time.Instant] => createSerializerForJavaInstant(inputObject) + case t if isSubtype(t, localTypeOf[java.time.Instant]) => + createSerializerForJavaInstant(inputObject) - case t if t <:< localTypeOf[java.sql.Timestamp] => + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => createSerializerForSqlTimestamp(inputObject) - case t if t <:< localTypeOf[java.time.LocalDate] => + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => createSerializerForJavaLocalDate(inputObject) - case t if t <:< localTypeOf[java.sql.Date] => createSerializerForSqlDate(inputObject) + case t if isSubtype(t, localTypeOf[java.sql.Date]) => createSerializerForSqlDate(inputObject) - case t if t <:< localTypeOf[BigDecimal] => createSerializerForScalaBigDecimal(inputObject) + case t if isSubtype(t, localTypeOf[BigDecimal]) => + createSerializerForScalaBigDecimal(inputObject) - case t if t <:< localTypeOf[java.math.BigDecimal] => + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => createSerializerForJavaBigDecimal(inputObject) - case t if t <:< localTypeOf[java.math.BigInteger] => + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => createSerializerForJavaBigInteger(inputObject) - case t if t <:< localTypeOf[scala.math.BigInt] => createSerializerForScalaBigInt(inputObject) + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => + createSerializerForScalaBigInt(inputObject) - case t if t <:< localTypeOf[java.lang.Integer] => createSerializerForInteger(inputObject) - case t if t <:< localTypeOf[java.lang.Long] => createSerializerForLong(inputObject) - case t if t <:< localTypeOf[java.lang.Double] => createSerializerForDouble(inputObject) - case t if t <:< localTypeOf[java.lang.Float] => createSerializerForFloat(inputObject) - case t if t <:< localTypeOf[java.lang.Short] => createSerializerForShort(inputObject) - case t if t <:< localTypeOf[java.lang.Byte] => createSerializerForByte(inputObject) - case t if t <:< localTypeOf[java.lang.Boolean] => createSerializerForBoolean(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => + createSerializerForInteger(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Long]) => createSerializerForLong(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Double]) => + createSerializerForDouble(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Float]) => createSerializerForFloat(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Short]) => createSerializerForShort(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => createSerializerForByte(inputObject) + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => + createSerializerForBoolean(inputObject) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t) @@ -540,7 +562,7 @@ object ScalaReflection extends ScalaReflection { */ def optionOfProductType(tpe: `Type`): Boolean = cleanUpReflectionObjects { tpe.dealias match { - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t definedByConstructorParams(optType) case _ => false @@ -606,7 +628,7 @@ object ScalaReflection extends ScalaReflection { tpe.dealias match { // this must be the first case, since all objects in scala are instances of Null, therefore // Null type would wrongly match the first of them, which is Option as of now - case t if t <:< definitions.NullTpe => Schema(NullType, nullable = true) + case t if isSubtype(t, definitions.NullTpe) => Schema(NullType, nullable = true) case t if t.typeSymbol.annotations.exists(_.tree.tpe =:= typeOf[SQLUserDefinedType]) => val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt(). getConstructor().newInstance() @@ -615,54 +637,58 @@ object ScalaReflection extends ScalaReflection { val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.getConstructor(). newInstance().asInstanceOf[UserDefinedType[_]] Schema(udt, nullable = true) - case t if t <:< localTypeOf[Option[_]] => + case t if isSubtype(t, localTypeOf[Option[_]]) => val TypeRef(_, _, Seq(optType)) = t Schema(schemaFor(optType).dataType, nullable = true) - case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true) - case t if t <:< localTypeOf[Array[_]] => + case t if isSubtype(t, localTypeOf[Array[Byte]]) => Schema(BinaryType, nullable = true) + case t if isSubtype(t, localTypeOf[Array[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Seq[_]] => + case t if isSubtype(t, localTypeOf[Seq[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[Map[_, _]] => + case t if isSubtype(t, localTypeOf[Map[_, _]]) => val TypeRef(_, _, Seq(keyType, valueType)) = t val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) - case t if t <:< localTypeOf[Set[_]] => + case t if isSubtype(t, localTypeOf[Set[_]]) => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) Schema(ArrayType(dataType, containsNull = nullable), nullable = true) - case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) - case t if t <:< localTypeOf[java.time.Instant] => Schema(TimestampType, nullable = true) - case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) - case t if t <:< localTypeOf[java.time.LocalDate] => Schema(DateType, nullable = true) - case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) - case t if t <:< localTypeOf[BigDecimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if t <:< localTypeOf[java.math.BigDecimal] => + case t if isSubtype(t, localTypeOf[String]) => Schema(StringType, nullable = true) + case t if isSubtype(t, localTypeOf[java.time.Instant]) => + Schema(TimestampType, nullable = true) + case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) => + Schema(TimestampType, nullable = true) + case t if isSubtype(t, localTypeOf[java.time.LocalDate]) => Schema(DateType, nullable = true) + case t if isSubtype(t, localTypeOf[java.sql.Date]) => Schema(DateType, nullable = true) + case t if isSubtype(t, localTypeOf[BigDecimal]) => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if t <:< localTypeOf[java.math.BigInteger] => + case t if isSubtype(t, localTypeOf[java.math.BigDecimal]) => + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if isSubtype(t, localTypeOf[java.math.BigInteger]) => Schema(DecimalType.BigIntDecimal, nullable = true) - case t if t <:< localTypeOf[scala.math.BigInt] => + case t if isSubtype(t, localTypeOf[scala.math.BigInt]) => Schema(DecimalType.BigIntDecimal, nullable = true) - case t if t <:< localTypeOf[Decimal] => Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) - case t if t <:< localTypeOf[java.lang.Integer] => Schema(IntegerType, nullable = true) - case t if t <:< localTypeOf[java.lang.Long] => Schema(LongType, nullable = true) - case t if t <:< localTypeOf[java.lang.Double] => Schema(DoubleType, nullable = true) - case t if t <:< localTypeOf[java.lang.Float] => Schema(FloatType, nullable = true) - case t if t <:< localTypeOf[java.lang.Short] => Schema(ShortType, nullable = true) - case t if t <:< localTypeOf[java.lang.Byte] => Schema(ByteType, nullable = true) - case t if t <:< localTypeOf[java.lang.Boolean] => Schema(BooleanType, nullable = true) - case t if t <:< definitions.IntTpe => Schema(IntegerType, nullable = false) - case t if t <:< definitions.LongTpe => Schema(LongType, nullable = false) - case t if t <:< definitions.DoubleTpe => Schema(DoubleType, nullable = false) - case t if t <:< definitions.FloatTpe => Schema(FloatType, nullable = false) - case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false) - case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false) - case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false) + case t if isSubtype(t, localTypeOf[Decimal]) => + Schema(DecimalType.SYSTEM_DEFAULT, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Integer]) => Schema(IntegerType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Long]) => Schema(LongType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Double]) => Schema(DoubleType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Float]) => Schema(FloatType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true) + case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true) + case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false) + case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false) + case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false) + case t if isSubtype(t, definitions.FloatTpe) => Schema(FloatType, nullable = false) + case t if isSubtype(t, definitions.ShortTpe) => Schema(ShortType, nullable = false) + case t if isSubtype(t, definitions.ByteTpe) => Schema(ByteType, nullable = false) + case t if isSubtype(t, definitions.BooleanTpe) => Schema(BooleanType, nullable = false) case t if definedByConstructorParams(t) => val params = getConstructorParameters(t) Schema(StructType( @@ -715,9 +741,9 @@ object ScalaReflection extends ScalaReflection { def definedByConstructorParams(tpe: Type): Boolean = cleanUpReflectionObjects { tpe.dealias match { // `Option` is a `Product`, but we don't wanna treat `Option[Int]` as a struct type. - case t if t <:< localTypeOf[Option[_]] => definedByConstructorParams(t.typeArgs.head) - case _ => tpe.dealias <:< localTypeOf[Product] || - tpe.dealias <:< localTypeOf[DefinedByConstructorParams] + case t if isSubtype(t, localTypeOf[Option[_]]) => definedByConstructorParams(t.typeArgs.head) + case _ => isSubtype(tpe.dealias, localTypeOf[Product]) || + isSubtype(tpe.dealias, localTypeOf[DefinedByConstructorParams]) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 80824cc2a7f21..e8df031a1a32c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -145,6 +145,12 @@ class ScalaReflectionSuite extends SparkFunSuite { private def deserializerFor[T: TypeTag]: Expression = deserializerForType(ScalaReflection.localTypeOf[T]) + test("isSubtype") { + assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[_]])) + assert(isSubtype(localTypeOf[Option[Int]], localTypeOf[Option[Int]])) + assert(!isSubtype(localTypeOf[Option[_]], localTypeOf[Option[Int]])) + } + test("SQLUserDefinedType annotation on Scala structure") { val schema = schemaFor[TestingUDT.NestedStruct] assert(schema === Schema(