diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 100b9cfd70f5..005350350104 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -154,6 +154,15 @@ object Cast { fromPrecedence >= 0 && fromPrecedence < toPrecedence } + def canNullSafeCastToDecimal(from: DataType, to: DecimalType): Boolean = from match { + case from: BooleanType if to.isWiderThan(DecimalType.BooleanDecimal) => true + case from: NumericType if to.isWiderThan(from) => true + case from: DecimalType => + // truncating or precision lose + (to.precision - to.scale) > (from.precision - from.scale) + case _ => false // overflow + } + def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { case (NullType, _) => true case (_, _) if from == to => false @@ -169,7 +178,7 @@ object Cast { case (DateType, _) => true case (_, CalendarIntervalType) => true - case (_, _: DecimalType) => true // overflow + case (_, to: DecimalType) if !canNullSafeCastToDecimal(from, to) => true case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index f780ffd46a87..15004e4b9667 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -121,6 +121,7 @@ object DecimalType extends AbstractDataType { val MINIMUM_ADJUSTED_SCALE = 6 // The decimal types compatible with other numeric types + private[sql] val BooleanDecimal = DecimalType(1, 0) private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) private[sql] val IntDecimal = DecimalType(10, 0) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 2c6cb3ae1274..461eda4334bb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -502,7 +502,11 @@ class TypeCoercionSuite extends AnalysisTest { widenTestWithStringPromotion( ArrayType(IntegerType, containsNull = false), ArrayType(DecimalType.IntDecimal, containsNull = false), - Some(ArrayType(DecimalType.IntDecimal, containsNull = true))) + Some(ArrayType(DecimalType.IntDecimal, containsNull = false))) + widenTestWithStringPromotion( + ArrayType(DecimalType(36, 0), containsNull = false), + ArrayType(DecimalType(36, 35), containsNull = false), + Some(ArrayType(DecimalType(38, 35), containsNull = true))) // MapType widenTestWithStringPromotion( @@ -524,10 +528,18 @@ class TypeCoercionSuite extends AnalysisTest { widenTestWithStringPromotion( MapType(StringType, IntegerType, valueContainsNull = false), MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false), - Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = true))) + Some(MapType(StringType, DecimalType.IntDecimal, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(StringType, DecimalType(36, 0), valueContainsNull = false), + MapType(StringType, DecimalType(36, 35), valueContainsNull = false), + Some(MapType(StringType, DecimalType(38, 35), valueContainsNull = true))) widenTestWithStringPromotion( MapType(IntegerType, StringType, valueContainsNull = false), MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false), + Some(MapType(DecimalType.IntDecimal, StringType, valueContainsNull = false))) + widenTestWithStringPromotion( + MapType(DecimalType(36, 0), StringType, valueContainsNull = false), + MapType(DecimalType(36, 35), StringType, valueContainsNull = false), None) // StructType @@ -555,7 +567,11 @@ class TypeCoercionSuite extends AnalysisTest { widenTestWithStringPromotion( new StructType().add("num", IntegerType, nullable = false), new StructType().add("num", DecimalType.IntDecimal, nullable = false), - Some(new StructType().add("num", DecimalType.IntDecimal, nullable = true))) + Some(new StructType().add("num", DecimalType.IntDecimal, nullable = false))) + widenTestWithStringPromotion( + new StructType().add("num", DecimalType(36, 0), nullable = false), + new StructType().add("num", DecimalType(36, 35), nullable = false), + Some(new StructType().add("num", DecimalType(38, 35), nullable = true))) widenTestWithStringPromotion( new StructType().add("num", IntegerType), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 5b25bdf907c3..d9f32c000a88 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -399,21 +399,35 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } test("casting to fixed-precision decimals") { - // Overflow and rounding for casting to fixed-precision decimals: - // - Values should round with HALF_UP mode by default when you lower scale - // - Values that would overflow the target precision should turn into null - // - Because of this, casts to fixed-precision decimals should be nullable - - assert(cast(123, DecimalType.USER_DEFAULT).nullable === true) + assert(cast(123, DecimalType.USER_DEFAULT).nullable === false) assert(cast(10.03f, DecimalType.SYSTEM_DEFAULT).nullable === true) assert(cast(10.03, DecimalType.SYSTEM_DEFAULT).nullable === true) - assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === true) + assert(cast(Decimal(10.03), DecimalType.SYSTEM_DEFAULT).nullable === false) assert(cast(123, DecimalType(2, 1)).nullable === true) assert(cast(10.03f, DecimalType(2, 1)).nullable === true) assert(cast(10.03, DecimalType(2, 1)).nullable === true) assert(cast(Decimal(10.03), DecimalType(2, 1)).nullable === true) + assert(cast(123, DecimalType.IntDecimal).nullable === false) + assert(cast(10.03f, DecimalType.FloatDecimal).nullable === true) + assert(cast(10.03, DecimalType.DoubleDecimal).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 2)).nullable === false) + assert(cast(Decimal(10.03), DecimalType(5, 3)).nullable === false) + + assert(cast(Decimal(10.03), DecimalType(3, 1)).nullable === true) + assert(cast(Decimal(10.03), DecimalType(4, 1)).nullable === false) + assert(cast(Decimal(9.95), DecimalType(2, 1)).nullable === true) + assert(cast(Decimal(9.95), DecimalType(3, 1)).nullable === false) + + assert(cast(Decimal("1003"), DecimalType(3, -1)).nullable === true) + assert(cast(Decimal("1003"), DecimalType(4, -1)).nullable === false) + assert(cast(Decimal("995"), DecimalType(2, -1)).nullable === true) + assert(cast(Decimal("995"), DecimalType(3, -1)).nullable === false) + + assert(cast(true, DecimalType.SYSTEM_DEFAULT).nullable === false) + assert(cast(true, DecimalType(1, 1)).nullable === true) + checkEvaluation(cast(10.03, DecimalType.SYSTEM_DEFAULT), Decimal(10.03)) checkEvaluation(cast(10.03, DecimalType(4, 2)), Decimal(10.03)) @@ -451,6 +465,20 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(Decimal(-9.95), DecimalType(3, 1)), Decimal(-10.0)) checkEvaluation(cast(Decimal(-9.95), DecimalType(1, 0)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType.SYSTEM_DEFAULT), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(4, 0)), Decimal(1003)) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("1003"), DecimalType(1, -2)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("1003"), DecimalType(3, 0)), null) + + checkEvaluation(cast(Decimal("995"), DecimalType(3, 0)), Decimal(995)) + checkEvaluation(cast(Decimal("995"), DecimalType(3, -1)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -2)), Decimal(1000)) + checkEvaluation(cast(Decimal("995"), DecimalType(2, -1)), null) + checkEvaluation(cast(Decimal("995"), DecimalType(1, -2)), null) + checkEvaluation(cast(Double.NaN, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(1.0 / 0.0, DecimalType.SYSTEM_DEFAULT), null) checkEvaluation(cast(Float.NaN, DecimalType.SYSTEM_DEFAULT), null) @@ -460,6 +488,9 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(cast(1.0 / 0.0, DecimalType(2, 1)), null) checkEvaluation(cast(Float.NaN, DecimalType(2, 1)), null) checkEvaluation(cast(1.0f / 0.0f, DecimalType(2, 1)), null) + + checkEvaluation(cast(true, DecimalType(2, 1)), Decimal(1)) + checkEvaluation(cast(true, DecimalType(1, 1)), null) } test("cast from date") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql index db00a18f2e7e..99f46dd19d0e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/concat.sql @@ -148,6 +148,8 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql index 119f868cb48e..1727ee725db2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapZipWith.sql @@ -47,6 +47,18 @@ FROM various_maps; SELECT map_zip_with(decimal_map1, decimal_map2, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps; +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps; + SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps; diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql index fc26397b881b..69da67fc66fc 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/mapconcat.sql @@ -61,6 +61,7 @@ SELECT map_concat(tinyint_map1, smallint_map2) ts_map, map_concat(smallint_map1, int_map2) si_map, map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, map_concat(decimal_map1, float_map2) df_map, map_concat(string_map1, date_map2) std_map, map_concat(timestamp_map1, string_map2) tst_map, diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out index be637b66abc8..6c6d3110d7d0 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/concat.sql.out @@ -306,12 +306,14 @@ SELECT (tinyint_array1 || smallint_array2) ts_array, (smallint_array1 || int_array2) si_array, (int_array1 || bigint_array2) ib_array, + (bigint_array1 || decimal_array2) bd_array, + (decimal_array1 || double_array2) dd_array, (double_array1 || float_array2) df_array, (string_array1 || data_array2) std_array, (timestamp_array1 || string_array2) tst_array, (string_array1 || int_array2) sti_array FROM various_arrays -- !query 13 schema -struct,si_array:array,ib_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> +struct,si_array:array,ib_array:array,bd_array:array,dd_array:array,df_array:array,std_array:array,tst_array:array,sti_array:array> -- !query 13 output -[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] +[2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,9223372036854775808,9223372036854775809] [9.223372036854776E18,9.223372036854776E18,3.0,4.0] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"] diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out index 7f7e2f07b9e7..35740094ba53 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapZipWith.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 12 +-- Number of queries: 16 -- !query 0 @@ -89,54 +89,91 @@ cannot resolve 'map_zip_with(various_maps.`decimal_map1`, various_maps.`decimal_ -- !query 6 -SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 6 schema -struct>> +struct>> -- !query 6 output -{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} +{2:{"k":2,"v1":null,"v2":1},922337203685477897945456575809789456:{"k":922337203685477897945456575809789456,"v1":922337203685477897945456575809789456,"v2":null}} -- !query 7 -SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map1, double_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 7 schema -struct>> +struct>> -- !query 7 output -{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854779E35:{"k":9.223372036854779E35,"v1":922337203685477897945456575809789456,"v2":null}} -- !query 8 -SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map2, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 8 schema -struct>> +struct<> -- !query 8 output -{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} +org.apache.spark.sql.AnalysisException +cannot resolve 'map_zip_with(various_maps.`decimal_map2`, various_maps.`int_map`, lambdafunction(named_struct(NamePlaceholder(), `k`, NamePlaceholder(), `v1`, NamePlaceholder(), `v2`), `k`, `v1`, `v2`))' due to argument data type mismatch: The input to function map_zip_with should have been two maps with compatible key types, but the key types are [decimal(36,35), int].; line 1 pos 7 -- !query 9 -SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(decimal_map2, double_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 9 schema -struct>> +struct>> -- !query 9 output -{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} +{2.0:{"k":2.0,"v1":null,"v2":1.0},9.223372036854778:{"k":9.223372036854778,"v1":9.22337203685477897945456575809789456,"v2":null}} -- !query 10 -SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(string_map1, int_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 10 schema -struct,struct,v1:array,v2:array>>> +struct>> -- !query 10 output -{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} +{"2":{"k":"2","v1":"1","v2":1},"true":{"k":"true","v1":"false","v2":null}} -- !query 11 -SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +SELECT map_zip_with(string_map2, date_map, (k, v1, v2) -> struct(k, v1, v2)) m FROM various_maps -- !query 11 schema -struct,struct,v1:struct,v2:struct>>> +struct>> -- !query 11 output +{"2016-03-14":{"k":"2016-03-14","v1":"2016-03-13","v2":2016-03-13}} + + +-- !query 12 +SELECT map_zip_with(timestamp_map, string_map3, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 12 schema +struct>> +-- !query 12 output +{"2016-11-15 20:54:00":{"k":"2016-11-15 20:54:00","v1":2016-11-12 20:54:00.0,"v2":null},"2016-11-15 20:54:00.000":{"k":"2016-11-15 20:54:00.000","v1":null,"v2":"2016-11-12 20:54:00.000"}} + + +-- !query 13 +SELECT map_zip_with(decimal_map1, string_map4, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 13 schema +struct>> +-- !query 13 output +{"922337203685477897945456575809789456":{"k":"922337203685477897945456575809789456","v1":922337203685477897945456575809789456,"v2":"text"}} + + +-- !query 14 +SELECT map_zip_with(array_map1, array_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 14 schema +struct,struct,v1:array,v2:array>>> +-- !query 14 output +{[1,2]:{"k":[1,2],"v1":[1,2],"v2":[1,2]}} + + +-- !query 15 +SELECT map_zip_with(struct_map1, struct_map2, (k, v1, v2) -> struct(k, v1, v2)) m +FROM various_maps +-- !query 15 schema +struct,struct,v1:struct,v2:struct>>> +-- !query 15 output {{"col1":1,"col2":2}:{"k":{"col1":1,"col2":2},"v1":{"col1":1,"col2":2},"v2":{"col1":1,"col2":2}}} diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out index d352b7284ae8..efc88e47209a 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/mapconcat.sql.out @@ -76,6 +76,7 @@ SELECT map_concat(tinyint_map1, smallint_map2) ts_map, map_concat(smallint_map1, int_map2) si_map, map_concat(int_map1, bigint_map2) ib_map, + map_concat(bigint_map1, decimal_map2) bd_map, map_concat(decimal_map1, float_map2) df_map, map_concat(string_map1, date_map2) std_map, map_concat(timestamp_map1, string_map2) tst_map, @@ -83,9 +84,9 @@ SELECT map_concat(int_string_map1, tinyint_map2) istt_map FROM various_maps -- !query 2 schema -struct,si_map:map,ib_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> +struct,si_map:map,ib_map:map,bd_map:map,df_map:map,std_map:map,tst_map:map,sti_map:map,istt_map:map> -- !query 2 output -{1:2,3:4} {1:2,7:8} {4:6,8:9} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} +{1:2,3:4} {1:2,7:8} {4:6,8:9} {6:7,9223372036854775808:9223372036854775809} {3.0:4.0,9.223372036854776E18:9.223372036854776E18} {"2016-03-12":"2016-03-11","a":"b"} {"2016-11-15 20:54:00":"2016-11-12 20:54:00","c":"d"} {"7":"8","a":"b"} {1:"a",3:"4"} -- !query 3