-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-21117][SQL] Built-in SQL Function Support - WIDTH_BUCKET #18323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
20fe567
cb682a9
3a0eaf7
3a5d46e
7407541
fda866f
5c3ecee
9e2cabb
507fcfb
099db6d
01af62b
0355284
0940a49
e0478f5
2e2b2ca
ea5e4ae
31ee943
44ef7df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult | |
| import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.Block._ | ||
| import org.apache.spark.sql.catalyst.util.NumberConverter | ||
| import org.apache.spark.sql.catalyst.util.{MathUtils, NumberConverter} | ||
| import org.apache.spark.sql.types._ | ||
| import org.apache.spark.unsafe.types.UTF8String | ||
|
|
||
|
|
@@ -1319,3 +1319,123 @@ case class BRound(child: Expression, scale: Expression) | |
| with Serializable with ImplicitCastInputTypes { | ||
| def this(child: Expression) = this(child, Literal(0)) | ||
| } | ||
|
|
||
| /** | ||
| * Returns the bucket number into which | ||
| * the value of this expression would fall after being evaluated. | ||
| * | ||
| * @param expr is the expression for which the histogram is being created | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this comment correct? How about "the expression for which the bucket number in the histogram would return"?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| * @param minValue is an expression that resolves | ||
| * to the minimum end point of the acceptable range for expr | ||
| * @param maxValue is an expression that resolves | ||
| * to the maximum end point of the acceptable range for expr | ||
| * @param numBucket is an expression that resolves to | ||
| * a constant indicating the number of buckets | ||
| */ | ||
| // scalastyle:off line.size.limit | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(expr, min_value, max_value, num_bucket) - Returns the `bucket` to which operand would be assigned in an equidepth histogram with `num_bucket` buckets, in the range `min_value` to `max_value`.", | ||
| extended = """ | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| Examples: | ||
| > SELECT _FUNC_(5.35, 0.024, 10.06, 5); | ||
| 3 | ||
| """) | ||
| // scalastyle:on line.size.limit | ||
| case class WidthBucket( | ||
| expr: Expression, | ||
| minValue: Expression, | ||
| maxValue: Expression, | ||
| numBucket: Expression) extends Expression with ImplicitCastInputTypes { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cloud we do |
||
|
|
||
| override def children: Seq[Expression] = Seq(expr, minValue, maxValue, numBucket) | ||
| override def foldable: Boolean = children.drop(1).forall(_.foldable) | ||
| override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType, DoubleType, DoubleType, LongType) | ||
| override def dataType: DataType = LongType | ||
| override def nullable: Boolean = true | ||
|
|
||
| private lazy val _minValue: Any = minValue.eval() | ||
| private lazy val minValueV = _minValue.asInstanceOf[Double] | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if |
||
|
|
||
| private lazy val _maxValue: Any = maxValue.eval() | ||
| private lazy val maxValueV = _maxValue.asInstanceOf[Double] | ||
|
|
||
| private lazy val _numBucket: Any = numBucket.eval() | ||
| private lazy val numBucketV = _numBucket.asInstanceOf[Long] | ||
|
|
||
| private val errMsg = "The argument [%d] of WIDTH_BUCKET function is NULL or invalid." | ||
|
|
||
| override def eval(input: InternalRow): Any = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why you create |
||
|
|
||
| if (foldable) { | ||
| if (_minValue == null) { | ||
| throw new RuntimeException(errMsg.format(2)) | ||
| } else if (_maxValue == null) { | ||
| throw new RuntimeException(errMsg.format(3)) | ||
| } else if (_numBucket == null || numBucketV <= 0) { | ||
| throw new RuntimeException(errMsg.format(4)) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the above handling for null valid? Except for null-intolerant expressions, usually any input of an expression is null, it returns null.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } else { | ||
| val exprV = expr.eval(input) | ||
| if (exprV == null) { | ||
| null | ||
| } else { | ||
| MathUtils.widthBucket(exprV.asInstanceOf[Double], minValueV, maxValueV, numBucketV) | ||
| } | ||
| } | ||
| } else { | ||
| val evals = children.map(_.eval(input)) | ||
| val invalid = evals.zipWithIndex.filter { case (e, i) => | ||
| (i > 0 && e == null) || (i == 3 && e.asInstanceOf[Long] <= 0) | ||
| } | ||
| if (invalid.nonEmpty) { | ||
| invalid.foreach(l => throw new RuntimeException(errMsg.format(l._2 + 1))) | ||
| } else if (evals(0) == null) { | ||
| null | ||
| } else { | ||
| MathUtils.widthBucket( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| evals(0).asInstanceOf[Double], | ||
| evals(1).asInstanceOf[Double], | ||
| evals(2).asInstanceOf[Double], | ||
| evals(3).asInstanceOf[Long]) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto. As you override |
||
| val mathUtils = MathUtils.getClass.getName.stripSuffix("$") | ||
| if (foldable) { | ||
| val exprV = expr.genCode(ctx) | ||
| ev.copy(code = code""" | ||
| if (${_minValue == null}) { | ||
| throw new RuntimeException(String.format("$errMsg", 2)); | ||
| } else if (${_maxValue == null}) { | ||
| throw new RuntimeException(String.format("$errMsg", 3)); | ||
| } else if (${_numBucket == null || numBucketV <= 0}) { | ||
| throw new RuntimeException(String.format("$errMsg", 4)); | ||
| } | ||
| ${exprV.code} | ||
| boolean ${ev.isNull} = ${exprV.isNull}; | ||
| long ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| if (!${ev.isNull}) { | ||
| ${ev.value} = $mathUtils.widthBucket(${exprV.value}, $minValueV, $maxValueV, $numBucketV); | ||
| }""") | ||
| } else { | ||
| val evals = children.map(_.genCode(ctx)) | ||
| val invalid = evals.zipWithIndex.map { case (e, i) => | ||
| s""" | ||
| if (($i > 0 && ${e.isNull}) || ($i == 3 && ${e.value} < 0)) { | ||
| throw new RuntimeException(String.format("$errMsg", $i + 1)); | ||
| } | ||
| """} | ||
|
|
||
| ev.copy(code = code""" | ||
| ${invalid.map(_.stripMargin).mkString("\n")} | ||
| boolean ${ev.isNull} = ${evals(0).isNull}; | ||
| long ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; | ||
| if (!${evals(0).isNull}) { | ||
| ${ev.value} = $mathUtils.widthBucket( | ||
| ${evals(0).value}, ${evals(1).value}, ${evals(2).value}, ${evals(3).value}); | ||
| } | ||
| """) | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.util | ||
|
|
||
| import org.apache.spark.sql.AnalysisException | ||
|
|
||
| object MathUtils { | ||
|
|
||
| /** | ||
| * Returns the bucket number into which | ||
| * the value of this expression would fall after being evaluated. | ||
| * | ||
| * @param expr is the expression for which the histogram is being created | ||
| * @param minValue is an expression that resolves | ||
| * to the minimum end point of the acceptable range for expr | ||
| * @param maxValue is an expression that resolves | ||
| * to the maximum end point of the acceptable range for expr | ||
| * @param numBucket is an expression that resolves to | ||
| * a constant indicating the number of buckets | ||
| * @return Returns an long between 0 and numBucket+1 by mapping the expr into buckets defined by | ||
| * the range [minValue, maxValue]. | ||
| */ | ||
| def widthBucket(expr: Double, minValue: Double, maxValue: Double, numBucket: Long): Long = { | ||
| val lower: Double = Math.min(minValue, maxValue) | ||
| val upper: Double = Math.max(minValue, maxValue) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does other databases allow max value to appear first? i.e.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, Oracle support it. |
||
|
|
||
| val result: Long = if (expr < lower) { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // an underflow bucket numbered 0
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added 2 test case: |
||
| 0 | ||
| } else if (expr >= upper) { | ||
| numBucket + 1L | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. // an overflow bucket numbered num_buckets+1
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added 2 test case: |
||
| } else { | ||
| (numBucket.toDouble * (expr - lower) / (upper - lower) + 1).toLong | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what if upper == lower? |
||
| } | ||
|
|
||
| if (minValue > maxValue) (numBucket - result) + 1 else result | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,46 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.catalyst.util | ||
|
|
||
| import org.apache.spark.SparkFunSuite | ||
| import org.apache.spark.sql.AnalysisException | ||
| import org.apache.spark.sql.catalyst.util.MathUtils._ | ||
|
|
||
| class MathUtilsSuite extends SparkFunSuite { | ||
|
|
||
| test("widthBucket") { | ||
| assert(widthBucket(5.35, 0.024, 10.06, 5) === 3) | ||
| assert(widthBucket(0, 1, 1, 1) === 0) | ||
| assert(widthBucket(20, 1, 1, 1) === 2) | ||
|
|
||
| // Test https://docs.oracle.com/cd/B28359_01/olap.111/b28126/dml_functions_2137.htm#OLADM717 | ||
| // WIDTH_BUCKET(credit_limit, 100, 5000, 10) | ||
| assert(widthBucket(500, 100, 5000, 10) === 1) | ||
| assert(widthBucket(2300, 100, 5000, 10) === 5) | ||
| assert(widthBucket(3500, 100, 5000, 10) === 7) | ||
| assert(widthBucket(1200, 100, 5000, 10) === 3) | ||
| assert(widthBucket(1400, 100, 5000, 10) === 3) | ||
| assert(widthBucket(700, 100, 5000, 10) === 2) | ||
| assert(widthBucket(5000, 100, 5000, 10) === 11) | ||
| assert(widthBucket(1800, 100, 5000, 10) === 4) | ||
| assert(widthBucket(400, 100, 5000, 10) === 1) | ||
|
|
||
| // minValue == maxValue | ||
| assert(widthBucket(10, 4, 4, 15) === 16) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this behavior consistent with other databases? If so, then I'm fine with this and please ignore my previous comment.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! |
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -81,3 +81,15 @@ select positive('-1.11'), positive(-1.11), negative('-1.11'), negative(-1.11); | |
| -- pmod | ||
| select pmod(-7, 2), pmod(0, 2), pmod(7, 0), pmod(7, null), pmod(null, 2), pmod(null, null); | ||
| select pmod(cast(3.13 as decimal), cast(0 as decimal)), pmod(cast(2 as smallint), cast(0 as smallint)); | ||
|
|
||
| -- width_bucket | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need end-to-end tests here? I think we already cover these cases in other test suites.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead, we need to add a test for sql queries using this function. |
||
| select width_bucket(5.35, 0.024, 10.06, 5); | ||
| select width_bucket(5.35, 0.024, 10.06, 3 + 2); | ||
| select width_bucket('5.35', '0.024', '10.06', '5'); | ||
| select width_bucket(5.35, 0.024, 10.06, 2.5); | ||
| select width_bucket(5.35, 0.024, 10.06, 0.5); | ||
| select width_bucket(null, 0.024, 10.06, 5); | ||
| select width_bucket(5.35, null, 10.06, 5); | ||
| select width_bucket(5.35, 0.024, null, -5); | ||
| select width_bucket(5.35, 0.024, 10.06, null); | ||
| select width_bucket(5.35, 0.024, 10.06, -5); | ||


Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: how about the format and the rephrasing below?