diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 288b6358fbff..819d8487d901 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -21,6 +21,7 @@ import javax.annotation.Nullable import scala.annotation.tailrec import scala.collection.mutable +import scala.math._ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ @@ -89,10 +90,10 @@ object TypeCoercion { case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) - case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => - Some(t2) - case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => - Some(t1) + case (t1: IntegralType, t2: DecimalType) => + findWiderDecimalType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + findWiderDecimalType(t1, DecimalType.forType(t2)) // Promote numeric types to the highest of the two case (t1: NumericType, t2: NumericType) @@ -106,6 +107,22 @@ object TypeCoercion { case (t1, t2) => findTypeForComplex(t1, t2, findTightestCommonType) } + /** + * Finds a wider decimal type between the two supplied decimal types without + * any loss of precision. + */ + def findWiderDecimalType(d1: DecimalType, d2: DecimalType): Option[DecimalType] = { + val scale = max(d1.scale, d2.scale) + val range = max(d1.precision - d1.scale, d2.precision - d2.scale) + + // Check the resultant decimal type does not exceed the allowable limits. + if (range + scale <= DecimalType.MAX_PRECISION) { + Some(DecimalType(range + scale, scale)) + } else { + None + } + } + /** Promotes all the way to StringType. */ private def stringPromotion(dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 461eda4334bb..5b4aaf2659a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -366,7 +366,29 @@ class TypeCoercionSuite extends AnalysisTest { // No up-casting for fixed-precision decimal (this is handled by arithmetic rules) widenTest(DecimalType(2, 1), DecimalType(3, 2), None) widenTest(DecimalType(2, 1), DoubleType, None) - widenTest(DecimalType(2, 1), IntegerType, None) + widenTest(DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1))) + widenTest(DecimalType(2, 1), IntegerType, Some(DecimalType(11, 1))) + widenTest(DecimalType(2, 2), IntegerType, Some(DecimalType(12, 2))) + widenTest(DecimalType(4, 2), IntegerType, Some(DecimalType(12, 2))) + widenTest(DecimalType(10, 2), IntegerType, Some(DecimalType(12, 2))) + widenTest(DecimalType(38, 18), IntegerType, Some(DecimalType(38, 18))) + widenTest(DecimalType(38, 28), IntegerType, Some(DecimalType(38, 28))) + widenTest(DecimalType(38, 29), IntegerType, None) + widenTest(DecimalType(38, 38), IntegerType, None) + widenTest(DecimalType(3, -2), IntegerType, Some(DecimalType(10, 0))) + widenTest(DecimalType(11, -2), IntegerType, Some(DecimalType(13, 0))) + widenTest(DecimalType(36, -2), IntegerType, Some(DecimalType(38, 0))) + widenTest(DecimalType(37, -2), IntegerType, None) + widenTest(DecimalType(1, -38), IntegerType, None) + widenTest(DecimalType(2, 1), LongType, Some(DecimalType(21, 1))) + widenTest(DecimalType(2, 1), LongType, Some(DecimalType(21, 1))) + widenTest(DecimalType(2, 2), LongType, Some(DecimalType(22, 2))) + widenTest(DecimalType(4, 2), LongType, Some(DecimalType(22, 2))) + widenTest(DecimalType(10, 2), LongType, Some(DecimalType(22, 2))) + widenTest(DecimalType(38, 18), LongType, Some(DecimalType(38, 18))) + widenTest(DecimalType(38, 18), LongType, Some(DecimalType(38, 18))) + widenTest(DecimalType(38, 19), LongType, None) + widenTest(DecimalType(38, 38), LongType, None) widenTest(DoubleType, DecimalType(2, 1), None) // StringType