Skip to content

Commit 73e8544

Browse files
change the MAX_SCALE for decimal
1 parent 053d94f commit 73e8544

File tree

2 files changed

+36
-8
lines changed

2 files changed

+36
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ object DecimalType extends AbstractDataType {
106106
import scala.math.min
107107

108108
val MAX_PRECISION = 38
109-
val MAX_SCALE = 38
110-
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18)
109+
val MAX_SCALE = 18
110+
val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, MAX_SCALE)
111111
val USER_DEFAULT: DecimalType = DecimalType(10, 0)
112112

113113
@deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5")

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ class MyDialect extends DefaultParserDialect
3535

3636
class SQLQuerySuite extends QueryTest with SharedSQLContext {
3737
import testImplicits._
38+
import BigDecimal.RoundingMode._
3839

3940
setupTestData()
4041

@@ -1614,17 +1615,23 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
16141615
checkAnswer(sql("select 10.3000 * 3.0"), Row(BigDecimal("30.90000")))
16151616
checkAnswer(sql("select 10.30000 * 30.0"), Row(BigDecimal("309.000000")))
16161617
checkAnswer(sql("select 10.300000000000000000 * 3.000000000000000000"),
1617-
Row(BigDecimal("30.900000000000000000000000000000000000", new MathContext(38))))
1618-
checkAnswer(sql("select 10.300000000000000000 * 3.0000000000000000000"),
1619-
Row(null))
1618+
Row(BigDecimal("30.9", new MathContext(38))))
1619+
checkAnswer(sql("select 10.30000000000000000000 * 3.000000000000000000000"),
1620+
Row(BigDecimal("30.9", new MathContext(38))))
16201621

16211622
checkAnswer(sql("select 10.3 / 3.0"), Row(BigDecimal("3.433333")))
16221623
checkAnswer(sql("select 10.3000 / 3.0"), Row(BigDecimal("3.4333333")))
16231624
checkAnswer(sql("select 10.30000 / 30.0"), Row(BigDecimal("0.343333333")))
16241625
checkAnswer(sql("select 10.300000000000000000 / 3.00000000000000000"),
1625-
Row(BigDecimal("3.4333333333333333333333333333333333333", new MathContext(38))))
1626-
checkAnswer(sql("select 10.3000000000000000000 / 3.00000000000000000"),
1627-
Row(null))
1626+
Row(BigDecimal("3.4333333333333333333333333").setScale(DecimalType.MAX_SCALE, HALF_UP)))
1627+
checkAnswer(sql("select 10.300000000000000000000 / 3.0000000000000000000"),
1628+
Row(BigDecimal("3.4333333333333333333333333").setScale(DecimalType.MAX_SCALE, HALF_UP)))
1629+
checkAnswer(sql("select 1030000.0000000000000000 / 3.0000000000000000000"),
1630+
Row(BigDecimal("343333.333333333333333333").setScale(DecimalType.MAX_SCALE, HALF_UP)))
1631+
checkAnswer(sql("select 10300000000000000000 / 3.0000000000000000000"),
1632+
Row(BigDecimal("3433333333333333333.333333333333333333", new MathContext(38))))
1633+
checkAnswer(sql("select 1030000000000000000000000 / 0.1"),
1634+
Row(BigDecimal("10300000000000000000000000", new MathContext(38))))
16281635
}
16291636

16301637
test("external sorting updates peak execution memory") {
@@ -1686,4 +1693,25 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
16861693
checkAnswer(
16871694
sql("SELECT IF(a > 0, a, 0) FROM (SELECT key a FROM src) temp"), Seq(Row(1), Row(0)))
16881695
}
1696+
1697+
test("SPARK-10215 Div of Decimal returns null") {
1698+
val d = Decimal(1.12321)
1699+
val df = Seq((d, 1)).toDF("a", "b")
1700+
1701+
checkAnswer(
1702+
df.selectExpr("b * a / b"),
1703+
Seq(Row(d.toBigDecimal.setScale(DecimalType.MAX_SCALE, HALF_UP))))
1704+
checkAnswer(
1705+
df.selectExpr("b * a / b / b"),
1706+
Seq(Row(d.toBigDecimal.setScale(DecimalType.MAX_SCALE, HALF_UP))))
1707+
checkAnswer(
1708+
df.selectExpr("b * a + b"),
1709+
Seq(Row(BigDecimal(2.12321).setScale(DecimalType.MAX_SCALE, HALF_UP))))
1710+
checkAnswer(
1711+
df.selectExpr("b * a - b"),
1712+
Seq(Row(BigDecimal(0.12321).setScale(DecimalType.MAX_SCALE, HALF_UP))))
1713+
checkAnswer(
1714+
df.selectExpr("b * a * b"),
1715+
Seq(Row(d.toBigDecimal.setScale(DecimalType.MAX_SCALE, HALF_UP))))
1716+
}
16891717
}

0 commit comments

Comments
 (0)