From 9bb6947d6516255f38361cde1d7a3413da0b10d8 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 8 Sep 2022 14:26:42 +0800 Subject: [PATCH 1/7] [SPARK-40387][SQL] Improve the implementation of Spark Decimal --- .../org/apache/spark/sql/types/Decimal.scala | 132 ++++++++---------- 1 file changed, 57 insertions(+), 75 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 57e8fc060a291..afac8c4e02f8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -184,68 +184,56 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } - def toBigDecimal: BigDecimal = { - if (decimalVal.ne(null)) { - decimalVal - } else { - BigDecimal(longVal, _scale) - } + def toBigDecimal: BigDecimal = if (decimalVal.eq(null)) { + BigDecimal(longVal, _scale) + } else { + decimalVal } - def toJavaBigDecimal: java.math.BigDecimal = { - if (decimalVal.ne(null)) { - decimalVal.underlying() - } else { - java.math.BigDecimal.valueOf(longVal, _scale) - } + def toJavaBigDecimal: java.math.BigDecimal = if (decimalVal.eq(null)) { + java.math.BigDecimal.valueOf(longVal, _scale) + } else { + decimalVal.underlying() } - def toScalaBigInt: BigInt = { - if (decimalVal.ne(null)) { - decimalVal.toBigInt - } else { - BigInt(toLong) - } + def toScalaBigInt: BigInt = if (decimalVal.eq(null)) { + BigInt(actualLongVal) + } else { + decimalVal.toBigInt } - def toJavaBigInteger: java.math.BigInteger = { - if (decimalVal.ne(null)) { - decimalVal.underlying().toBigInteger() - } else { - java.math.BigInteger.valueOf(toLong) - } + def toJavaBigInteger: java.math.BigInteger = if (decimalVal.eq(null)) { + java.math.BigInteger.valueOf(actualLongVal) + } else { + decimalVal.underlying().toBigInteger() } - def toUnscaledLong: Long = { - if (decimalVal.ne(null)) { - decimalVal.underlying().unscaledValue().longValueExact() - } else { - longVal - } + def toUnscaledLong: Long = if (decimalVal.eq(null)) { + longVal + } else { + decimalVal.underlying().unscaledValue().longValueExact() } override def toString: String = toBigDecimal.toString() - def toPlainString: String = toBigDecimal.bigDecimal.toPlainString + def toPlainString: String = toJavaBigDecimal.toPlainString - def toDebugString: String = { - if (decimalVal.ne(null)) { - s"Decimal(expanded, $decimalVal, $precision, $scale)" - } else { - s"Decimal(compact, $longVal, $precision, $scale)" - } + def toDebugString: String = if (decimalVal.eq(null)) { + s"Decimal(compact, $longVal, $precision, $scale)" + } else { + s"Decimal(expanded, $decimalVal, $precision, $scale)" } def toDouble: Double = toBigDecimal.doubleValue def toFloat: Float = toBigDecimal.floatValue - def toLong: Long = { - if (decimalVal.eq(null)) { - longVal / POW_10(_scale) - } else { - decimalVal.longValue - } + private def actualLongVal: Long = longVal / POW_10(_scale) + + def toLong: Long = if (decimalVal.eq(null)) { + actualLongVal + } else { + decimalVal.longValue } def toInt: Int = toLong.toInt @@ -278,7 +266,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) (f1: Long => T) (f2: Double => T): T = { if (decimalVal.eq(null)) { - val actualLongVal = longVal / POW_10(_scale) val numericVal = f1(actualLongVal) if (actualLongVal == numericVal) { numericVal @@ -303,7 +290,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ private[sql] def roundToLong(): Long = { if (decimalVal.eq(null)) { - longVal / POW_10(_scale) + actualLongVal } else { try { // We cannot store Long.MAX_VALUE as a Double without losing precision. @@ -414,20 +401,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { // In both cases, we will check whether our precision is okay below } - if (dv.ne(null)) { - // We get here if either we started with a BigDecimal, or we switched to one because we would - // have overflowed our Long; in either case we must rescale dv to the new scale. - dv = dv.setScale(scale, roundMode) - if (dv.precision > precision) { - return false - } - } else { + if (dv.eq(null)) { // We're still using Longs, but we should check whether we match the new precision val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) if (lv <= -p || lv >= p) { // Note that we shouldn't have been able to fix this by switching to BigDecimal return false } + } else { + // We get here if either we started with a BigDecimal, or we switched to one because we would + // have overflowed our Long; in either case we must rescale dv to the new scale. + dv = dv.setScale(scale, roundMode) + if (dv.precision > precision) { + return false + } } decimalVal = dv longVal = lv @@ -455,7 +442,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def hashCode(): Int = toBigDecimal.hashCode() - def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + def isZero: Boolean = if (decimalVal.eq(null)) longVal == 0 else decimalVal.signum == 0 // We should follow DecimalPrecision promote if use longVal for add and subtract: // Operation Result Precision Result Scale @@ -487,24 +474,24 @@ final class Decimal extends Ordered[Decimal] with Serializable { DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) def % (that: Decimal): Decimal = - if (that.isZero) null - else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) + if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, + MATH_CONTEXT)) def quot(that: Decimal): Decimal = - if (that.isZero) null - else Decimal(toJavaBigDecimal.divideToIntegralValue(that.toJavaBigDecimal, MATH_CONTEXT)) + if (that.isZero) null else Decimal(toJavaBigDecimal.divideToIntegralValue(that.toJavaBigDecimal, + MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that def unary_- : Decimal = { - if (decimalVal.ne(null)) { - Decimal(-decimalVal, precision, scale) - } else { + if (decimalVal.eq(null)) { Decimal(-longVal, precision, scale) + } else { + Decimal(-decimalVal, precision, scale) } } - def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this + def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision @@ -532,8 +519,6 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO = BigDecimal(0) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) private[sql] val ZERO = Decimal(0) @@ -565,19 +550,16 @@ object Decimal { def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) // This is used for RowEncoder to handle Decimal inside external row. - def fromDecimal(value: Any): Decimal = { - value match { - case j: java.math.BigDecimal => apply(j) - case d: BigDecimal => apply(d) - case k: scala.math.BigInt => apply(k) - case l: java.math.BigInteger => apply(l) - case d: Decimal => d - } + def fromDecimal(value: Any): Decimal = value match { + case j: java.math.BigDecimal => apply(j) + case d: BigDecimal => apply(d) + case k: scala.math.BigInt => apply(k) + case l: java.math.BigInteger => apply(l) + case d: Decimal => d } - private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = { - bigDecimal.precision - bigDecimal.scale - } + private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = + bigDecimal.precision - bigDecimal.scale private def stringToJavaBigDecimal(str: UTF8String): JavaBigDecimal = { // According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`. From 61d748973fac5d44af81ba13dec4bcea2ff3acd4 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 8 Sep 2022 14:35:19 +0800 Subject: [PATCH 2/7] Update code --- .../src/main/scala/org/apache/spark/sql/types/Decimal.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index afac8c4e02f8f..699bbac17b579 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -453,7 +453,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal + that.longVal, Math.max(precision, that.precision) + 1, scale) } else { - Decimal(toBigDecimal.bigDecimal.add(that.toBigDecimal.bigDecimal)) + Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal)) } } @@ -461,7 +461,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal - that.longVal, Math.max(precision, that.precision) + 1, scale) } else { - Decimal(toBigDecimal.bigDecimal.subtract(that.toBigDecimal.bigDecimal)) + Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal)) } } From c5f743f5ffd79a434b463c6c8c727b3f65d8754b Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 8 Sep 2022 14:45:13 +0800 Subject: [PATCH 3/7] Update code --- .../scala/org/apache/spark/sql/types/Decimal.scala | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 699bbac17b579..557003aa72fd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -483,12 +483,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { def remainder(that: Decimal): Decimal = this % that - def unary_- : Decimal = { - if (decimalVal.eq(null)) { - Decimal(-longVal, precision, scale) - } else { - Decimal(-decimalVal, precision, scale) - } + def unary_- : Decimal = if (decimalVal.eq(null)) { + Decimal(-longVal, precision, scale) + } else { + Decimal(-decimalVal, precision, scale) } def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this From 0ae7dff32902589e72b8c4930263ed3c33ed0f2f Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 9 Sep 2022 12:28:39 +0800 Subject: [PATCH 4/7] Update code --- .../org/apache/spark/sql/types/Decimal.scala | 60 +++++++++---------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 557003aa72fd0..5b05b8a8d1ffb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -184,44 +184,44 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } - def toBigDecimal: BigDecimal = if (decimalVal.eq(null)) { - BigDecimal(longVal, _scale) - } else { + def toBigDecimal: BigDecimal = if (decimalVal.ne(null)) { decimalVal + } else { + BigDecimal(longVal, _scale) } - def toJavaBigDecimal: java.math.BigDecimal = if (decimalVal.eq(null)) { - java.math.BigDecimal.valueOf(longVal, _scale) - } else { + def toJavaBigDecimal: java.math.BigDecimal = if (decimalVal.ne(null)) { decimalVal.underlying() + } else { + java.math.BigDecimal.valueOf(longVal, _scale) } - def toScalaBigInt: BigInt = if (decimalVal.eq(null)) { - BigInt(actualLongVal) - } else { + def toScalaBigInt: BigInt = if (decimalVal.ne(null)) { decimalVal.toBigInt + } else { + BigInt(actualLongVal) } - def toJavaBigInteger: java.math.BigInteger = if (decimalVal.eq(null)) { - java.math.BigInteger.valueOf(actualLongVal) - } else { + def toJavaBigInteger: java.math.BigInteger = if (decimalVal.ne(null)) { decimalVal.underlying().toBigInteger() + } else { + java.math.BigInteger.valueOf(toLong) } - def toUnscaledLong: Long = if (decimalVal.eq(null)) { - longVal - } else { + def toUnscaledLong: Long = if (decimalVal.ne(null)) { decimalVal.underlying().unscaledValue().longValueExact() + } else { + longVal } override def toString: String = toBigDecimal.toString() def toPlainString: String = toJavaBigDecimal.toPlainString - def toDebugString: String = if (decimalVal.eq(null)) { - s"Decimal(compact, $longVal, $precision, $scale)" - } else { + def toDebugString: String = if (decimalVal.ne(null)) { s"Decimal(expanded, $decimalVal, $precision, $scale)" + } else { + s"Decimal(compact, $longVal, $precision, $scale)" } def toDouble: Double = toBigDecimal.doubleValue @@ -401,20 +401,20 @@ final class Decimal extends Ordered[Decimal] with Serializable { // In both cases, we will check whether our precision is okay below } - if (dv.eq(null)) { - // We're still using Longs, but we should check whether we match the new precision - val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) - if (lv <= -p || lv >= p) { - // Note that we shouldn't have been able to fix this by switching to BigDecimal - return false - } - } else { + if (dv.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale dv to the new scale. dv = dv.setScale(scale, roundMode) if (dv.precision > precision) { return false } + } else { + // We're still using Longs, but we should check whether we match the new precision + val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) + if (lv <= -p || lv >= p) { + // Note that we shouldn't have been able to fix this by switching to BigDecimal + return false + } } decimalVal = dv longVal = lv @@ -442,7 +442,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def hashCode(): Int = toBigDecimal.hashCode() - def isZero: Boolean = if (decimalVal.eq(null)) longVal == 0 else decimalVal.signum == 0 + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal.signum == 0 else longVal == 0 // We should follow DecimalPrecision promote if use longVal for add and subtract: // Operation Result Precision Result Scale @@ -483,10 +483,10 @@ final class Decimal extends Ordered[Decimal] with Serializable { def remainder(that: Decimal): Decimal = this % that - def unary_- : Decimal = if (decimalVal.eq(null)) { - Decimal(-longVal, precision, scale) - } else { + def unary_- : Decimal = if (decimalVal.ne(null)) { Decimal(-decimalVal, precision, scale) + } else { + Decimal(-longVal, precision, scale) } def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this From 0efecd54773272c67426187fa2c74297a272efad Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 9 Sep 2022 16:33:34 +0800 Subject: [PATCH 5/7] Update code --- .../org/apache/spark/sql/types/Decimal.scala | 94 +++++++++++-------- 1 file changed, 56 insertions(+), 38 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 5b05b8a8d1ffb..15d5e188de43f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -184,44 +184,56 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } - def toBigDecimal: BigDecimal = if (decimalVal.ne(null)) { - decimalVal - } else { - BigDecimal(longVal, _scale) + def toBigDecimal: BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal + } else { + BigDecimal(longVal, _scale) + } } - def toJavaBigDecimal: java.math.BigDecimal = if (decimalVal.ne(null)) { - decimalVal.underlying() - } else { - java.math.BigDecimal.valueOf(longVal, _scale) + def toJavaBigDecimal: java.math.BigDecimal = { + if (decimalVal.ne(null)) { + decimalVal.underlying() + } else { + java.math.BigDecimal.valueOf(longVal, _scale) + } } - def toScalaBigInt: BigInt = if (decimalVal.ne(null)) { - decimalVal.toBigInt - } else { - BigInt(actualLongVal) + def toScalaBigInt: BigInt = { + if (decimalVal.ne(null)) { + decimalVal.toBigInt + } else { + BigInt(actualLongVal) + } } - def toJavaBigInteger: java.math.BigInteger = if (decimalVal.ne(null)) { - decimalVal.underlying().toBigInteger() - } else { - java.math.BigInteger.valueOf(toLong) + def toJavaBigInteger: java.math.BigInteger = { + if (decimalVal.ne(null)) { + decimalVal.underlying().toBigInteger() + } else { + java.math.BigInteger.valueOf(actualLongVal) + } } - def toUnscaledLong: Long = if (decimalVal.ne(null)) { - decimalVal.underlying().unscaledValue().longValueExact() - } else { - longVal + def toUnscaledLong: Long = { + if (decimalVal.ne(null)) { + decimalVal.underlying().unscaledValue().longValueExact() + } else { + longVal + } } override def toString: String = toBigDecimal.toString() def toPlainString: String = toJavaBigDecimal.toPlainString - def toDebugString: String = if (decimalVal.ne(null)) { - s"Decimal(expanded, $decimalVal, $precision, $scale)" - } else { - s"Decimal(compact, $longVal, $precision, $scale)" + def toDebugString: String = { + if (decimalVal.ne(null)) { + s"Decimal(expanded, $decimalVal, $precision, $scale)" + } else { + s"Decimal(compact, $longVal, $precision, $scale)" + } } def toDouble: Double = toBigDecimal.doubleValue @@ -230,10 +242,12 @@ final class Decimal extends Ordered[Decimal] with Serializable { private def actualLongVal: Long = longVal / POW_10(_scale) - def toLong: Long = if (decimalVal.eq(null)) { - actualLongVal - } else { - decimalVal.longValue + def toLong: Long = { + if (decimalVal.eq(null)) { + actualLongVal + } else { + decimalVal.longValue + } } def toInt: Int = toLong.toInt @@ -483,10 +497,12 @@ final class Decimal extends Ordered[Decimal] with Serializable { def remainder(that: Decimal): Decimal = this % that - def unary_- : Decimal = if (decimalVal.ne(null)) { - Decimal(-decimalVal, precision, scale) - } else { - Decimal(-longVal, precision, scale) + def unary_- : Decimal = { + if (decimalVal.ne(null)) { + Decimal(-decimalVal, precision, scale) + } else { + Decimal(-longVal, precision, scale) + } } def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this @@ -548,12 +564,14 @@ object Decimal { def apply(value: String): Decimal = new Decimal().set(BigDecimal(value)) // This is used for RowEncoder to handle Decimal inside external row. - def fromDecimal(value: Any): Decimal = value match { - case j: java.math.BigDecimal => apply(j) - case d: BigDecimal => apply(d) - case k: scala.math.BigInt => apply(k) - case l: java.math.BigInteger => apply(l) - case d: Decimal => d + def fromDecimal(value: Any): Decimal = { + value match { + case j: java.math.BigDecimal => apply(j) + case d: BigDecimal => apply(d) + case k: scala.math.BigInt => apply(k) + case l: java.math.BigInteger => apply(l) + case d: Decimal => d + } } private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = From aa60e0eb775705d2e434e71995c15967e7c52616 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Fri, 9 Sep 2022 19:14:38 +0800 Subject: [PATCH 6/7] Update code --- .../org/apache/spark/sql/types/Decimal.scala | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 15d5e188de43f..68a45b2813901 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -204,7 +204,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.toBigInt } else { - BigInt(actualLongVal) + BigInt(rawLongValue) } } @@ -212,7 +212,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.underlying().toBigInteger() } else { - java.math.BigInteger.valueOf(actualLongVal) + java.math.BigInteger.valueOf(rawLongValue) } } @@ -240,11 +240,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toFloat: Float = toBigDecimal.floatValue - private def actualLongVal: Long = longVal / POW_10(_scale) + private def rawLongValue: Long = longVal / POW_10(_scale) def toLong: Long = { if (decimalVal.eq(null)) { - actualLongVal + rawLongValue } else { decimalVal.longValue } @@ -280,8 +280,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) (f1: Long => T) (f2: Double => T): T = { if (decimalVal.eq(null)) { - val numericVal = f1(actualLongVal) - if (actualLongVal == numericVal) { + val numericVal = f1(rawLongValue) + if (rawLongValue == numericVal) { numericVal } else { throw QueryExecutionErrors.castingCauseOverflowError( @@ -304,7 +304,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ private[sql] def roundToLong(): Long = { if (decimalVal.eq(null)) { - actualLongVal + rawLongValue } else { try { // We cannot store Long.MAX_VALUE as a Double without losing precision. @@ -488,12 +488,12 @@ final class Decimal extends Ordered[Decimal] with Serializable { DecimalType.MAX_SCALE, MATH_CONTEXT.getRoundingMode)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, - MATH_CONTEXT)) + if (that.isZero) null + else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal, MATH_CONTEXT)) def quot(that: Decimal): Decimal = - if (that.isZero) null else Decimal(toJavaBigDecimal.divideToIntegralValue(that.toJavaBigDecimal, - MATH_CONTEXT)) + if (that.isZero) null + else Decimal(toJavaBigDecimal.divideToIntegralValue(that.toJavaBigDecimal, MATH_CONTEXT)) def remainder(that: Decimal): Decimal = this % that From 9af453d18034b95f54d4b58d2adf07edcce0eb57 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 14 Sep 2022 10:25:35 +0800 Subject: [PATCH 7/7] Update code --- .../scala/org/apache/spark/sql/types/Decimal.scala | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 68a45b2813901..cedf4440aabf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -204,7 +204,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.toBigInt } else { - BigInt(rawLongValue) + BigInt(actualLongVal) } } @@ -212,7 +212,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.underlying().toBigInteger() } else { - java.math.BigInteger.valueOf(rawLongValue) + java.math.BigInteger.valueOf(actualLongVal) } } @@ -240,11 +240,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toFloat: Float = toBigDecimal.floatValue - private def rawLongValue: Long = longVal / POW_10(_scale) + private def actualLongVal: Long = longVal / POW_10(_scale) def toLong: Long = { if (decimalVal.eq(null)) { - rawLongValue + actualLongVal } else { decimalVal.longValue } @@ -280,8 +280,8 @@ final class Decimal extends Ordered[Decimal] with Serializable { private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) (f1: Long => T) (f2: Double => T): T = { if (decimalVal.eq(null)) { - val numericVal = f1(rawLongValue) - if (rawLongValue == numericVal) { + val numericVal = f1(actualLongVal) + if (actualLongVal == numericVal) { numericVal } else { throw QueryExecutionErrors.castingCauseOverflowError( @@ -304,7 +304,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ private[sql] def roundToLong(): Long = { if (decimalVal.eq(null)) { - rawLongValue + actualLongVal } else { try { // We cannot store Long.MAX_VALUE as a Double without losing precision.