diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index bf46a3986213..66a36f4c2abd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -232,19 +232,20 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.ceil()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } @@ -348,19 +349,20 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO } override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(LongType, DoubleType, DecimalType)) + Seq(TypeCollection(DoubleType, DecimalType, LongType)) protected override def nullSafeEval(input: Any): Any = child.dataType match { case LongType => input.asInstanceOf[Long] case DoubleType => f(input.asInstanceOf[Double]).toLong - case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor + case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") - case DecimalType.Fixed(precision, scale) => + case DecimalType.Fixed(_, _) => defineCodeGen(ctx, ev, c => s"$c.floor()") + case LongType => defineCodeGen(ctx, ev, c => s"$c") case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 0896caeab8d7..31047f688600 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -262,7 +262,7 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val plan = testRelation2.select('c).orderBy(Floor('a).asc) val expected = testRelation2.select(c, a) - .orderBy(Floor(Cast(a, LongType, Option(TimeZone.getDefault().getID))).asc).select(c) + .orderBy(Floor(Cast(a, DoubleType, Option(TimeZone.getDefault().getID))).asc).select(c) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 8ed7a82b943b..6af0cde73538 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -258,6 +258,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Ceil, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Ceil(doublePi), 4L, EmptyRow) + checkEvaluation(Ceil(floatPi.toDouble), 4L, EmptyRow) + checkEvaluation(Ceil(longLit), longLit, EmptyRow) + checkEvaluation(Ceil(-doublePi), -3L, EmptyRow) + checkEvaluation(Ceil(-floatPi.toDouble), -3L, EmptyRow) + checkEvaluation(Ceil(-longLit), -longLit, EmptyRow) } test("floor") { @@ -268,6 +278,16 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 3)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(25, 0)) checkConsistencyBetweenInterpretedAndCodegen(Floor, DecimalType(5, 0)) + + val doublePi: Double = 3.1415 + val floatPi: Float = 3.1415f + val longLit: Long = 12345678901234567L + checkEvaluation(Floor(doublePi), 3L, EmptyRow) + checkEvaluation(Floor(floatPi.toDouble), 3L, EmptyRow) + checkEvaluation(Floor(longLit), longLit, EmptyRow) + checkEvaluation(Floor(-doublePi), -4L, EmptyRow) + checkEvaluation(Floor(-floatPi.toDouble), -4L, EmptyRow) + checkEvaluation(Floor(-longLit), -longLit, EmptyRow) } test("factorial") { diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index f7167472b05c..7e3b86b76a34 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -64,12 +64,9 @@ select cot(-1); select ceiling(0); select ceiling(1); select ceil(1234567890123456); -select ceil(12345678901234567); select ceiling(1234567890123456); -select ceiling(12345678901234567); -- floor select floor(0); select floor(1); select floor(1234567890123456); -select floor(12345678901234567); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fe52005aa91d..28cfb744193e 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 38 +-- Number of queries: 45 -- !query 0 @@ -321,7 +321,7 @@ struct -- !query 38 select ceiling(0) -- !query 38 schema -struct +struct -- !query 38 output 0 @@ -329,7 +329,7 @@ struct -- !query 39 select ceiling(1) -- !query 39 schema -struct +struct -- !query 39 output 1 @@ -343,56 +343,32 @@ struct -- !query 41 -select ceil(12345678901234567) +select ceiling(1234567890123456) -- !query 41 schema -struct +struct -- !query 41 output -12345678901234567 +1234567890123456 -- !query 42 -select ceiling(1234567890123456) +select floor(0) -- !query 42 schema -struct +struct -- !query 42 output -1234567890123456 +0 -- !query 43 -select ceiling(12345678901234567) +select floor(1) -- !query 43 schema -struct +struct -- !query 43 output -12345678901234567 - - --- !query 44 -select floor(0) --- !query 44 schema -struct --- !query 44 output -0 - - --- !query 45 -select floor(1) --- !query 45 schema -struct --- !query 45 output 1 --- !query 46 +-- !query 44 select floor(1234567890123456) --- !query 46 schema +-- !query 44 schema struct --- !query 46 output +-- !query 44 output 1234567890123456 - - --- !query 47 -select floor(12345678901234567) --- !query 47 schema -struct --- !query 47 output -12345678901234567