diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 80916ee9c5379..1f1fb51addfd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -126,7 +126,15 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(decimal: BigDecimal): Decimal = { this.decimalVal = decimal this.longVal = 0L - this._precision = decimal.precision + if (decimal.precision <= decimal.scale) { + // For Decimal, we expect the precision is equal to or large than the scale, however, + // in BigDecimal, the digit count starts from the leftmost nonzero digit of the exact + // result. For example, the precision of 0.01 equals to 1 based on the definition, but + // the scale is 2. The expected precision should be 3. + this._precision = decimal.scale + 1 + } else { + this._precision = decimal.precision + } this._scale = decimal.scale this } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2624f5586fd5d..2ac11598e63d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -484,24 +484,50 @@ class TypeCoercionSuite extends PlanTest { } test("coalesce casts") { - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1.0) - :: Literal(1) - :: Literal.create(1.0, FloatType) - :: Nil), - Coalesce(Cast(Literal(1.0), DoubleType) - :: Cast(Literal(1), DoubleType) - :: Cast(Literal.create(1.0, FloatType), DoubleType) - :: Nil)) - ruleTest(TypeCoercion.FunctionArgumentConversion, - Coalesce(Literal(1L) - :: Literal(1) - :: Literal(new java.math.BigDecimal("1000000000000000000000")) - :: Nil), - Coalesce(Cast(Literal(1L), DecimalType(22, 0)) - :: Cast(Literal(1), DecimalType(22, 0)) - :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) - :: Nil)) + val rule = TypeCoercion.FunctionArgumentConversion + + val intLit = Literal(1) + val longLit = Literal.create(1L) + val doubleLit = Literal(1.0) + val stringLit = Literal.create("c", StringType) + val nullLit = Literal.create(null, NullType) + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) + + ruleTest(rule, + Coalesce(Seq(doubleLit, intLit, floatLit)), + Coalesce(Seq(Cast(doubleLit, DoubleType), + Cast(intLit, DoubleType), Cast(floatLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(longLit, intLit, decimalLit)), + Coalesce(Seq(Cast(longLit, DecimalType(22, 0)), + Cast(intLit, DecimalType(22, 0)), Cast(decimalLit, DecimalType(22, 0))))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit)), + Coalesce(Seq(Cast(nullLit, IntegerType), Cast(intLit, IntegerType)))) + + ruleTest(rule, + Coalesce(Seq(timestampLit, stringLit)), + Coalesce(Seq(Cast(timestampLit, StringType), Cast(stringLit, StringType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, intLit)), + Coalesce(Seq(Cast(nullLit, FloatType), Cast(floatNullLit, FloatType), + Cast(intLit, FloatType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, intLit, decimalLit, doubleLit)), + Coalesce(Seq(Cast(nullLit, DoubleType), Cast(intLit, DoubleType), + Cast(decimalLit, DoubleType), Cast(doubleLit, DoubleType)))) + + ruleTest(rule, + Coalesce(Seq(nullLit, floatNullLit, doubleLit, stringLit)), + Coalesce(Seq(Cast(nullLit, StringType), Cast(floatNullLit, StringType), + Cast(doubleLit, StringType), Cast(stringLit, StringType)))) } test("CreateArray casts") { @@ -675,6 +701,14 @@ class TypeCoercionSuite extends PlanTest { test("type coercion for If") { val rule = TypeCoercion.IfCoercion + val intLit = Literal(1) + val doubleLit = Literal(1.0) + val trueLit = Literal.create(true, BooleanType) + val falseLit = Literal.create(false, BooleanType) + val stringLit = Literal.create("c", StringType) + val floatLit = Literal.create(1.0f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal(new java.math.BigDecimal("1000000000000000000000")) ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), @@ -685,12 +719,32 @@ class TypeCoercionSuite extends PlanTest { If(Literal.create(null, BooleanType), Literal(1), Literal(1))) ruleTest(rule, - If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(AssertTrue(trueLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(trueLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(AssertTrue(falseLit), Literal(1), Literal(2)), + If(Cast(AssertTrue(falseLit), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(trueLit, intLit, doubleLit), + If(trueLit, Cast(intLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, doubleLit), + If(trueLit, Cast(floatLit, DoubleType), doubleLit)) + + ruleTest(rule, + If(trueLit, floatLit, decimalLit), + If(trueLit, Cast(floatLit, DoubleType), Cast(decimalLit, DoubleType))) + + ruleTest(rule, + If(falseLit, stringLit, doubleLit), + If(falseLit, stringLit, Cast(doubleLit, StringType))) ruleTest(rule, - If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)), - If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2))) + If(trueLit, timestampLit, stringLit), + If(trueLit, Cast(timestampLit, StringType), stringLit)) } test("type coercion for CaseKeyWhen") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala index 5064a1f63f83d..394c0a091e390 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NullExpressionsSuite.scala @@ -97,14 +97,30 @@ class NullExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val doubleLit = Literal.create(2.2, DoubleType) val stringLit = Literal.create("c", StringType) val nullLit = Literal.create(null, NullType) - + val floatNullLit = Literal.create(null, FloatType) + val floatLit = Literal.create(1.01f, FloatType) + val timestampLit = Literal.create("2017-04-12", TimestampType) + val decimalLit = Literal.create(10.2, DecimalType(20, 2)) + + assert(analyze(new Nvl(decimalLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, decimalLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(decimalLit, floatLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatLit, decimalLit)).dataType == DoubleType) + + assert(analyze(new Nvl(timestampLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(intLit, doubleLit)).dataType == DoubleType) assert(analyze(new Nvl(intLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(stringLit, doubleLit)).dataType == StringType) + assert(analyze(new Nvl(doubleLit, stringLit)).dataType == StringType) assert(analyze(new Nvl(nullLit, intLit)).dataType == IntegerType) assert(analyze(new Nvl(doubleLit, nullLit)).dataType == DoubleType) assert(analyze(new Nvl(nullLit, stringLit)).dataType == StringType) + + assert(analyze(new Nvl(floatLit, stringLit)).dataType == StringType) + assert(analyze(new Nvl(floatLit, doubleLit)).dataType == DoubleType) + assert(analyze(new Nvl(floatNullLit, intLit)).dataType == FloatType) } test("AtLeastNNonNulls") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 93c231e30b49b..144f3d688d402 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -32,6 +32,16 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { test("creating decimals") { checkDecimal(new Decimal(), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("0.09")), "0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("0.9")), "0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("0.90")), "0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("0.0")), "0.0", 2, 1) + checkDecimal(Decimal(BigDecimal("0")), "0", 1, 0) + checkDecimal(Decimal(BigDecimal("1.0")), "1.0", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.09")), "-0.09", 3, 2) + checkDecimal(Decimal(BigDecimal("-0.9")), "-0.9", 2, 1) + checkDecimal(Decimal(BigDecimal("-0.90")), "-0.90", 3, 2) + checkDecimal(Decimal(BigDecimal("-1.0")), "-1.0", 2, 1) checkDecimal(Decimal(BigDecimal("10.030")), "10.030", 5, 3) checkDecimal(Decimal(BigDecimal("10.030"), 4, 1), "10.0", 4, 1) checkDecimal(Decimal(BigDecimal("-9.95"), 4, 1), "-10.0", 4, 1) diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 7e3b86b76a34a..75a0256ad7239 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -65,8 +65,15 @@ select ceiling(0); select ceiling(1); select ceil(1234567890123456); select ceiling(1234567890123456); +select ceil(0.01); +select ceiling(-0.10); -- floor select floor(0); select floor(1); select floor(1234567890123456); +select floor(0.01); +select floor(-0.10); + +-- comparison operator +select 1 > 0.00001 diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 28cfb744193ec..57e8a612fab44 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 45 +-- Number of queries: 50 -- !query 0 @@ -351,24 +351,64 @@ struct -- !query 42 -select floor(0) +select ceil(0.01) -- !query 42 schema -struct +struct -- !query 42 output -0 +1 -- !query 43 -select floor(1) +select ceiling(-0.10) -- !query 43 schema -struct +struct -- !query 43 output -1 +0 -- !query 44 -select floor(1234567890123456) +select floor(0) -- !query 44 schema -struct +struct -- !query 44 output +0 + + +-- !query 45 +select floor(1) +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +select floor(1234567890123456) +-- !query 46 schema +struct +-- !query 46 output 1234567890123456 + + +-- !query 47 +select floor(0.01) +-- !query 47 schema +struct +-- !query 47 output +0 + + +-- !query 48 +select floor(-0.10) +-- !query 48 schema +struct +-- !query 48 output +-1 + + +-- !query 49 +select 1 > 0.00001 +-- !query 49 schema +struct<(CAST(1 AS BIGINT) > 0):boolean> +-- !query 49 output +true