diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 1aea26208694..8d73e622671f 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -138,6 +138,8 @@ license: | need to specify a value with units like "30s" now, to avoid being interpreted as milliseconds; otherwise, the extremely short interval that results will likely cause applications to fail. + - When turning a Dataset to another Dataset, Spark will up cast the fields in the original Dataset to the type of corresponding fields in the target DataSet. In version 2.4 and earlier, this up cast is not very strict, e.g. `Seq("str").toDS.as[Int]` fails, but `Seq("str").toDS.as[Boolean]` works and throw NPE during execution. In Spark 3.0, the up cast is stricter and turning String into something else is not allowed, i.e. `Seq("str").toDS.as[Boolean]` will fail during analysis. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02d83e7e8cb6..2672583ec174 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2562,7 +2562,7 @@ class Analyzer( case e => e.sql } throw new AnalysisException(s"Cannot up cast $fromStr from " + - s"${from.dataType.catalogString} to ${to.catalogString} as it may truncate\n" + + s"${from.dataType.catalogString} to ${to.catalogString}.\n" + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + "You can either add an explicit cast to the input data or choose a higher precision " + "type of the field in the target object") @@ -2575,11 +2575,15 @@ class Analyzer( case p => p transformExpressions { case u @ UpCast(child, _, _) if !child.resolved => u - case UpCast(child, dataType, walkedTypePath) - if Cast.mayTruncate(child.dataType, dataType) => + case UpCast(child, dt: AtomicType, _) + if SQLConf.get.getConf(SQLConf.LEGACY_LOOSE_UPCAST) && + child.dataType == StringType => + Cast(child, dt.asNullable) + + case UpCast(child, dataType, walkedTypePath) if !Cast.canUpCast(child.dataType, dataType) => fail(child, dataType, walkedTypePath) - case UpCast(child, dataType, walkedTypePath) => Cast(child, dataType.asNullable) + case UpCast(child, dataType, _) => Cast(child, dataType.asNullable) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 6134d54531a1..24276e11d844 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -73,8 +73,8 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupp case (attr, originAttr) if !attr.semanticEquals(originAttr) => // The dataType of the output attributes may be not the same with that of the view // output, so we should cast the attribute to the dataType of the view output attribute. - // Will throw an AnalysisException if the cast can't perform or might truncate. - if (Cast.mayTruncate(originAttr.dataType, attr.dataType)) { + // Will throw an AnalysisException if the cast is not a up-cast. + if (!Cast.canUpCast(originAttr.dataType, attr.dataType)) { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + s"${originAttr.dataType.catalogString} to ${attr.dataType.catalogString} as it " + s"may truncate\n") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 4c5419961ee9..f8c1102953ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -120,35 +120,36 @@ object Cast { } /** - * Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int, - * timestamp -> date. + * Returns true iff we can safely up-cast the `from` type to `to` type without any truncating or + * precision lose or possible runtime failures. For example, long -> int, string -> int are not + * up-cast. */ - def mayTruncate(from: DataType, to: DataType): Boolean = (from, to) match { - case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => true - case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => true - case (from, to) if illegalNumericPrecedence(from, to) => true - case (TimestampType, DateType) => true - case (StringType, to: NumericType) => true - case _ => false - } - - private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { - val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from) - val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to) - toPrecedence >= 0 && fromPrecedence > toPrecedence - } - - /** - * Returns true iff we can safely cast the `from` type to `to` type without any truncating or - * precision lose, e.g. int -> long, date -> timestamp. - */ - def canSafeCast(from: AtomicType, to: AtomicType): Boolean = (from, to) match { + def canUpCast(from: DataType, to: DataType): Boolean = (from, to) match { case _ if from == to => true case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true - case (from, to) if legalNumericPrecedence(from, to) => true + case (f, t) if legalNumericPrecedence(f, t) => true case (DateType, TimestampType) => true case (_, StringType) => true + + // Spark supports casting between long and timestamp, please see `longToTimestamp` and + // `timestampToLong` for details. + case (TimestampType, LongType) => true + case (LongType, TimestampType) => true + + case (ArrayType(fromType, fn), ArrayType(toType, tn)) => + resolvableNullability(fn, tn) && canUpCast(fromType, toType) + + case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => + resolvableNullability(fn, tn) && canUpCast(fromKey, toKey) && canUpCast(fromValue, toValue) + + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).forall { + case (f1, f2) => + resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType) + } + case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b7e6135513db..71c830207701 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1769,9 +1769,14 @@ object SQLConf { .createWithDefault(false) val DEFAULT_V2_CATALOG = buildConf("spark.sql.default.catalog") - .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") - .stringConf - .createOptional + .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") + .stringConf + .createOptional + + val LEGACY_LOOSE_UPCAST = buildConf("spark.sql.legacy.looseUpcast") + .doc("When true, the upcast will be loose and allows string to atomic types.") + .booleanConf + .createWithDefault(false) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index d26a73a0d359..c987088a6238 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -443,7 +443,7 @@ object DataType { fieldCompatible case (w: AtomicType, r: AtomicType) => - if (!Cast.canSafeCast(w, r)) { + if (!Cast.canUpCast(w, r)) { addError(s"Cannot safely cast '$context': $w to $r") false } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index dd20e6497fbb..da1b695919de 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -80,7 +80,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = Seq('arr.array(StringType)) assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == s""" - |Cannot up cast array element from string to bigint as it may truncate + |Cannot up cast array element from string to bigint. |The type path of the target object is: |- array element class: "scala.Long" |- field (class: "scala.Array", name: "arr") @@ -202,7 +202,7 @@ class EncoderResolutionSuite extends PlanTest { }.message assert(msg1 == s""" - |Cannot up cast `b` from bigint to int as it may truncate + |Cannot up cast `b` from bigint to int. |The type path of the target object is: |- field (class: "scala.Int", name: "b") |- root class: "org.apache.spark.sql.catalyst.encoders.StringIntClass" @@ -215,7 +215,7 @@ class EncoderResolutionSuite extends PlanTest { }.message assert(msg2 == s""" - |Cannot up cast `b`.`b` from decimal(38,18) to bigint as it may truncate + |Cannot up cast `b`.`b` from decimal(38,18) to bigint. |The type path of the target object is: |- field (class: "scala.Long", name: "b") |- field (class: "org.apache.spark.sql.catalyst.encoders.StringLongClass", name: "b") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index f6a1d00c519c..4d667fd61ae0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -956,37 +956,50 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(ret6, "[1, [1 -> a, 2 -> b, 3 -> c]]") } - test("SPARK-26706: Fix Cast.mayTruncate for bytes") { - assert(!Cast.mayTruncate(ByteType, ByteType)) - assert(!Cast.mayTruncate(DecimalType.ByteDecimal, ByteType)) - assert(Cast.mayTruncate(ShortType, ByteType)) - assert(Cast.mayTruncate(IntegerType, ByteType)) - assert(Cast.mayTruncate(LongType, ByteType)) - assert(Cast.mayTruncate(FloatType, ByteType)) - assert(Cast.mayTruncate(DoubleType, ByteType)) - assert(Cast.mayTruncate(DecimalType.IntDecimal, ByteType)) - } - - test("canSafeCast and mayTruncate must be consistent for numeric types") { - import DataTypeTestUtils._ - + test("up-cast") { def isCastSafe(from: NumericType, to: NumericType): Boolean = (from, to) match { case (_, dt: DecimalType) => dt.isWiderThan(from) case (dt: DecimalType, _) => dt.isTighterThan(to) case _ => numericPrecedence.indexOf(from) <= numericPrecedence.indexOf(to) } + def makeComplexTypes(dt: NumericType, nullable: Boolean): Seq[DataType] = { + Seq( + new StructType().add("a", dt, nullable).add("b", dt, nullable), + ArrayType(dt, nullable), + MapType(dt, dt, nullable), + ArrayType(new StructType().add("a", dt, nullable), nullable), + new StructType().add("a", ArrayType(dt, nullable), nullable) + ) + } + + import DataTypeTestUtils.numericTypes numericTypes.foreach { from => val (safeTargetTypes, unsafeTargetTypes) = numericTypes.partition(to => isCastSafe(from, to)) safeTargetTypes.foreach { to => - assert(Cast.canSafeCast(from, to), s"It should be possible to safely cast $from to $to") - assert(!Cast.mayTruncate(from, to), s"No truncation is expected when casting $from to $to") + assert(Cast.canUpCast(from, to), s"It should be possible to up-cast $from to $to") + + // If the nullability is compatible, we can up-cast complex types too. + Seq(true -> true, false -> false, false -> true).foreach { case (fn, tn) => + makeComplexTypes(from, fn).zip(makeComplexTypes(to, tn)).foreach { + case (complexFromType, complexToType) => + assert(Cast.canUpCast(complexFromType, complexToType)) + } + } + + makeComplexTypes(from, true).zip(makeComplexTypes(to, false)).foreach { + case (complexFromType, complexToType) => + assert(!Cast.canUpCast(complexFromType, complexToType)) + } } unsafeTargetTypes.foreach { to => - assert(!Cast.canSafeCast(from, to), s"It shouldn't be possible to safely cast $from to $to") - assert(Cast.mayTruncate(from, to), s"Truncation is expected when casting $from to $to") + assert(!Cast.canUpCast(from, to), s"It shouldn't be possible to up-cast $from to $to") + makeComplexTypes(from, true).zip(makeComplexTypes(to, true)).foreach { + case (complexFromType, complexToType) => + assert(!Cast.canUpCast(complexFromType, complexToType)) + } } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala index 87d1cd4d60be..9cc9894f2044 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DataTypeWriteCompatibilitySuite.scala @@ -67,7 +67,7 @@ class DataTypeWriteCompatibilitySuite extends SparkFunSuite { test("Check atomic types: write allowed only when casting is safe") { atomicTypes.foreach { w => atomicTypes.foreach { r => - if (Cast.canSafeCast(w, r)) { + if (Cast.canUpCast(w, r)) { assertAllowed(w, r, "t", s"Should allow writing $w to $r because cast is safe") } else { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index af5ea59429b5..18f8c5360981 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -679,7 +679,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { expr match { case attr: Attribute => Some(attr) case Cast(child @ AtomicType(), dt: AtomicType, _) - if Cast.canSafeCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) + if Cast.canUpCast(child.dataType.asInstanceOf[AtomicType], dt) => unapply(child) case _ => None } }