From b5365e28bf40448bb3cbd59668f316c7e5a3809a Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 14 Sep 2018 16:23:50 +0800 Subject: [PATCH 1/6] Support truncate number --- .../catalyst/analysis/FunctionRegistry.scala | 1 + .../expressions/mathExpressions.scala | 64 ++++++++++++++++++- .../expressions/MathExpressionsSuite.scala | 27 ++++++++ .../org/apache/spark/sql/functions.scala | 18 ++++++ .../resources/sql-tests/inputs/operators.sql | 6 ++ .../sql-tests/results/operators.sql.out | 34 +++++++++- 6 files changed, 148 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 77860e1584f4..4a2c596cbcac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -262,6 +262,7 @@ object FunctionRegistry { expression[Tan]("tan"), expression[Cot]("cot"), expression[Tanh]("tanh"), + expression[Truncate]("truncate"), expression[Add]("+"), expression[Subtract]("-"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c2e1720259b5..a63f7abe095c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.NumberConverter +import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1245,3 +1245,65 @@ case class BRound(child: Expression, scale: Expression) with Serializable with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal(0)) } + +/** + * The number truncated to scale decimal places. + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(number, scale) - Returns number truncated to scale decimal places. " + + "If scale is omitted, then number is truncated to 0 places. " + + "scale can be negative to truncate (make zero) scale digits left of the decimal point.", + examples = """ + Examples: + > SELECT _FUNC_(1234567891.1234567891, 4); + 1234567891.1234 + > SELECT _FUNC_(1234567891.1234567891, -4); + 1234560000 + > SELECT _FUNC_(1234567891.1234567891); + 1234567891 + """) +// scalastyle:on line.size.limit +case class Truncate(number: Expression, scale: Expression) + extends BinaryExpression with ImplicitCastInputTypes { + + override def left: Expression = number + override def right: Expression = scale + + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(DoubleType, DecimalType), IntegerType) + + override def dataType: DataType = left.dataType + + private lazy val foldableTruncScale: Int = scale.eval().asInstanceOf[Int] + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val truncScale = if (scale.foldable) { + foldableTruncScale + } else { + scale.eval().asInstanceOf[Int] + } + number.dataType match { + case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], truncScale) + case DecimalType.Fixed(_, _) => + MathUtils.trunc(input1.asInstanceOf[Decimal].toJavaBigDecimal, truncScale) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mu = MathUtils.getClass.getName.stripSuffix("$") + if (scale.foldable) { + val d = number.genCode(ctx) + ev.copy(code = code""" + ${d.code} + boolean ${ev.isNull} = ${d.isNull}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $mu.trunc(${d.value}, $foldableTruncScale); + }""") + } else { + nullSafeCodeGen(ctx, ev, (doubleVal, truncParam) => + s"${ev.value} = $mu.trunc($doubleVal, $truncParam);") + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 3a094079380f..d2f2bb8e65c4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -644,4 +644,31 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BRound(-0.35, 1), -0.4) checkEvaluation(BRound(-35, -1), -40) } + + test("Truncate number") { + def testTruncate(input: Double, fmt: Int, expected: Double): Unit = { + checkEvaluation(Truncate(Literal.create(input, DoubleType), + Literal.create(fmt, IntegerType)), + expected) + checkEvaluation(Truncate(Literal.create(input, DoubleType), + NonFoldableLiteral.create(fmt, IntegerType)), + expected) + } + + testTruncate(1234567891.1234567891, 4, 1234567891.1234) + testTruncate(1234567891.1234567891, -4, 1234560000) + testTruncate(1234567891.1234567891, 0, 1234567891) + testTruncate(0.123, -1, 0) + testTruncate(0.123, 0, 0) + + checkEvaluation(Truncate(Literal.create(1D, DoubleType), + NonFoldableLiteral.create(null, IntegerType)), + null) + checkEvaluation(Truncate(Literal.create(null, DoubleType), + NonFoldableLiteral.create(1, IntegerType)), + null) + checkEvaluation(Truncate(Literal.create(null, DoubleType), + NonFoldableLiteral.create(null, IntegerType)), + null) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 10b67d7a1ca5..521d5710a959 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2214,6 +2214,24 @@ object functions { */ def radians(columnName: String): Column = radians(Column(columnName)) + /** + * Returns number truncated to the unit specified by the scale. + * + * For example, `truncate(1234567891.1234567891, 4)` returns 1234567891.1234 + * + * @param number The number to be truncated + * @param scale: A scale used to truncate number + * + * @return The number truncated to scale decimal places. + * If scale is omitted, then number is truncated to 0 places. + * scale can be negative to truncate (make zero) scale digits left of the decimal point. + * @group math_funcs + * @since 2.4.0 + */ + def truncate(number: Column, scale: Int): Column = withExpr { + Truncate(number.expr, Literal(scale)) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Misc functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 15d981985c55..70b43c48c45d 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -96,3 +96,9 @@ select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); -- pmod select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null); select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); + +-- truncate +select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4); +select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4); +select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4); +select truncate(cast(1234567891.1234567891 as long), 9.03) diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 840655b7a644..3464d7e0f84a 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 59 +-- Number of queries: 63 -- !query 0 @@ -484,3 +484,35 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) struct -- !query 58 output NULL NULL + + +-- !query 59 +select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4) +-- !query 59 schema +struct +-- !query 59 output +1234560000 1234567891 1234567891.1234 + + +-- !query 60 +select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4) +-- !query 60 schema +struct +-- !query 60 output +1234560000 1234567891 1234567891 + + +-- !query 61 +select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4) +-- !query 61 schema +struct +-- !query 61 output +1.23456E9 1.234567891E9 1.234567891E9 + + +-- !query 62 +select truncate(cast(1234567891.1234567891 as long), 9.03) +-- !query 62 schema +struct +-- !query 62 output +1.234567891E9 From bf7103a6f68119db81b1d89561f9d1430d8a9084 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Fri, 14 Sep 2018 16:57:41 +0800 Subject: [PATCH 2/6] Add MathUtils --- .../spark/sql/catalyst/util/MathUtils.scala | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala new file mode 100644 index 000000000000..912d99d957f8 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.util + +import java.math.{BigDecimal => JBigDecimal} + +object MathUtils { + + /** + * Returns double type input truncated to scale decimal places. + */ + def trunc(input: Double, scale: Int): Double = { + trunc(JBigDecimal.valueOf(input), scale).doubleValue() + } + + /** + * Returns BigDecimal type input truncated to scale decimal places. + */ + def trunc(input: JBigDecimal, scale: Int): JBigDecimal = { + // Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0 + // /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487) + val pow = if (scale >= 0) { + JBigDecimal.valueOf(Math.pow(10, scale)) + } else { + JBigDecimal.valueOf(Math.pow(10, Math.abs(scale))) + } + + if (scale > 0) { + val longValue = input.multiply(pow).longValue() + JBigDecimal.valueOf(longValue).divide(pow) + } else if (scale == 0) { + JBigDecimal.valueOf(input.longValue()) + } else { + val longValue = input.divide(pow).longValue() + JBigDecimal.valueOf(longValue).multiply(pow) + } + } +} From c7156943a2a32ba57e67aa6d8fa7035a09847e07 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 19 Sep 2018 18:13:22 +0800 Subject: [PATCH 3/6] Add float type. --- .../expressions/mathExpressions.scala | 47 +++++++++----- .../spark/sql/catalyst/util/MathUtils.scala | 24 ++++++- .../expressions/MathExpressionsSuite.scala | 63 +++++++++++++------ .../org/apache/spark/sql/functions.scala | 22 ++++--- .../resources/sql-tests/inputs/operators.sql | 3 +- .../sql-tests/results/operators.sql.out | 10 ++- 6 files changed, 119 insertions(+), 50 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index a63f7abe095c..32fa036747e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -1267,43 +1267,58 @@ case class BRound(child: Expression, scale: Expression) case class Truncate(number: Expression, scale: Expression) extends BinaryExpression with ImplicitCastInputTypes { + def this(number: Expression) = this(number, Literal(0)) + override def left: Expression = number override def right: Expression = scale override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, DecimalType), IntegerType) + Seq(TypeCollection(DoubleType, FloatType, DecimalType), IntegerType) + + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckSuccess => + if (scale.foldable) { + TypeCheckSuccess + } else { + TypeCheckFailure("Only foldable Expression is allowed for scale arguments") + } + case f => f + } + } override def dataType: DataType = left.dataType + override def nullable: Boolean = true + override def prettyName: String = "truncate" - private lazy val foldableTruncScale: Int = scale.eval().asInstanceOf[Int] + private lazy val scaleV: Any = scale.eval(EmptyRow) + private lazy val _scale: Int = scaleV.asInstanceOf[Int] protected override def nullSafeEval(input1: Any, input2: Any): Any = { - val truncScale = if (scale.foldable) { - foldableTruncScale - } else { - scale.eval().asInstanceOf[Int] - } number.dataType match { - case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], truncScale) - case DecimalType.Fixed(_, _) => - MathUtils.trunc(input1.asInstanceOf[Decimal].toJavaBigDecimal, truncScale) + case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], _scale) + case FloatType => MathUtils.trunc(input1.asInstanceOf[Float], _scale) + case DecimalType.Fixed(_, _) => MathUtils.trunc(input1.asInstanceOf[Decimal], _scale) } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val mu = MathUtils.getClass.getName.stripSuffix("$") - if (scale.foldable) { + + val javaType = CodeGenerator.javaType(dataType) + if (scaleV == null) { // if scale is null, no need to eval its child at all + ev.copy(code = code""" + boolean ${ev.isNull} = true; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") + } else { val d = number.genCode(ctx) ev.copy(code = code""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; - ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $mu.trunc(${d.value}, $foldableTruncScale); + ${ev.value} = $mu.trunc(${d.value}, ${_scale}); }""") - } else { - nullSafeCodeGen(ctx, ev, (doubleVal, truncParam) => - s"${ev.value} = $mu.trunc($doubleVal, $truncParam);") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala index 912d99d957f8..9f00af128666 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -18,19 +18,35 @@ package org.apache.spark.sql.catalyst.util import java.math.{BigDecimal => JBigDecimal} +import org.apache.spark.sql.types.Decimal + object MathUtils { /** * Returns double type input truncated to scale decimal places. */ def trunc(input: Double, scale: Int): Double = { - trunc(JBigDecimal.valueOf(input), scale).doubleValue() + trunc(JBigDecimal.valueOf(input), scale).toDouble + } + + /** + * Returns float type input truncated to scale decimal places. + */ + def trunc(input: Float, scale: Int): Float = { + trunc(JBigDecimal.valueOf(input), scale).toFloat + } + + /** + * Returns decimal type input truncated to scale decimal places. + */ + def trunc(input: Decimal, scale: Int): Decimal = { + trunc(input.toJavaBigDecimal, scale) } /** * Returns BigDecimal type input truncated to scale decimal places. */ - def trunc(input: JBigDecimal, scale: Int): JBigDecimal = { + def trunc(input: JBigDecimal, scale: Int): Decimal = { // Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0 // /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487) val pow = if (scale >= 0) { @@ -39,7 +55,7 @@ object MathUtils { JBigDecimal.valueOf(Math.pow(10, Math.abs(scale))) } - if (scale > 0) { + val truncatedValue = if (scale > 0) { val longValue = input.multiply(pow).longValue() JBigDecimal.valueOf(longValue).divide(pow) } else if (scale == 0) { @@ -48,5 +64,7 @@ object MathUtils { val longValue = input.divide(pow).longValue() JBigDecimal.valueOf(longValue).multiply(pow) } + + Decimal(truncatedValue) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index d2f2bb8e65c4..320a2e821d8a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -646,29 +646,54 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } test("Truncate number") { - def testTruncate(input: Double, fmt: Int, expected: Double): Unit = { + assert(Truncate(Literal.create(123.123, DoubleType), + NonFoldableLiteral.create(1, IntegerType)).checkInputDataTypes().isFailure) + assert(Truncate(Literal.create(123.123, DoubleType), + Literal.create(1, IntegerType)).checkInputDataTypes().isSuccess) + + def testDouble(input: Any, scale: Any, expected: Any): Unit = { checkEvaluation(Truncate(Literal.create(input, DoubleType), - Literal.create(fmt, IntegerType)), + Literal.create(scale, IntegerType)), expected) - checkEvaluation(Truncate(Literal.create(input, DoubleType), - NonFoldableLiteral.create(fmt, IntegerType)), + } + + def testFloat(input: Any, scale: Any, expected: Any): Unit = { + checkEvaluation(Truncate(Literal.create(input, FloatType), + Literal.create(scale, IntegerType)), + expected) + } + + def testDecimal(input: Any, scale: Any, expected: Any): Unit = { + checkEvaluation(Truncate(Literal.create(input, DecimalType.DoubleDecimal), + Literal.create(scale, IntegerType)), expected) } - testTruncate(1234567891.1234567891, 4, 1234567891.1234) - testTruncate(1234567891.1234567891, -4, 1234560000) - testTruncate(1234567891.1234567891, 0, 1234567891) - testTruncate(0.123, -1, 0) - testTruncate(0.123, 0, 0) - - checkEvaluation(Truncate(Literal.create(1D, DoubleType), - NonFoldableLiteral.create(null, IntegerType)), - null) - checkEvaluation(Truncate(Literal.create(null, DoubleType), - NonFoldableLiteral.create(1, IntegerType)), - null) - checkEvaluation(Truncate(Literal.create(null, DoubleType), - NonFoldableLiteral.create(null, IntegerType)), - null) + testDouble(1234567891.1234567891D, 4, 1234567891.1234D) + testDouble(1234567891.1234567891D, -4, 1234560000D) + testDouble(1234567891.1234567891D, 0, 1234567891D) + testDouble(0.123D, -1, 0D) + testDouble(0.123D, 0, 0D) + testDouble(null, null, null) + testDouble(null, 0, null) + testDouble(1D, null, null) + + testFloat(1234567891.1234567891F, 4, 1234567891.1234F) + testFloat(1234567891.1234567891F, -4, 1234560000F) + testFloat(1234567891.1234567891F, 0, 1234567891F) + testFloat(0.123F, -1, 0F) + testFloat(0.123F, 0, 0F) + testFloat(null, null, null) + testFloat(null, 0, null) + testFloat(1D, null, null) + + testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234)) + testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000)) + testDecimal(Decimal(1234567891.1234567891), 0, Decimal(1234567891)) + testDecimal(Decimal(0.123), -1, Decimal(0)) + testDecimal(Decimal(0.123), 0, Decimal(0)) + testDecimal(null, null, null) + testDecimal(null, 0, null) + testDecimal(1D, null, null) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 521d5710a959..6e2df69d5d87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2215,21 +2215,23 @@ object functions { def radians(columnName: String): Column = radians(Column(columnName)) /** - * Returns number truncated to the unit specified by the scale. + * Returns the value of the column `e` truncated to 0 places. * - * For example, `truncate(1234567891.1234567891, 4)` returns 1234567891.1234 - * - * @param number The number to be truncated - * @param scale: A scale used to truncate number + * @group math_funcs + * @since 2.4.0 + */ + def truncate(e: Column): Column = truncate(e, 0) + + /** + * Returns the value of column `e` truncated to the unit specified by the scale. + * If scale is omitted, then the value of column `e` is truncated to 0 places. + * Scale can be negative to truncate (make zero) scale digits left of the decimal point. * - * @return The number truncated to scale decimal places. - * If scale is omitted, then number is truncated to 0 places. - * scale can be negative to truncate (make zero) scale digits left of the decimal point. * @group math_funcs * @since 2.4.0 */ - def truncate(number: Column, scale: Int): Column = withExpr { - Truncate(number.expr, Literal(scale)) + def truncate(e: Column, scale: Int): Column = withExpr { + Truncate(e.expr, Literal(scale)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 70b43c48c45d..39efa356a502 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -101,4 +101,5 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4); select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4); select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4); -select truncate(cast(1234567891.1234567891 as long), 9.03) +select truncate(cast(1234567891.1234567891 as long), 9.03); +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 3464d7e0f84a..d27590193377 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 63 +-- Number of queries: 64 -- !query 0 @@ -516,3 +516,11 @@ select truncate(cast(1234567891.1234567891 as long), 9.03) struct -- !query 62 output 1.234567891E9 + + +-- !query 63 +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)) +-- !query 63 schema +struct +-- !query 63 output +1.234567891E9 1.23456794E9 1234567891 From 479b31fa046e8402f4f93cdbad5fe93ef1ea570f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 22 Sep 2018 01:08:02 +0800 Subject: [PATCH 4/6] Implements by BigDecimal.RoundingMode.DOWN --- .../expressions/mathExpressions.scala | 63 ++--------------- .../spark/sql/catalyst/util/MathUtils.scala | 70 ------------------- .../org/apache/spark/sql/types/Decimal.scala | 1 + .../expressions/MathExpressionsSuite.scala | 4 +- .../resources/sql-tests/inputs/operators.sql | 3 +- .../sql-tests/results/operators.sql.out | 50 ++++++++++++- 6 files changed, 59 insertions(+), 132 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 32fa036747e9..942a4d0f99a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter} +import org.apache.spark.sql.catalyst.util.NumberConverter import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1264,61 +1264,8 @@ case class BRound(child: Expression, scale: Expression) 1234567891 """) // scalastyle:on line.size.limit -case class Truncate(number: Expression, scale: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - def this(number: Expression) = this(number, Literal(0)) - - override def left: Expression = number - override def right: Expression = scale - - override def inputTypes: Seq[AbstractDataType] = - Seq(TypeCollection(DoubleType, FloatType, DecimalType), IntegerType) - - override def checkInputDataTypes(): TypeCheckResult = { - super.checkInputDataTypes() match { - case TypeCheckSuccess => - if (scale.foldable) { - TypeCheckSuccess - } else { - TypeCheckFailure("Only foldable Expression is allowed for scale arguments") - } - case f => f - } - } - - override def dataType: DataType = left.dataType - override def nullable: Boolean = true - override def prettyName: String = "truncate" - - private lazy val scaleV: Any = scale.eval(EmptyRow) - private lazy val _scale: Int = scaleV.asInstanceOf[Int] - - protected override def nullSafeEval(input1: Any, input2: Any): Any = { - number.dataType match { - case DoubleType => MathUtils.trunc(input1.asInstanceOf[Double], _scale) - case FloatType => MathUtils.trunc(input1.asInstanceOf[Float], _scale) - case DecimalType.Fixed(_, _) => MathUtils.trunc(input1.asInstanceOf[Decimal], _scale) - } - } - - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val mu = MathUtils.getClass.getName.stripSuffix("$") - - val javaType = CodeGenerator.javaType(dataType) - if (scaleV == null) { // if scale is null, no need to eval its child at all - ev.copy(code = code""" - boolean ${ev.isNull} = true; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};""") - } else { - val d = number.genCode(ctx) - ev.copy(code = code""" - ${d.code} - boolean ${ev.isNull} = ${d.isNull}; - $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${ev.value} = $mu.trunc(${d.value}, ${_scale}); - }""") - } - } +case class Truncate(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.DOWN, "ROUND_DOWN") + with Serializable with ImplicitCastInputTypes { + def this(child: Expression) = this(child, Literal(0)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala deleted file mode 100644 index 9f00af128666..000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.util - -import java.math.{BigDecimal => JBigDecimal} - -import org.apache.spark.sql.types.Decimal - -object MathUtils { - - /** - * Returns double type input truncated to scale decimal places. - */ - def trunc(input: Double, scale: Int): Double = { - trunc(JBigDecimal.valueOf(input), scale).toDouble - } - - /** - * Returns float type input truncated to scale decimal places. - */ - def trunc(input: Float, scale: Int): Float = { - trunc(JBigDecimal.valueOf(input), scale).toFloat - } - - /** - * Returns decimal type input truncated to scale decimal places. - */ - def trunc(input: Decimal, scale: Int): Decimal = { - trunc(input.toJavaBigDecimal, scale) - } - - /** - * Returns BigDecimal type input truncated to scale decimal places. - */ - def trunc(input: JBigDecimal, scale: Int): Decimal = { - // Copy from (https://github.com/apache/hive/blob/release-2.3.0-rc0 - // /ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDFTrunc.java#L471-L487) - val pow = if (scale >= 0) { - JBigDecimal.valueOf(Math.pow(10, scale)) - } else { - JBigDecimal.valueOf(Math.pow(10, Math.abs(scale))) - } - - val truncatedValue = if (scale > 0) { - val longValue = input.multiply(pow).longValue() - JBigDecimal.valueOf(longValue).divide(pow) - } else if (scale == 0) { - JBigDecimal.valueOf(input.longValue()) - } else { - val longValue = input.divide(pow).longValue() - JBigDecimal.valueOf(longValue).multiply(pow) - } - - Decimal(truncatedValue) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index 9eed2eb20204..b0ffb816817f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -413,6 +413,7 @@ object Decimal { val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR + val ROUND_DOWN = BigDecimal.RoundingMode.DOWN /** Maximum number of decimal digits an Int can represent */ val MAX_INT_DIGITS = 9 diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 320a2e821d8a..459075cf8193 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -685,7 +685,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testFloat(0.123F, 0, 0F) testFloat(null, null, null) testFloat(null, 0, null) - testFloat(1D, null, null) + testFloat(1F, null, null) testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234)) testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000)) @@ -694,6 +694,6 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testDecimal(Decimal(0.123), 0, Decimal(0)) testDecimal(null, null, null) testDecimal(null, 0, null) - testDecimal(1D, null, null) + testDecimal(Decimal(1), null, null) } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index d14a7fdc9d2c..3c8f30eaa8b9 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -94,7 +94,8 @@ select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(n select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); -- truncate -select truncate(1234567891.1234567891, -4), truncate(1234567891.1234567891, 0), truncate(1234567891.1234567891, 4); +select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4); +select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4); select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4); select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4); select truncate(cast(1234567891.1234567891 as long), 9.03); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index fd1d0db9e3f7..9fff062490e1 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 55 +-- Number of queries: 61 -- !query 0 @@ -452,3 +452,51 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) struct -- !query 54 output NULL NULL + + +-- !query 55 +select truncate(cast(1234567891.1234567891 as double), -4), truncate(cast(1234567891.1234567891 as double), 0), truncate(cast(1234567891.1234567891 as double), 4) +-- !query 55 schema +struct +-- !query 55 output +1.23456E9 1.234567891E9 1.2345678911234E9 + + +-- !query 56 +select truncate(cast(1234567891.1234567891 as float), -4), truncate(cast(1234567891.1234567891 as float), 0), truncate(cast(1234567891.1234567891 as float), 4) +-- !query 56 schema +struct +-- !query 56 output +1.23456E9 1.23456794E9 1.23456794E9 + + +-- !query 57 +select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(1234567891.1234567891 as decimal), 0), truncate(cast(1234567891.1234567891 as decimal), 4) +-- !query 57 schema +struct +-- !query 57 output +1234560000 1234567891 1234567891 + + +-- !query 58 +select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4) +-- !query 58 schema +struct +-- !query 58 output +1234560000 1234567891 1234567891 + + +-- !query 59 +select truncate(cast(1234567891.1234567891 as long), 9.03) +-- !query 59 schema +struct +-- !query 59 output +1234567891 + + +-- !query 60 +select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)) +-- !query 60 schema +struct +-- !query 60 output +1.234567891E9 1.23456794E9 1234567891 From b7e3460c2588be9f1f91259dc402de01ac87327c Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 22 Sep 2018 11:40:11 +0800 Subject: [PATCH 5/6] Add ROUND_DOWN to DecimalSuite --- .../scala/org/apache/spark/sql/types/Decimal.scala | 1 + .../catalyst/expressions/MathExpressionsSuite.scala | 3 +++ .../org/apache/spark/sql/types/DecimalSuite.scala | 2 +- .../src/test/resources/sql-tests/inputs/operators.sql | 1 + .../test/resources/sql-tests/results/operators.sql.out | 10 +++++++++- 5 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index b0ffb816817f..eea471ca8cd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -287,6 +287,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) { longVal += (if (droppedDigits < 0) -1L else 1L) } + case ROUND_DOWN => case _ => sys.error(s"Not supported rounding mode: $roundMode") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index 459075cf8193..b374424b711e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -677,6 +677,7 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testDouble(null, null, null) testDouble(null, 0, null) testDouble(1D, null, null) + testDouble(-1234567891.1234567891D, 4, -1234567891.1234D) testFloat(1234567891.1234567891F, 4, 1234567891.1234F) testFloat(1234567891.1234567891F, -4, 1234560000F) @@ -686,8 +687,10 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { testFloat(null, null, null) testFloat(null, 0, null) testFloat(1F, null, null) + testFloat(-1234567891.1234567891F, 4, -1234567891.1234F) testDecimal(Decimal(1234567891.1234567891), 4, Decimal(1234567891.1234)) + testDecimal(Decimal(-1234567891.1234567891), 4, Decimal(-1234567891.1234)) testDecimal(Decimal(1234567891.1234567891), -4, Decimal(1234560000)) testDecimal(Decimal(1234567891.1234567891), 0, Decimal(1234567891)) testDecimal(Decimal(0.123), -1, Decimal(0)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 10de90c6a44c..962b7049fe62 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -204,7 +204,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { } test("changePrecision/toPrecision on compact decimal should respect rounding mode") { - Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode => + Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN, ROUND_DOWN).foreach { mode => Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n => Seq("", "-").foreach { sign => val bd = BigDecimal(sign + n) diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 3c8f30eaa8b9..31d2a77240a2 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -100,3 +100,4 @@ select truncate(cast(1234567891.1234567891 as decimal), -4), truncate(cast(12345 select truncate(cast(1234567891.1234567891 as long), -4), truncate(cast(1234567891.1234567891 as long), 0), truncate(cast(1234567891.1234567891 as long), 4); select truncate(cast(1234567891.1234567891 as long), 9.03); select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891.1234567891 as float)), truncate(cast(1234567891.1234567891 as decimal)); +select truncate(cast(-1234567891.1234567891 as double), -4), truncate(cast(-1234567891.1234567891 as double), 4); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 9fff062490e1..46ade81086e6 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 61 +-- Number of queries: 62 -- !query 0 @@ -500,3 +500,11 @@ select truncate(cast(1234567891.1234567891 as double)), truncate(cast(1234567891 struct -- !query 60 output 1.234567891E9 1.23456794E9 1234567891 + + +-- !query 61 +select truncate(cast(-1234567891.1234567891 as double), -4), truncate(cast(-1234567891.1234567891 as double), 4) +-- !query 61 schema +struct +-- !query 61 output +-1.23456E9 -1.2345678911234E9 From ae7eb73a94815f396f576747712be72071257b28 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 26 Sep 2018 09:58:06 +0800 Subject: [PATCH 6/6] @since 2.4.0 -> @since 2.5.0 --- sql/core/src/main/scala/org/apache/spark/sql/functions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6e2df69d5d87..aed0faefd450 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2218,7 +2218,7 @@ object functions { * Returns the value of the column `e` truncated to 0 places. * * @group math_funcs - * @since 2.4.0 + * @since 2.5.0 */ def truncate(e: Column): Column = truncate(e, 0) @@ -2228,7 +2228,7 @@ object functions { * Scale can be negative to truncate (make zero) scale digits left of the decimal point. * * @group math_funcs - * @since 2.4.0 + * @since 2.5.0 */ def truncate(e: Column, scale: Int): Column = withExpr { Truncate(e.expr, Literal(scale))