From 9730afbc2a303b4961474eb8d21b709a2cd7d596 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 21 Jul 2016 11:05:53 +0900 Subject: [PATCH 1/2] LEAST and GREATEST doesn't accept numeric arguments with different data types --- .../sql/catalyst/analysis/TypeCoercion.scala | 41 +++++++++++++- .../catalyst/analysis/TypeCoercionSuite.scala | 9 ++++ .../spark/sql/DataFrameFunctionsSuite.scala | 54 +++++++++++++++++++ 3 files changed, 102 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 9a040f8644fb..08c539d5549a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -108,6 +108,31 @@ object TypeCoercion { }) } + /** + * Similar to [[findTightestCommonType]], but this handles all numeric types including + * fixed-precision decimals interacting with each other or with primitive types. This will + * not lose precision and scale. + */ + private def findTightestCommonTypeToDecimal(left: DataType, right: DataType): Option[DataType] = { + findTightestCommonTypeOfTwo(left, right).orElse((left, right) match { + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + None + } else { + Some(DecimalType(range + scale, scale)) + } + case (t1: IntegralType, t2: DecimalType) => + findTightestCommonTypeToDecimal(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + findTightestCommonTypeToDecimal(t1, DecimalType.forType(t2)) + + case _ => None + }) + } + /** * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use * [[findTightestCommonTypeToString]] to find the TightestCommonType. @@ -120,6 +145,18 @@ object TypeCoercion { }) } + /** + * Similar to [[findTightestCommonType]], Find the tightest common type of a set of types + * by continuously applying `findTightestCommonTypeToDecimal` on these types. + */ + private def findTightestCommonTypeAndPromoteToDecimal(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => + findTightestCommonTypeToDecimal(d, c) + }) + } + /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -496,14 +533,14 @@ object TypeCoercion { case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findTightestCommonTypeAndPromoteToDecimal(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findTightestCommonType(types) match { + findTightestCommonTypeAndPromoteToDecimal(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 971c99b67167..2a55c467c46c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -344,6 +344,15 @@ class TypeCoercionSuite extends PlanTest { :: Cast(Literal(1), DecimalType(22, 0)) :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0)) :: Nil)) + ruleTest(TypeCoercion.FunctionArgumentConversion, + operator(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1.5")) + :: Nil), + operator(Cast(Literal(1L), DecimalType(21, 1)) + :: Cast(Literal(1), DecimalType(21, 1)) + :: Cast(Literal(new java.math.BigDecimal("1.5")), DecimalType(21, 1)) + :: Nil)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 0f6c49e75959..3efa53edf141 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -222,6 +222,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("conditional function: least - type cast") { + checkAnswer( + testData2.select(least(lit(BigDecimal("-1")), lit(0), col("a"), col("b"))).limit(1), + Row(BigDecimal("-1")) + ) + checkAnswer( + sql("SELECT least(a, 1.5) as l from testData2 order by l"), + Seq( + Row(BigDecimal("1.0")), + Row(BigDecimal("1.0")), + Row(BigDecimal("1.5")), + Row(BigDecimal("1.5")), + Row(BigDecimal("1.5")), + Row(BigDecimal("1.5"))) + ) + } + + test("conditional function: least - type cast failure") { + val message = intercept[AnalysisException] { + testData2.select( + least(lit(BigDecimal("0.000000000000000000001")), lit(0L), col("a"), col("b"))).limit(1) + }.message + assert( + message.contains("cannot resolve 'least(CAST(1E-21 AS DECIMAL(21,21)), 0L, `a`, `b`)'" + + " due to data type mismatch") ) + } + test("conditional function: greatest") { checkAnswer( testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1), @@ -233,6 +260,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { ) } + test("conditional function: greatest - type cast") { + checkAnswer( + testData2.select(greatest(lit(2), lit(BigDecimal("3")), col("a"), col("b"))).limit(1), + Row(BigDecimal("3")) + ) + checkAnswer( + sql("SELECT greatest(a, 2.5) as g from testData2 order by g"), + Seq( + Row(BigDecimal("2.5")), + Row(BigDecimal("2.5")), + Row(BigDecimal("2.5")), + Row(BigDecimal("2.5")), + Row(BigDecimal("3")), + Row(BigDecimal("3"))) + ) + } + + test("conditional function: greatest - type cast failure") { + val message = intercept[AnalysisException] { + testData2.select( + greatest(lit(BigDecimal("0.000000000000000000001")), lit(0L), col("a"), col("b"))) + }.message + assert( + message.contains("cannot resolve 'greatest(CAST(1E-21 AS DECIMAL(21,21)), 0L, `a`, `b`)'" + + " due to data type mismatch") ) + } + test("pmod") { val intData = Seq((7, 3), (-7, 3)).toDF("a", "b") checkAnswer( From abbcb4e9884c7aafcb0705a7b9d0b80988d8a8c7 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 21 Jul 2016 12:57:16 +0900 Subject: [PATCH 2/2] Remove the test checking exception with decimal and integer in least and greatest --- .../sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 76e42d9afa4c..3aefb3cfc333 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -216,7 +216,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) { assertError(operator(Seq('booleanField)), "requires at least 2 arguments") assertError(operator(Seq('intField, 'stringField)), "should all have the same type") - assertError(operator(Seq('intField, 'decimalField)), "should all have the same type") assertError(operator(Seq('mapField, 'mapField)), "does not support ordering") } }