diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index e55c25c4b0c54..701e4e3483c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -161,6 +161,10 @@ object DeserializerBuildHelper { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr + case _: DecimalType => + // For Scala/Java `BigDecimal`, we accept decimal types of any valid precision/scale. + // Here we use the `DecimalType` object to indicate it. + UpCast(expr, DecimalType, walkedTypePath.getPaths) case _ => UpCast(expr, expected, walkedTypePath.getPaths) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3505dcccbfd8e..2f6ffd5f6b908 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3071,15 +3071,29 @@ class Analyzer( case p => p transformExpressions { case u @ UpCast(child, _, _) if !child.resolved => u - case UpCast(child, dt: AtomicType, _) + case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => + throw new AnalysisException( + s"UpCast only support DecimalType as AbstractDataType yet, but got: $target") + + case UpCast(child, target, walkedTypePath) if target == DecimalType + && child.dataType.isInstanceOf[DecimalType] => + assert(walkedTypePath.nonEmpty, + "object DecimalType should only be used inside ExpressionEncoder") + + // SPARK-31750: if we want to upcast to the general decimal type, and the `child` is + // already decimal type, we can remove the `Upcast` and accept any precision/scale. + // This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`. + child + + case UpCast(child, target: AtomicType, _) if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && child.dataType == StringType => - Cast(child, dt.asNullable) + Cast(child, target.asNullable) - case UpCast(child, dataType, walkedTypePath) if !Cast.canUpCast(child.dataType, dataType) => - fail(child, dataType, walkedTypePath) + case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => + fail(child, u.dataType, walkedTypePath) - case UpCast(child, dataType, _) => Cast(child, dataType.asNullable) + case u @ UpCast(child, _, _) => Cast(child, u.dataType.asNullable) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fa615d71a61a0..a56e95c1ef617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1735,8 +1735,16 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St /** * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. + * + * Note: `target` is `AbstractDataType`, so that we can put `object DecimalType`, which means + * we accept `DecimalType` with any valid precision/scale. */ -case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String] = Nil) +case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { override lazy val resolved = false + + def dataType: DataType = target match { + case DecimalType => DecimalType.SYSTEM_DEFAULT + case _ => target.asInstanceOf[DataType] + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 48f4ef5051fb3..577814b9c6696 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -22,9 +22,9 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -247,6 +247,13 @@ class EncoderResolutionSuite extends PlanTest { """.stripMargin.trim + " of the field in the target object") } + test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { + val encoder = ExpressionEncoder[Seq[BigDecimal]] + val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))()) + // Before SPARK-31750, it will fail because Decimal(38, 0) can not be casted to Decimal(38, 18) + testFromRow(encoder, attr, InternalRow(ArrayData.toArrayData(Array(Decimal(1.0))))) + } + // test for leaf types castSuccess[Int, Long] castSuccess[java.sql.Date, java.sql.Timestamp] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4e91a7c7bb0f4..954a4bd9331ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2439,6 +2439,17 @@ class DataFrameSuite extends QueryTest val nestedDecArray = Array(decSpark) checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava)))) } + + test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { + withTempPath { f => + sql("select cast(1 as decimal(38, 0)) as d") + .write.mode("overwrite") + .parquet(f.getAbsolutePath) + + val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal] + assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0)))) + } + } } case class GroupByKey(a: Int, b: Int)