diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 49303978d1ce8..d08a6382f738b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2871,20 +2871,15 @@ class Analyzer( case udf: ScalaUDF if udf.inputEncoders.nonEmpty => val boundEncoders = udf.inputEncoders.zipWithIndex.map { case (encOpt, i) => val dataType = udf.children(i).dataType - if (dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) { - // for UDT, we use `CatalystTypeConverters` - None - } else { - encOpt.map { enc => - val attrs = if (enc.isSerializedAsStructForTopLevel) { - dataType.asInstanceOf[StructType].toAttributes - } else { - // the field name doesn't matter here, so we use - // a simple literal to avoid any overhead - new StructType().add("input", dataType).toAttributes - } - enc.resolveAndBind(attrs) + encOpt.map { enc => + val attrs = if (enc.isSerializedAsStructForTopLevel) { + dataType.asInstanceOf[StructType].toAttributes + } else { + // the field name doesn't matter here, so we use + // a simple literal to avoid any overhead + new StructType().add("input", dataType).toAttributes } + enc.resolveAndBind(attrs) } } udf.copy(inputEncoders = boundEncoders) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 3ce284d5518a8..e27c021556377 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -93,8 +93,7 @@ object Cast { toField.nullable) } - case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass => - true + case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true case _ => false } @@ -157,6 +156,8 @@ object Cast { resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType) } + case (from: UserDefinedType[_], to: UserDefinedType[_]) if to.acceptsType(from) => true + case _ => false } @@ -810,8 +811,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) case map: MapType => castMap(from.asInstanceOf[MapType], map) case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) - case udt: UserDefinedType[_] - if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + case udt: UserDefinedType[_] if udt.acceptsType(from) => identity[Any] case _: UserDefinedType[_] => throw new SparkException(s"Cannot cast $from to $to.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 1e3e6d90b8501..3d10b084a8db1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -108,7 +108,6 @@ case class ScalaUDF( * - UDF which doesn't provide inputEncoders, e.g., untyped Scala UDF and Java UDF * - type which isn't supported by `ExpressionEncoder`, e.g., Any * - primitive types, in order to use `identity` for better performance - * - UserDefinedType which isn't fully supported by `ExpressionEncoder` * For other cases like case class, Option[T], we use `ExpressionEncoder` instead since * `CatalystTypeConverters` doesn't support these data types. * @@ -121,8 +120,7 @@ case class ScalaUDF( val useEncoder = !(inputEncoders.isEmpty || // for untyped Scala UDF and Java UDF inputEncoders(i).isEmpty || // for types aren't supported by encoder, e.g. Any - inputPrimitives(i) || // for primitive types - dataType.existsRecursively(_.isInstanceOf[UserDefinedType[_]])) + inputPrimitives(i)) // for primitive types if (useEncoder) { val enc = inputEncoders(i).get diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index ed8ab1cb3a603..3fd5cc72cb95e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -275,11 +275,11 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque // this worked already before the fix SPARK-19311: // return type of doUDF equals parameter type of doOtherUDF - sql("SELECT doOtherUDF(doUDF(41))") + checkAnswer(sql("SELECT doOtherUDF(doUDF(41))"), Row(41) :: Nil) // this one passes only with the fix SPARK-19311: // return type of doSubUDF is a subtype of the parameter type of doOtherUDF - sql("SELECT doOtherUDF(doSubTypeUDF(42))") + checkAnswer(sql("SELECT doOtherUDF(doSubTypeUDF(42))"), Row(42) :: Nil) } test("except on UDT") {