diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2d9c2e42064b..03101b4bfc5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal, RoundingMode} import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 @@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction { * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution * - * TODO: Support Decimal and date related types + * TODO: Support date related types */ @ExpressionDescription( usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") @@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashBytes(b: String, result: String): String = s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + override protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, + input: String, + result: String): String = { + s""" + $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( + $input.toJavaBigDecimal()).hashCode();""" + } + override protected def genHashCalendarInterval(input: String, result: String): String = { s""" $result = (31 * $hasherClassName.hashInt($input.months)) + @@ -732,6 +743,44 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashUnsafeBytes(base, offset, len) } + private val HIVE_DECIMAL_MAX_PRECISION = 38 + private val HIVE_DECIMAL_MAX_SCALE = 38 + + // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() + def normalizeDecimal(input: BigDecimal): BigDecimal = { + if (input == null) return null + + def trimDecimal(input: BigDecimal) = { + var result = input + if (result.compareTo(BigDecimal.ZERO) == 0) { + // Special case for 0, because java doesn't strip zeros correctly on that number. + result = BigDecimal.ZERO + } else { + result = result.stripTrailingZeros + if (result.scale < 0) { + // no negative scale decimals + result = result.setScale(0) + } + } + result + } + + var result = trimDecimal(input) + val intDigits = result.precision - result.scale + if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { + return null + } + + val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE, + Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) + if (result.scale > maxScale) { + result = result.setScale(maxScale, RoundingMode.HALF_UP) + // Trimming is again necessary, because rounding may introduce new trailing 0's. + result = trimDecimal(result) + } + result + } + override def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => 0 @@ -785,7 +834,10 @@ object HiveHashFunction extends InterpretedHashFunction { } result - case _ => super.hash(value, dataType, 0) + case d: Decimal => + normalizeDecimal(d.toJavaBigDecimal).hashCode() + + case _ => super.hash(value, dataType, seed) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0cb3a79eee67..0c77dc2709da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 val actual = HiveHashFunction.hash(input, dataType, seed = 0) @@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) + test("hive-hash for decimal") { + def checkHiveHashForDecimal( + input: String, + precision: Int, + scale: Int, + expected: Long): Unit = { + val decimalType = DataTypes.createDecimalType(precision, scale) + val decimal = { + val value = Decimal.apply(new java.math.BigDecimal(input)) + if (value.changePrecision(precision, scale)) value else null + } + + checkHiveHash(decimal, decimalType, expected) + } + + checkHiveHashForDecimal("18", 38, 0, 558) + checkHiveHashForDecimal("-18", 38, 0, -558) + checkHiveHashForDecimal("-18", 38, 12, -558) + checkHiveHashForDecimal("18446744073709001000", 38, 19, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) + checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) + checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) + checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656) + checkHiveHashForDecimal("00000.00000000000", 38, 34, 0) + checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0) + checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666) + } + test("SPARK-18207: Compute hash for a lot of expressions") { val N = 1000 val wideRow = new GenericInternalRow(