From 46f34955cd076c376b43bb262be1cff9754c8285 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 27 May 2016 15:35:37 -0700 Subject: [PATCH] SPARK-15557] expression ((cast(99 as decimal) + '3') * '2.3' ) return null --- .../sql/catalyst/analysis/TypeCoercion.scala | 5 ----- .../sql/catalyst/analysis/AnalysisSuite.scala | 3 +-- .../datasources/json/JsonSuite.scala | 4 ++-- .../sql/hive/execution/SQLQuerySuite.scala | 19 +++++++++++++++++++ 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 387e5552549e..a5b5b91e4ab3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -290,11 +290,6 @@ object TypeCoercion { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - case a @ BinaryArithmetic(left @ StringType(), right @ DecimalType.Expression(_, _)) => - a.makeCopy(Array(Cast(left, DecimalType.SYSTEM_DEFAULT), right)) - case a @ BinaryArithmetic(left @ DecimalType.Expression(_, _), right @ StringType()) => - a.makeCopy(Array(left, Cast(right, DecimalType.SYSTEM_DEFAULT))) - case a @ BinaryArithmetic(left @ StringType(), right) => a.makeCopy(Array(Cast(left, DoubleType), right)) case a @ BinaryArithmetic(left, right @ StringType()) => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index a63d1770f325..77ea29ead92c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -182,8 +182,7 @@ class AnalysisSuite extends AnalysisTest { assert(pl(0).dataType == DoubleType) assert(pl(1).dataType == DoubleType) assert(pl(2).dataType == DoubleType) - // StringType will be promoted into Decimal(38, 18) - assert(pl(3).dataType == DecimalType(38, 22)) + assert(pl(3).dataType == DoubleType) assert(pl(4).dataType == DoubleType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 500d8ff55a9a..9f35c02d4876 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -446,13 +446,13 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str > 14"), - Row(BigDecimal("92233720368547758071.2")) + Row(92233720368547758071.2) ) // Number and String conflict: resolve the type as number in this query. checkAnswer( sql("select num_str + 1.2 from jsonTable where num_str >= 92233720368547758060"), - Row(new java.math.BigDecimal("92233720368547758071.2")) + Row(new java.math.BigDecimal("92233720368547758071.2").doubleValue) ) // String and Boolean conflict: resolve the type as string. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 4b51f021bfa0..2a9b06b75efa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1560,4 +1560,23 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SELECT * FROM tbl"), Row(1, "a")) } } + + test("spark-15557 promote string test") { + withTable("tbl") { + sql("CREATE TABLE tbl(c1 string, c2 string)") + sql("insert into tbl values ('3', '2.3')") + checkAnswer( + sql("select (cast (99 as decimal(19,6)) + cast('3' as decimal)) * cast('2.3' as decimal)"), + Row(204.0) + ) + checkAnswer( + sql("select (cast(99 as decimal(19,6)) + '3') *'2.3' from tbl"), + Row(234.6) + ) + checkAnswer( + sql("select (cast(99 as decimal(19,6)) + c1) * c2 from tbl"), + Row(234.6) + ) + } + } }