diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 57e8fc060a291..cedf4440aabf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -204,7 +204,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.toBigInt } else { - BigInt(toLong) + BigInt(actualLongVal) } } @@ -212,7 +212,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { decimalVal.underlying().toBigInteger() } else { - java.math.BigInteger.valueOf(toLong) + java.math.BigInteger.valueOf(actualLongVal) } } @@ -226,7 +226,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def toString: String = toBigDecimal.toString() - def toPlainString: String = toBigDecimal.bigDecimal.toPlainString + def toPlainString: String = toJavaBigDecimal.toPlainString def toDebugString: String = { if (decimalVal.ne(null)) { @@ -240,9 +240,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { def toFloat: Float = toBigDecimal.floatValue + private def actualLongVal: Long = longVal / POW_10(_scale) + def toLong: Long = { if (decimalVal.eq(null)) { - longVal / POW_10(_scale) + actualLongVal } else { decimalVal.longValue } @@ -278,7 +280,6 @@ final class Decimal extends Ordered[Decimal] with Serializable { private def roundToNumeric[T <: AnyVal](integralType: IntegralType, maxValue: Int, minValue: Int) (f1: Long => T) (f2: Double => T): T = { if (decimalVal.eq(null)) { - val actualLongVal = longVal / POW_10(_scale) val numericVal = f1(actualLongVal) if (actualLongVal == numericVal) { numericVal @@ -303,7 +304,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { */ private[sql] def roundToLong(): Long = { if (decimalVal.eq(null)) { - longVal / POW_10(_scale) + actualLongVal } else { try { // We cannot store Long.MAX_VALUE as a Double without losing precision. @@ -455,7 +456,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { override def hashCode(): Int = toBigDecimal.hashCode() - def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 + def isZero: Boolean = if (decimalVal.ne(null)) decimalVal.signum == 0 else longVal == 0 // We should follow DecimalPrecision promote if use longVal for add and subtract: // Operation Result Precision Result Scale @@ -466,7 +467,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal + that.longVal, Math.max(precision, that.precision) + 1, scale) } else { - Decimal(toBigDecimal.bigDecimal.add(that.toBigDecimal.bigDecimal)) + Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal)) } } @@ -474,7 +475,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && that.decimalVal.eq(null) && scale == that.scale) { Decimal(longVal - that.longVal, Math.max(precision, that.precision) + 1, scale) } else { - Decimal(toBigDecimal.bigDecimal.subtract(that.toBigDecimal.bigDecimal)) + Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal)) } } @@ -504,7 +505,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this + def abs: Decimal = if (this < Decimal.ZERO) this.unary_- else this def floor: Decimal = if (scale == 0) this else { val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision @@ -532,8 +533,6 @@ object Decimal { val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO = BigDecimal(0) - private val MATH_CONTEXT = new MathContext(DecimalType.MAX_PRECISION, RoundingMode.HALF_UP) private[sql] val ZERO = Decimal(0) @@ -575,9 +574,8 @@ object Decimal { } } - private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = { - bigDecimal.precision - bigDecimal.scale - } + private def numDigitsInIntegralPart(bigDecimal: JavaBigDecimal): Int = + bigDecimal.precision - bigDecimal.scale private def stringToJavaBigDecimal(str: UTF8String): JavaBigDecimal = { // According the benchmark test, `s.toString.trim` is much faster than `s.trim.toString`.