From a6db5d63b3224a4b7a40fda30158641a03b10321 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Fri, 15 May 2020 16:20:22 +0800 Subject: [PATCH 01/12] fix --- .../spark/sql/catalyst/analysis/Analyzer.scala | 6 ++++++ .../scala/org/apache/spark/sql/DataFrameSuite.scala | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3505dcccbfd8e..1d91d06909aaf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3076,6 +3076,12 @@ class Analyzer( child.dataType == StringType => Cast(child, dt.asNullable) + case UpCast(child, dataType, walkedTypePath) + if child.dataType.isInstanceOf[DecimalType] + && dataType.isInstanceOf[DecimalType] + && walkedTypePath.size == 1 => + child + case UpCast(child, dataType, walkedTypePath) if !Cast.canUpCast(child.dataType, dataType) => fail(child, dataType, walkedTypePath) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 4e91a7c7bb0f4..9cc99b79a7ae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2439,6 +2439,18 @@ class DataFrameSuite extends QueryTest val nestedDecArray = Array(decSpark) checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava)))) } + + test("as[BigDecimal] should not lost data precision or scale") { + withTempPath { f => + sql("select 11111111111111111111111111111111111111 as d"). + selectExpr("cast (d as decimal(38, 0))") + .write.mode("overwrite") + .parquet(f.getAbsolutePath) + + val df = spark.read.parquet("/tmp/foo").as[BigDecimal] + assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0)))) + } + } } case class GroupByKey(a: Int, b: Int) From 2a9c35a6afb35059b95b9aa6f81cfa7a5fcbcf6a Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Fri, 15 May 2020 17:14:58 +0800 Subject: [PATCH 02/12] add test --- .../sql/catalyst/encoders/EncoderResolutionSuite.scala | 8 +++++++- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 3 +-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 48f4ef5051fb3..c5ae9d8031fbe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.util.GenericArrayData import org.apache.spark.sql.types._ @@ -247,6 +247,12 @@ class EncoderResolutionSuite extends PlanTest { """.stripMargin.trim + " of the field in the target object") } + test("eliminate UpCast when the output data type of the leaf node is already decimal type") { + val encoder = ExpressionEncoder[BigDecimal] + val attr = Seq(AttributeReference("a", DecimalType(38, 0))()) + testFromRow(encoder, attr, InternalRow(Decimal(0))) + } + // test for leaf types castSuccess[Int, Long] castSuccess[java.sql.Date, java.sql.Timestamp] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 9cc99b79a7ae3..736c9f2185b12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2442,8 +2442,7 @@ class DataFrameSuite extends QueryTest test("as[BigDecimal] should not lost data precision or scale") { withTempPath { f => - sql("select 11111111111111111111111111111111111111 as d"). - selectExpr("cast (d as decimal(38, 0))") + sql("select cast(11111111111111111111111111111111111111 as decimal(38, 0)) as d") .write.mode("overwrite") .parquet(f.getAbsolutePath) From fae0e54e0748cee031639e828f867b9ea0fc1787 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 18 May 2020 15:20:30 +0800 Subject: [PATCH 03/12] add test --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../sql/catalyst/encoders/EncoderResolutionSuite.scala | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1d91d06909aaf..0df7ac9202afd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3079,7 +3079,7 @@ class Analyzer( case UpCast(child, dataType, walkedTypePath) if child.dataType.isInstanceOf[DecimalType] && dataType.isInstanceOf[DecimalType] - && walkedTypePath.size == 1 => + && walkedTypePath.nonEmpty => child case UpCast(child, dataType, walkedTypePath) if !Cast.canUpCast(child.dataType, dataType) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index c5ae9d8031fbe..ef6758bc78a3d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -248,9 +248,9 @@ class EncoderResolutionSuite extends PlanTest { } test("eliminate UpCast when the output data type of the leaf node is already decimal type") { - val encoder = ExpressionEncoder[BigDecimal] - val attr = Seq(AttributeReference("a", DecimalType(38, 0))()) - testFromRow(encoder, attr, InternalRow(Decimal(0))) + val encoder = ExpressionEncoder[Seq[BigDecimal]] + val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))()) + testFromRow(encoder, attr, InternalRow(ArrayData.toArrayData(Array(Decimal(1.0))))) } // test for leaf types From 2dc526f84d8a8a628f4183cecf335dd0ff2395e0 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 18 May 2020 21:54:46 +0800 Subject: [PATCH 04/12] update --- .../sql/catalyst/DeserializerBuildHelper.scala | 1 + .../spark/sql/catalyst/analysis/Analyzer.scala | 15 ++++++++------- .../spark/sql/catalyst/expressions/Cast.scala | 9 ++++++++- .../encoders/EncoderResolutionSuite.scala | 1 + 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index e55c25c4b0c54..f8a6706eb8f03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -161,6 +161,7 @@ object DeserializerBuildHelper { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr + case _: DecimalType => UpCast(expr, DecimalType, walkedTypePath.getPaths) case _ => UpCast(expr, expected, walkedTypePath.getPaths) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0df7ac9202afd..e4225144b7394 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3071,21 +3071,22 @@ class Analyzer( case p => p transformExpressions { case u @ UpCast(child, _, _) if !child.resolved => u - case UpCast(child, dt: AtomicType, _) + case u @ UpCast(child, _, _) if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && + u.dataType.isInstanceOf[AtomicType] && child.dataType == StringType => - Cast(child, dt.asNullable) + Cast(child, u.dataType.asNullable) - case UpCast(child, dataType, walkedTypePath) + case u @ UpCast(child, _, walkedTypePath) if child.dataType.isInstanceOf[DecimalType] - && dataType.isInstanceOf[DecimalType] + && u.dataType.isInstanceOf[DecimalType] && walkedTypePath.nonEmpty => child - case UpCast(child, dataType, walkedTypePath) if !Cast.canUpCast(child.dataType, dataType) => - fail(child, dataType, walkedTypePath) + case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => + fail(child, u.dataType, walkedTypePath) - case UpCast(child, dataType, _) => Cast(child, dataType.asNullable) + case u @ UpCast(child, _, _) => Cast(child, u.dataType.asNullable) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index fa615d71a61a0..3ff90556c6087 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1735,8 +1735,15 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St /** * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. + * + * Note UpCast will be eliminated if the child's dataType is already DecimalType. */ -case class UpCast(child: Expression, dataType: DataType, walkedTypePath: Seq[String] = Nil) +case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { override lazy val resolved = false + + def dataType: DataType = target match { + case DecimalType => DecimalType.SYSTEM_DEFAULT + case _ => target.asInstanceOf[DataType] + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index ef6758bc78a3d..69d722d383d06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -250,6 +250,7 @@ class EncoderResolutionSuite extends PlanTest { test("eliminate UpCast when the output data type of the leaf node is already decimal type") { val encoder = ExpressionEncoder[Seq[BigDecimal]] val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))()) + // previously, it will fail because Decimal(38, 0) can not be casted to Decimal(38, 18) testFromRow(encoder, attr, InternalRow(ArrayData.toArrayData(Array(Decimal(1.0))))) } From b137ec4784b7dd56502f23282fc431fbc08f66fb Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Mon, 18 May 2020 22:38:41 +0800 Subject: [PATCH 05/12] update --- .../apache/spark/sql/catalyst/analysis/Analyzer.scala | 9 +++++++++ .../sql/catalyst/encoders/EncoderResolutionSuite.scala | 2 +- .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e4225144b7394..c3f07c0f795a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3081,6 +3081,15 @@ class Analyzer( if child.dataType.isInstanceOf[DecimalType] && u.dataType.isInstanceOf[DecimalType] && walkedTypePath.nonEmpty => + // SPARK-31750: there are two cases: 1. for local BigDecimal collection, + // e.g. Seq(BigDecimal(12.34), BigDecimal(1)).toDF("a"), we will have + // UpCast(child, Decimal(38, 18)) where child's data type is always Decimal(38, 18). + // 2. for other cases where data type is explicitly known, e.g, spark.read + // .parquet("/tmp/file").as[BigDecimal]. We will have UpCast(child, Decimal(38, 18)), + // where child's data type can be, e.g. Decimal(38, 0). In this case, we actually + // should not do cast otherwise there will be precision lost. + // Thus, we eliminate the UpCast here to avoid precision lost for case 2 and do + // no hurt for case 1. child case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 69d722d383d06..1dbb2fa86d92e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -247,7 +247,7 @@ class EncoderResolutionSuite extends PlanTest { """.stripMargin.trim + " of the field in the target object") } - test("eliminate UpCast when the output data type of the leaf node is already decimal type") { + test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { val encoder = ExpressionEncoder[Seq[BigDecimal]] val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))()) // previously, it will fail because Decimal(38, 0) can not be casted to Decimal(38, 18) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 736c9f2185b12..1ad5ed88a5828 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2440,13 +2440,13 @@ class DataFrameSuite extends QueryTest checkAnswer(Seq(nestedDecArray).toDF(), Row(Array(wrapRefArray(decJava)))) } - test("as[BigDecimal] should not lost data precision or scale") { + test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { withTempPath { f => sql("select cast(11111111111111111111111111111111111111 as decimal(38, 0)) as d") .write.mode("overwrite") .parquet(f.getAbsolutePath) - val df = spark.read.parquet("/tmp/foo").as[BigDecimal] + val df = spark.read.parquet(f.getAbsolutePath).as[BigDecimal] assert(df.schema === new StructType().add(StructField("d", DecimalType(38, 0)))) } } From b4eb291fa3a3baabda9dd3fad4794ade0c5afe72 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 19 May 2020 14:23:02 +0800 Subject: [PATCH 06/12] address comment --- .../sql/catalyst/DeserializerBuildHelper.scala | 5 ++++- .../spark/sql/catalyst/analysis/Analyzer.scala | 18 +++++++----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala index f8a6706eb8f03..701e4e3483c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/DeserializerBuildHelper.scala @@ -161,7 +161,10 @@ object DeserializerBuildHelper { case _: StructType => expr case _: ArrayType => expr case _: MapType => expr - case _: DecimalType => UpCast(expr, DecimalType, walkedTypePath.getPaths) + case _: DecimalType => + // For Scala/Java `BigDecimal`, we accept decimal types of any valid precision/scale. + // Here we use the `DecimalType` object to indicate it. + UpCast(expr, DecimalType, walkedTypePath.getPaths) case _ => UpCast(expr, expected, walkedTypePath.getPaths) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c3f07c0f795a5..5c7936fdf3f0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3077,19 +3077,15 @@ class Analyzer( child.dataType == StringType => Cast(child, u.dataType.asNullable) - case u @ UpCast(child, _, walkedTypePath) + case UpCast(child, target, walkedTypePath) if child.dataType.isInstanceOf[DecimalType] - && u.dataType.isInstanceOf[DecimalType] + && target == DecimalType && walkedTypePath.nonEmpty => - // SPARK-31750: there are two cases: 1. for local BigDecimal collection, - // e.g. Seq(BigDecimal(12.34), BigDecimal(1)).toDF("a"), we will have - // UpCast(child, Decimal(38, 18)) where child's data type is always Decimal(38, 18). - // 2. for other cases where data type is explicitly known, e.g, spark.read - // .parquet("/tmp/file").as[BigDecimal]. We will have UpCast(child, Decimal(38, 18)), - // where child's data type can be, e.g. Decimal(38, 0). In this case, we actually - // should not do cast otherwise there will be precision lost. - // Thus, we eliminate the UpCast here to avoid precision lost for case 2 and do - // no hurt for case 1. + // SPARK-31750: for the case where data type is explicitly known, e.g, spark.read + // .parquet("/tmp/file").as[BigDecimal], we will have UpCast(child, Decimal(38, 18)), + // where child's data type can be, e.g. Decimal(38, 0). In this kind of case, we + // actually should not do cast otherwise it will cause precision lost. Thus, we should + // eliminate the UpCast here to avoid precision lost. child case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => From 8fe049068e7a52235afb79c97db4da6492a4a22a Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 19 May 2020 15:11:38 +0800 Subject: [PATCH 07/12] update --- .../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 3ff90556c6087..59810b142bdf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1736,7 +1736,8 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. * - * Note UpCast will be eliminated if the child's dataType is already DecimalType. + * Note that UpCast will be eliminated if the child's dataType is already DecimalType and + * target is also DecimalType. */ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { From 6b70e778266b90bdd8ab5b3190c07168f4a12caf Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 19 May 2020 16:13:28 +0800 Subject: [PATCH 08/12] use assert --- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5c7936fdf3f0f..04cebee246bbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3079,8 +3079,9 @@ class Analyzer( case UpCast(child, target, walkedTypePath) if child.dataType.isInstanceOf[DecimalType] - && target == DecimalType - && walkedTypePath.nonEmpty => + && target == DecimalType => + assert(walkedTypePath.nonEmpty, + "object DecimalType should only be used inside ExpressionEncoder") // SPARK-31750: for the case where data type is explicitly known, e.g, spark.read // .parquet("/tmp/file").as[BigDecimal], we will have UpCast(child, Decimal(38, 18)), // where child's data type can be, e.g. Decimal(38, 0). In this kind of case, we From bc0bbeca2350d47fe9531fe03c40af55e0f4ae2c Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 19 May 2020 17:38:26 +0800 Subject: [PATCH 09/12] update --- .../sql/catalyst/analysis/Analyzer.scala | 19 +++++++++++-------- .../spark/sql/catalyst/expressions/Cast.scala | 6 ++++++ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 04cebee246bbe..5ab30b636c156 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3071,15 +3071,12 @@ class Analyzer( case p => p transformExpressions { case u @ UpCast(child, _, _) if !child.resolved => u - case u @ UpCast(child, _, _) - if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && - u.dataType.isInstanceOf[AtomicType] && - child.dataType == StringType => - Cast(child, u.dataType.asNullable) + case UpCast(_, target, _) if target != DecimalType && !target.isInstanceOf[DataType] => + throw new AnalysisException( + s"UpCast only support DecimalType as AbstractDataType yet, but got: $target") - case UpCast(child, target, walkedTypePath) - if child.dataType.isInstanceOf[DecimalType] - && target == DecimalType => + case UpCast(child, target, walkedTypePath) if target == DecimalType + && child.dataType.isInstanceOf[DecimalType] => assert(walkedTypePath.nonEmpty, "object DecimalType should only be used inside ExpressionEncoder") // SPARK-31750: for the case where data type is explicitly known, e.g, spark.read @@ -3089,6 +3086,12 @@ class Analyzer( // eliminate the UpCast here to avoid precision lost. child + case u @ UpCast(child, _, _) + if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && + u.dataType.isInstanceOf[AtomicType] && + child.dataType == StringType => + Cast(child, u.dataType.asNullable) + case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => fail(child, u.dataType, walkedTypePath) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 59810b142bdf7..5d9dc8bbaf4e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1738,6 +1738,12 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St * * Note that UpCast will be eliminated if the child's dataType is already DecimalType and * target is also DecimalType. + * + * Also note we use `AbstractDataType` instead of `DataType` in order to accept both object + * DecimalType and a concrete DecimalType, e.g. DecimalType(5, 2). In this way, we can still + * keep the original semantic for `UpCast`(e.g. cast to a concrete decimal type ) but take + * object DecimalType as an exception for the `ExpressionEncoder` only(any other AbstractDataType + * will fail at analysis phase yet). */ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable { From 75f4f657ca471db43bd0a1489c4cf1703b99f1b7 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 19 May 2020 21:24:29 +0800 Subject: [PATCH 10/12] use 1 --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1ad5ed88a5828..954a4bd9331ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2442,7 +2442,7 @@ class DataFrameSuite extends QueryTest test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { withTempPath { f => - sql("select cast(11111111111111111111111111111111111111 as decimal(38, 0)) as d") + sql("select cast(1 as decimal(38, 0)) as d") .write.mode("overwrite") .parquet(f.getAbsolutePath) From c0345d6abce2b5aa99b46555f2779ea3fcb60217 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 19 May 2020 21:26:28 +0800 Subject: [PATCH 11/12] before --- .../spark/sql/catalyst/encoders/EncoderResolutionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 1dbb2fa86d92e..577814b9c6696 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -250,7 +250,7 @@ class EncoderResolutionSuite extends PlanTest { test("SPARK-31750: eliminate UpCast if child's dataType is DecimalType") { val encoder = ExpressionEncoder[Seq[BigDecimal]] val attr = Seq(AttributeReference("a", ArrayType(DecimalType(38, 0)))()) - // previously, it will fail because Decimal(38, 0) can not be casted to Decimal(38, 18) + // Before SPARK-31750, it will fail because Decimal(38, 0) can not be casted to Decimal(38, 18) testFromRow(encoder, attr, InternalRow(ArrayData.toArrayData(Array(Decimal(1.0))))) } From e7664a11b7c6f14df0132e25316a1878792963c6 Mon Sep 17 00:00:00 2001 From: "yi.wu" Date: Tue, 19 May 2020 21:35:05 +0800 Subject: [PATCH 12/12] address comments --- .../spark/sql/catalyst/analysis/Analyzer.scala | 14 ++++++-------- .../spark/sql/catalyst/expressions/Cast.scala | 10 ++-------- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 5ab30b636c156..2f6ffd5f6b908 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -3079,18 +3079,16 @@ class Analyzer( && child.dataType.isInstanceOf[DecimalType] => assert(walkedTypePath.nonEmpty, "object DecimalType should only be used inside ExpressionEncoder") - // SPARK-31750: for the case where data type is explicitly known, e.g, spark.read - // .parquet("/tmp/file").as[BigDecimal], we will have UpCast(child, Decimal(38, 18)), - // where child's data type can be, e.g. Decimal(38, 0). In this kind of case, we - // actually should not do cast otherwise it will cause precision lost. Thus, we should - // eliminate the UpCast here to avoid precision lost. + + // SPARK-31750: if we want to upcast to the general decimal type, and the `child` is + // already decimal type, we can remove the `Upcast` and accept any precision/scale. + // This can happen for cases like `spark.read.parquet("/tmp/file").as[BigDecimal]`. child - case u @ UpCast(child, _, _) + case UpCast(child, target: AtomicType, _) if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && - u.dataType.isInstanceOf[AtomicType] && child.dataType == StringType => - Cast(child, u.dataType.asNullable) + Cast(child, target.asNullable) case u @ UpCast(child, _, walkedTypePath) if !Cast.canUpCast(child.dataType, u.dataType) => fail(child, u.dataType, walkedTypePath) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 5d9dc8bbaf4e1..a56e95c1ef617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -1736,14 +1736,8 @@ case class AnsiCast(child: Expression, dataType: DataType, timeZoneId: Option[St * Cast the child expression to the target data type, but will throw error if the cast might * truncate, e.g. long -> int, timestamp -> data. * - * Note that UpCast will be eliminated if the child's dataType is already DecimalType and - * target is also DecimalType. - * - * Also note we use `AbstractDataType` instead of `DataType` in order to accept both object - * DecimalType and a concrete DecimalType, e.g. DecimalType(5, 2). In this way, we can still - * keep the original semantic for `UpCast`(e.g. cast to a concrete decimal type ) but take - * object DecimalType as an exception for the `ExpressionEncoder` only(any other AbstractDataType - * will fail at analysis phase yet). + * Note: `target` is `AbstractDataType`, so that we can put `object DecimalType`, which means + * we accept `DecimalType` with any valid precision/scale. */ case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil) extends UnaryExpression with Unevaluable {