diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c7d0eba0964cc..4c89647077375 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -274,6 +274,7 @@ object FunctionRegistry { expression[Tan]("tan"), expression[Cot]("cot"), expression[Tanh]("tanh"), + expression[WidthBucket]("width_bucket"), expression[Add]("+"), expression[Subtract]("-"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 66e6334e3a450..a669d6050e118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.util.NumberConverter +import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -1319,3 +1319,123 @@ case class BRound(child: Expression, scale: Expression) with Serializable with ImplicitCastInputTypes { def this(child: Expression) = this(child, Literal(0)) } + +/** + * Returns the bucket number into which + * the value of this expression would fall after being evaluated. + * + * @param expr is the expression for which the histogram is being created + * @param minValue is an expression that resolves + * to the minimum end point of the acceptable range for expr + * @param maxValue is an expression that resolves + * to the maximum end point of the acceptable range for expr + * @param numBucket is an expression that resolves to + * a constant indicating the number of buckets + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr, min_value, max_value, num_bucket) - Returns the `bucket` to which operand would be assigned in an equidepth histogram with `num_bucket` buckets, in the range `min_value` to `max_value`.", + extended = """ + Examples: + > SELECT _FUNC_(5.35, 0.024, 10.06, 5); + 3 + """) +// scalastyle:on line.size.limit +case class WidthBucket( + expr: Expression, + minValue: Expression, + maxValue: Expression, + numBucket: Expression) extends Expression with ImplicitCastInputTypes { + + override def children: Seq[Expression] = Seq(expr, minValue, maxValue, numBucket) + override def foldable: Boolean = children.drop(1).forall(_.foldable) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType) + override def dataType: DataType = LongType + override def nullable: Boolean = true + + private lazy val _minValue: Any = minValue.eval() + private lazy val minValueV = _minValue.asInstanceOf[Double] + + private lazy val _maxValue: Any = maxValue.eval() + private lazy val maxValueV = _maxValue.asInstanceOf[Double] + + private lazy val _numBucket: Any = numBucket.eval() + private lazy val numBucketV = _numBucket.asInstanceOf[Long] + + private val errMsg = "The argument [%d] of WIDTH_BUCKET function is NULL or invalid." + + override def eval(input: InternalRow): Any = { + + if (foldable) { + if (_minValue == null) { + throw new RuntimeException(errMsg.format(2)) + } else if (_maxValue == null) { + throw new RuntimeException(errMsg.format(3)) + } else if (_numBucket == null || numBucketV <= 0) { + throw new RuntimeException(errMsg.format(4)) + } else { + val exprV = expr.eval(input) + if (exprV == null) { + null + } else { + MathUtils.widthBucket(exprV.asInstanceOf[Double], minValueV, maxValueV, numBucketV) + } + } + } else { + val evals = children.map(_.eval(input)) + val invalid = evals.zipWithIndex.filter { case (e, i) => + (i > 0 && e == null) || (i == 3 && e.asInstanceOf[Long] <= 0) + } + if (invalid.nonEmpty) { + invalid.foreach(l => throw new RuntimeException(errMsg.format(l._2 + 1))) + } else if (evals(0) == null) { + null + } else { + MathUtils.widthBucket( + evals(0).asInstanceOf[Double], + evals(1).asInstanceOf[Double], + evals(2).asInstanceOf[Double], + evals(3).asInstanceOf[Long]) + } + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val mathUtils = MathUtils.getClass.getName.stripSuffix("$") + if (foldable) { + val exprV = expr.genCode(ctx) + ev.copy(code = code""" + if (${_minValue == null}) { + throw new RuntimeException(String.format("$errMsg", 2)); + } else if (${_maxValue == null}) { + throw new RuntimeException(String.format("$errMsg", 3)); + } else if (${_numBucket == null || numBucketV <= 0}) { + throw new RuntimeException(String.format("$errMsg", 4)); + } + ${exprV.code} + boolean ${ev.isNull} = ${exprV.isNull}; + long ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = $mathUtils.widthBucket(${exprV.value}, $minValueV, $maxValueV, $numBucketV); + }""") + } else { + val evals = children.map(_.genCode(ctx)) + val invalid = evals.zipWithIndex.map { case (e, i) => + s""" + if (($i > 0 && ${e.isNull}) || ($i == 3 && ${e.value} < 0)) { + throw new RuntimeException(String.format("$errMsg", $i + 1)); + } + """} + + ev.copy(code = code""" + ${invalid.map(_.stripMargin).mkString("\n")} + boolean ${ev.isNull} = ${evals(0).isNull}; + long ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${ev.value} = $mathUtils.widthBucket( + ${evals(0).value}, ${evals(1).value}, ${evals(2).value}, ${evals(3).value}); + } + """) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala new file mode 100644 index 0000000000000..472c9c6f618b1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/MathUtils.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.sql.AnalysisException + +object MathUtils { + + /** + * Returns the bucket number into which + * the value of this expression would fall after being evaluated. + * + * @param expr is the expression for which the histogram is being created + * @param minValue is an expression that resolves + * to the minimum end point of the acceptable range for expr + * @param maxValue is an expression that resolves + * to the maximum end point of the acceptable range for expr + * @param numBucket is an expression that resolves to + * a constant indicating the number of buckets + * @return Returns an long between 0 and numBucket+1 by mapping the expr into buckets defined by + * the range [minValue, maxValue]. + */ + def widthBucket(expr: Double, minValue: Double, maxValue: Double, numBucket: Long): Long = { + val lower: Double = Math.min(minValue, maxValue) + val upper: Double = Math.max(minValue, maxValue) + + val result: Long = if (expr < lower) { + 0 + } else if (expr >= upper) { + numBucket + 1L + } else { + (numBucket.toDouble * (expr - lower) / (upper - lower) + 1).toLong + } + + if (minValue > maxValue) (numBucket - result) + 1 else result + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala index b4096f21bea3a..47127a4cf6e35 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathExpressionsSuite.scala @@ -677,4 +677,68 @@ class MathExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(BRound(-0.35, 1), -0.4) checkEvaluation(BRound(-35, -1), -40) } + + test("width_bucket") { + def test( + expr: Double, + minValue: Double, + maxValue: Double, + numBucket: Long, + expected: Long): Unit = { + checkEvaluation(WidthBucket(Literal.create(expr, DoubleType), + Literal.create(minValue, DoubleType), + Literal.create(maxValue, DoubleType), + Literal.create(numBucket, LongType)), + expected) + } + + test(5.35, 0.024, 10.06, 5, 3) + + test(3.14, 0, 4, 3, 3) + test(2, 0, 4, 3, 2) + + test(3.14, 4, 0, 3, 1) + test(2, 4, 0, 3, 2) + + // an underflow bucket numbered 0 + test(-1, 0, 3.2, 4, 0) + test(1, 2, 3, 4, 0) + + // an overflow bucket numbered num_buckets + 1 + test(-1, 3.2, 0, 4, 5) + test(3, 2, 3, 2, 3) + + // invalid argument + val e1 = intercept[RuntimeException] { + WidthBucket(Literal.create(1.0, DoubleType), + Literal.create(null, DoubleType), + Literal.create(2.0, DoubleType), + Literal.create(5L, LongType)).eval(EmptyRow) + } + assert(e1.getMessage.contains("The argument [2] of WIDTH_BUCKET function is NULL or invalid.")) + + val e2 = intercept[RuntimeException] { + WidthBucket(Literal.create(1.0, DoubleType), + Literal.create(1.0, DoubleType), + Literal.create(null, DoubleType), + Literal.create(5L, LongType)).eval(EmptyRow) + } + assert(e2.getMessage.contains("The argument [3] of WIDTH_BUCKET function is NULL or invalid.")) + + val e3 = intercept[RuntimeException] { + WidthBucket(Literal.create(1.0, DoubleType), + Literal.create(1.0, DoubleType), + Literal.create(2.0, DoubleType), + Literal.create(null, LongType)).eval(EmptyRow) + } + assert(e3.getMessage.contains("The argument [4] of WIDTH_BUCKET function is NULL or invalid.")) + + val e4 = intercept[RuntimeException] { + WidthBucket(Literal.create(1.0, DoubleType), + Literal.create(1.0, DoubleType), + Literal.create(2.0, DoubleType), + Literal.create(-1L, LongType)).eval(EmptyRow) + } + assert(e4.getMessage.contains("The argument [4] of WIDTH_BUCKET function is NULL or invalid.")) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala new file mode 100644 index 0000000000000..a2f98a11b6829 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MathUtilsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.MathUtils._ + +class MathUtilsSuite extends SparkFunSuite { + + test("widthBucket") { + assert(widthBucket(5.35, 0.024, 10.06, 5) === 3) + assert(widthBucket(0, 1, 1, 1) === 0) + assert(widthBucket(20, 1, 1, 1) === 2) + + // Test https://docs.oracle.com/cd/B28359_01/olap.111/b28126/dml_functions_2137.htm#OLADM717 + // WIDTH_BUCKET(credit_limit, 100, 5000, 10) + assert(widthBucket(500, 100, 5000, 10) === 1) + assert(widthBucket(2300, 100, 5000, 10) === 5) + assert(widthBucket(3500, 100, 5000, 10) === 7) + assert(widthBucket(1200, 100, 5000, 10) === 3) + assert(widthBucket(1400, 100, 5000, 10) === 3) + assert(widthBucket(700, 100, 5000, 10) === 2) + assert(widthBucket(5000, 100, 5000, 10) === 11) + assert(widthBucket(1800, 100, 5000, 10) === 4) + assert(widthBucket(400, 100, 5000, 10) === 1) + + // minValue == maxValue + assert(widthBucket(10, 4, 4, 15) === 16) + } +} diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index 1e22ae2eefeb2..37eb0cfa04339 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -1,8 +1,8 @@ ## Summary - - Number of queries: 333 - - Number of expressions that missing example: 34 - - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,count_min_sketch + - Number of queries: 334 + - Number of expressions that missing example: 35 + - Expressions missing examples: and,string,tinyint,double,smallint,date,decimal,boolean,float,binary,bigint,int,timestamp,cume_dist,dense_rank,input_file_block_length,input_file_block_start,input_file_name,lag,lead,monotonically_increasing_id,ntile,struct,!,not,or,percent_rank,rank,row_number,spark_partition_id,version,window,positive,width_bucket,count_min_sketch ## Schema of Built-in Functions | Class name | Function name or alias | Query example | Output schema | | ---------- | ---------------------- | ------------- | ------------- | @@ -287,6 +287,7 @@ | org.apache.spark.sql.catalyst.expressions.Uuid | uuid | SELECT uuid() | struct | | org.apache.spark.sql.catalyst.expressions.WeekDay | weekday | SELECT weekday('2009-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.WeekOfYear | weekofyear | SELECT weekofyear('2008-02-20') | struct | +| org.apache.spark.sql.catalyst.expressions.WidthBucket | width_bucket | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.XxHash64 | xxhash64 | SELECT xxhash64('Spark', array(123), 2) | struct | | org.apache.spark.sql.catalyst.expressions.Year | year | SELECT year('2016-07-30') | struct | | org.apache.spark.sql.catalyst.expressions.ZipWith | zip_with | SELECT zip_with(array(1, 2, 3), array('a', 'b', 'c'), (x, y) -> (y, x)) | struct>> | diff --git a/sql/core/src/test/resources/sql-tests/inputs/operators.sql b/sql/core/src/test/resources/sql-tests/inputs/operators.sql index 20bf0eb15c5b2..fabe0a5768471 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/operators.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/operators.sql @@ -81,3 +81,15 @@ select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); -- pmod select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null); select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); + +-- width_bucket +select width_bucket(5.35, 0.024, 10.06, 5); +select width_bucket(5.35, 0.024, 10.06, 3 + 2); +select width_bucket('5.35', '0.024', '10.06', '5'); +select width_bucket(5.35, 0.024, 10.06, 2.5); +select width_bucket(5.35, 0.024, 10.06, 0.5); +select width_bucket(null, 0.024, 10.06, 5); +select width_bucket(5.35, null, 10.06, 5); +select width_bucket(5.35, 0.024, null, -5); +select width_bucket(5.35, 0.024, 10.06, null); +select width_bucket(5.35, 0.024, 10.06, -5); diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index cf857cf9f98ad..a6ab37b22d615 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 57 +-- Number of queries: 67 -- !query @@ -456,3 +456,88 @@ select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint) struct -- !query output NULL NULL + + +-- !query +select width_bucket(5.35, 0.024, 10.06, 5) +-- !query schema +struct +-- !query output +3 + + +-- !query +select width_bucket(5.35, 0.024, 10.06, 3 + 2) +-- !query schema +struct +-- !query output +3 + + +-- !query +select width_bucket('5.35', '0.024', '10.06', '5') +-- !query schema +struct +-- !query output +3 + + +-- !query +select width_bucket(5.35, 0.024, 10.06, 2.5) +-- !query schema +struct +-- !query output +2 + + +-- !query +select width_bucket(5.35, 0.024, 10.06, 0.5) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +The argument [4] of WIDTH_BUCKET function is NULL or invalid. + + +-- !query +select width_bucket(null, 0.024, 10.06, 5) +-- !query schema +struct +-- !query output +NULL + + +-- !query +select width_bucket(5.35, null, 10.06, 5) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +The argument [2] of WIDTH_BUCKET function is NULL or invalid. + + +-- !query +select width_bucket(5.35, 0.024, null, -5) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +The argument [3] of WIDTH_BUCKET function is NULL or invalid. + + +-- !query +select width_bucket(5.35, 0.024, 10.06, null) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +The argument [4] of WIDTH_BUCKET function is NULL or invalid. + + +-- !query +select width_bucket(5.35, 0.024, 10.06, -5) +-- !query schema +struct<> +-- !query output +java.lang.RuntimeException +The argument [4] of WIDTH_BUCKET function is NULL or invalid.