-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-17495] [SQL] Support Decimal type in Hive-hash #17056
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
d16ae0f
7bc4eab
2e4c333
c0c8390
f60ffb9
65a09e9
7c0b6c8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
|
|
||
| package org.apache.spark.sql.catalyst.expressions | ||
|
|
||
| import java.math.{BigDecimal, RoundingMode} | ||
| import java.security.{MessageDigest, NoSuchAlgorithmException} | ||
| import java.util.zip.CRC32 | ||
|
|
||
|
|
@@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction { | |
| * We should use this hash function for both shuffle and bucket of Hive tables, so that | ||
| * we can guarantee shuffle and bucketing have same data distribution | ||
| * | ||
| * TODO: Support Decimal and date related types | ||
| * TODO: Support date related types | ||
| */ | ||
| @ExpressionDescription( | ||
| usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") | ||
|
|
@@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { | |
| override protected def genHashBytes(b: String, result: String): String = | ||
| s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" | ||
|
|
||
| override protected def genHashDecimal( | ||
| ctx: CodegenContext, | ||
| d: DecimalType, | ||
| input: String, | ||
| result: String): String = { | ||
| s""" | ||
| $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( | ||
| $input.toJavaBigDecimal(), true).hashCode();""" | ||
| } | ||
|
|
||
| override protected def genHashCalendarInterval(input: String, result: String): String = { | ||
| s""" | ||
| $result = (31 * $hasherClassName.hashInt($input.months)) + | ||
|
|
@@ -732,6 +743,48 @@ object HiveHashFunction extends InterpretedHashFunction { | |
| HiveHasher.hashUnsafeBytes(base, offset, len) | ||
| } | ||
|
|
||
| private val HiveDecimalMaxPrecision = 38 | ||
|
||
| private val HiveDecimalMaxScale = 38 | ||
|
|
||
| // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() | ||
| def normalizeDecimal(input: BigDecimal, allowRounding: Boolean): BigDecimal = { | ||
|
||
| if (input == null) return null | ||
|
|
||
| def trimDecimal(input: BigDecimal) = { | ||
| var result = input | ||
| if (result.compareTo(BigDecimal.ZERO) == 0) { | ||
| // Special case for 0, because java doesn't strip zeros correctly on that number. | ||
| result = BigDecimal.ZERO | ||
| } else { | ||
| result = result.stripTrailingZeros | ||
| if (result.scale < 0) { | ||
| // no negative scale decimals | ||
| result = result.setScale(0) | ||
| } | ||
| } | ||
| result | ||
| } | ||
|
|
||
| var result = trimDecimal(input) | ||
| val intDigits = result.precision - result.scale | ||
| if (intDigits > HiveDecimalMaxPrecision) { | ||
| return null | ||
| } | ||
|
|
||
| val maxScale = | ||
| Math.min(HiveDecimalMaxScale, Math.min(HiveDecimalMaxPrecision - intDigits, result.scale)) | ||
| if (result.scale > maxScale) { | ||
| if (allowRounding) { | ||
| result = result.setScale(maxScale, RoundingMode.HALF_UP) | ||
| // Trimming is again necessary, because rounding may introduce new trailing 0's. | ||
| result = trimDecimal(result) | ||
| } else { | ||
| result = null | ||
| } | ||
| } | ||
| result | ||
| } | ||
|
|
||
| override def hash(value: Any, dataType: DataType, seed: Long): Long = { | ||
| value match { | ||
| case null => 0 | ||
|
|
@@ -785,7 +838,10 @@ object HiveHashFunction extends InterpretedHashFunction { | |
| } | ||
| result | ||
|
|
||
| case _ => super.hash(value, dataType, 0) | ||
| case d: Decimal => | ||
| normalizeDecimal(d.toJavaBigDecimal, allowRounding = true).hashCode() | ||
|
|
||
| case _ => super.hash(value, dataType, seed) | ||
| } | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { | |
| checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) | ||
| } | ||
|
|
||
|
|
||
| def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { | ||
| // Note : All expected hashes need to be computed using Hive 1.2.1 | ||
| val actual = HiveHashFunction.hash(input, dataType, seed = 0) | ||
|
|
@@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { | |
| new StructType().add("array", arrayOfString).add("map", mapOfString)) | ||
| .add("structOfUDT", structOfUDT)) | ||
|
|
||
| test("hive-hash for decimal") { | ||
| def checkHiveHashForDecimal( | ||
| input: String, | ||
| precision: Int, | ||
| scale: Int, | ||
| expected: Long): Unit = { | ||
| val decimalType = DataTypes.createDecimalType(precision, scale) | ||
| val decimal = { | ||
| val value = Decimal.apply(new java.math.BigDecimal(input)) | ||
| if (value.changePrecision(precision, scale)) value else null | ||
| } | ||
|
|
||
| checkHiveHash(decimal, decimalType, expected) | ||
| } | ||
|
|
||
| checkHiveHashForDecimal("18", 38, 0, 558) | ||
|
||
| checkHiveHashForDecimal("-18", 38, 0, -558) | ||
| checkHiveHashForDecimal("-18", 38, 12, -558) | ||
| checkHiveHashForDecimal("18446744073709001000", 38, 19, 0) | ||
| checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0) | ||
| checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) | ||
| checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) | ||
| checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) | ||
| checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656) | ||
| checkHiveHashForDecimal("00000.00000000000", 38, 34, 0) | ||
| checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0) | ||
| checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974) | ||
| checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) | ||
| checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252) | ||
| checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234) | ||
| checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136) | ||
| checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136) | ||
| checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) | ||
| checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234) | ||
| checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136) | ||
| checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136) | ||
| checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252) | ||
| checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234) | ||
| checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582) | ||
| checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544) | ||
| checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666) | ||
| checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608) | ||
| checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666) | ||
| } | ||
|
|
||
| test("SPARK-18207: Compute hash for a lot of expressions") { | ||
| val N = 1000 | ||
| val wideRow = new GenericInternalRow( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do we use
ctxandd?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They both aren't used but are a part of the method signature since the default impl in abstract class needs those :
spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
Line 321 in 3e40f6c