From d16ae0fa75e31f3855de3481c71f4e65eb191f93 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Thu, 23 Feb 2017 23:35:16 -0800 Subject: [PATCH 1/7] [SPARK-17495] [SQL] Support Decimal type --- .../spark/sql/catalyst/expressions/hash.scala | 63 ++++++++++++++++++- .../expressions/HashExpressionsSuite.scala | 43 ++++++++++++- 2 files changed, 103 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2d9c2e42064b..52cca3a28748 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal, RoundingMode} import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 @@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction { * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution * - * TODO: Support Decimal and date related types + * TODO: Support date related types */ @ExpressionDescription( usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") @@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashBytes(b: String, result: String): String = s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + override protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, + input: String, + result: String): String = { + s""" + $result = org.apache.spark.sql.catalyst.expressions.HiveHashFunction.normalizeDecimal( + $input.toJavaBigDecimal(), true).hashCode();""" + } + override protected def genHashCalendarInterval(input: String, result: String): String = { s""" $result = (31 * $hasherClassName.hashInt($input.months)) + @@ -732,6 +743,51 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashUnsafeBytes(base, offset, len) } + private val HiveDecimalMaxPrecision = 38 + private val HiveDecimalMaxScale = 38 + + // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() + def normalizeDecimal(input: BigDecimal, allowRounding: Boolean): BigDecimal = { + if (input == null) { + return null + } + + def trimDecimal(input: BigDecimal) = { + var result = input + if (result.compareTo(BigDecimal.ZERO) == 0) { + // Special case for 0, because java doesn't strip zeros correctly on that number. + result = BigDecimal.ZERO + } + else { + result = result.stripTrailingZeros + if (result.scale < 0) { + // no negative scale decimals + result = result.setScale(0) + } + } + result + } + + var result = trimDecimal(input) + val intDigits = result.precision - result.scale + if (intDigits > HiveDecimalMaxPrecision) { + return null + } + + val maxScale = + Math.min(HiveDecimalMaxScale, Math.min(HiveDecimalMaxPrecision - intDigits, result.scale)) + if (result.scale > maxScale) { + if (allowRounding) { + result = result.setScale(maxScale, RoundingMode.HALF_UP) + // Trimming is again necessary, because rounding may introduce new trailing 0's. + result = trimDecimal(result) + } else { + result = null + } + } + result + } + override def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => 0 @@ -785,7 +841,10 @@ object HiveHashFunction extends InterpretedHashFunction { } result - case _ => super.hash(value, dataType, 0) + case d: Decimal => + normalizeDecimal(d.toJavaBigDecimal, allowRounding = true).hashCode() + + case _ => super.hash(value, dataType, seed) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0cb3a79eee67..5f1348e15643 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 val actual = HiveHashFunction.hash(input, dataType, seed = 0) @@ -371,6 +370,48 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) + test("hive-hash for decimal") { + def checkHiveHashForDecimal( + input: String, + precision: Int, + scale: Int, + expected: Long): Unit = { + val decimal = Decimal.apply(new java.math.BigDecimal(input)) + decimal.changePrecision(precision, scale) + val decimalType = DataTypes.createDecimalType(precision, scale) + checkHiveHash(decimal, decimalType, expected) + } + + checkHiveHashForDecimal("18", 38, 0, 558) + checkHiveHashForDecimal("-18", 38, 0, -558) + checkHiveHashForDecimal("-18", 38, 12, -558) + checkHiveHashForDecimal("18446744073709001000", 38, 19, -17070057) + checkHiveHashForDecimal("-18446744073709001000", 38, 22, 17070057) + checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) + checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) + checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) + checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656) + checkHiveHashForDecimal("00000.00000000000", 38, 34, 0) + checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0) + checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666) + } + test("SPARK-18207: Compute hash for a lot of expressions") { val N = 1000 val wideRow = new GenericInternalRow( From 7bc4eabdff7469766e29b96feac9e38f2266ea2f Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Tue, 28 Feb 2017 09:00:34 -0800 Subject: [PATCH 2/7] Handle un-successfull precision change in test cases --- .../catalyst/expressions/HashExpressionsSuite.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 5f1348e15643..0c77dc2709da 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -376,17 +376,20 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { precision: Int, scale: Int, expected: Long): Unit = { - val decimal = Decimal.apply(new java.math.BigDecimal(input)) - decimal.changePrecision(precision, scale) val decimalType = DataTypes.createDecimalType(precision, scale) + val decimal = { + val value = Decimal.apply(new java.math.BigDecimal(input)) + if (value.changePrecision(precision, scale)) value else null + } + checkHiveHash(decimal, decimalType, expected) } checkHiveHashForDecimal("18", 38, 0, 558) checkHiveHashForDecimal("-18", 38, 0, -558) checkHiveHashForDecimal("-18", 38, 12, -558) - checkHiveHashForDecimal("18446744073709001000", 38, 19, -17070057) - checkHiveHashForDecimal("-18446744073709001000", 38, 22, 17070057) + checkHiveHashForDecimal("18446744073709001000", 38, 19, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0) checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) From 2e4c3332fbcd9aa359e07ad8e35ab1d804f5a0dc Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Tue, 28 Feb 2017 13:28:50 -0800 Subject: [PATCH 3/7] review comment #2 --- .../scala/org/apache/spark/sql/catalyst/expressions/hash.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 52cca3a28748..d8f6d24647f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -642,7 +642,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String): String = { s""" - $result = org.apache.spark.sql.catalyst.expressions.HiveHashFunction.normalizeDecimal( + $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( $input.toJavaBigDecimal(), true).hashCode();""" } From c0c8390e0bcb706c474db66b6326ee403cc6c58c Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Tue, 28 Feb 2017 22:32:04 -0800 Subject: [PATCH 4/7] review round #3 --- .../org/apache/spark/sql/catalyst/expressions/hash.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index d8f6d24647f3..1340d21353f0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -748,17 +748,14 @@ object HiveHashFunction extends InterpretedHashFunction { // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() def normalizeDecimal(input: BigDecimal, allowRounding: Boolean): BigDecimal = { - if (input == null) { - return null - } + if (input == null) return null def trimDecimal(input: BigDecimal) = { var result = input if (result.compareTo(BigDecimal.ZERO) == 0) { // Special case for 0, because java doesn't strip zeros correctly on that number. result = BigDecimal.ZERO - } - else { + } else { result = result.stripTrailingZeros if (result.scale < 0) { // no negative scale decimals From f60ffb9f1b66ddd860b5c9da0c487ea2007ff981 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Sun, 5 Mar 2017 14:57:46 -0800 Subject: [PATCH 5/7] review #4 --- .../spark/sql/catalyst/expressions/hash.scala | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 1340d21353f0..fd176b625df5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -743,11 +743,11 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashUnsafeBytes(base, offset, len) } - private val HiveDecimalMaxPrecision = 38 - private val HiveDecimalMaxScale = 38 + private val HIVE_DECIMAL_MAX_PRECISION = 38 + private val HIVE_DECIMAL_MAX_SCALE = 38 // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() - def normalizeDecimal(input: BigDecimal, allowRounding: Boolean): BigDecimal = { + def normalizeDecimal(input: BigDecimal): BigDecimal = { if (input == null) return null def trimDecimal(input: BigDecimal) = { @@ -767,20 +767,16 @@ object HiveHashFunction extends InterpretedHashFunction { var result = trimDecimal(input) val intDigits = result.precision - result.scale - if (intDigits > HiveDecimalMaxPrecision) { + if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { return null } val maxScale = - Math.min(HiveDecimalMaxScale, Math.min(HiveDecimalMaxPrecision - intDigits, result.scale)) + Math.min(HIVE_DECIMAL_MAX_SCALE, Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) if (result.scale > maxScale) { - if (allowRounding) { - result = result.setScale(maxScale, RoundingMode.HALF_UP) - // Trimming is again necessary, because rounding may introduce new trailing 0's. - result = trimDecimal(result) - } else { - result = null - } + result = result.setScale(maxScale, RoundingMode.HALF_UP) + // Trimming is again necessary, because rounding may introduce new trailing 0's. + result = trimDecimal(result) } result } @@ -839,7 +835,7 @@ object HiveHashFunction extends InterpretedHashFunction { result case d: Decimal => - normalizeDecimal(d.toJavaBigDecimal, allowRounding = true).hashCode() + normalizeDecimal(d.toJavaBigDecimal).hashCode() case _ => super.hash(value, dataType, seed) } From 65a09e940484b64212262fe17888ede6c5d8cc14 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Sun, 5 Mar 2017 14:58:33 -0800 Subject: [PATCH 6/7] checkstlye --- .../org/apache/spark/sql/catalyst/expressions/hash.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index fd176b625df5..6ee432db3ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -771,8 +771,8 @@ object HiveHashFunction extends InterpretedHashFunction { return null } - val maxScale = - Math.min(HIVE_DECIMAL_MAX_SCALE, Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) + val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE, + Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) if (result.scale > maxScale) { result = result.setScale(maxScale, RoundingMode.HALF_UP) // Trimming is again necessary, because rounding may introduce new trailing 0's. From 7c0b6c849bb2b3869a9c91560d130bb884e1532b Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Mon, 6 Mar 2017 07:26:13 -0800 Subject: [PATCH 7/7] fix generated code --- .../scala/org/apache/spark/sql/catalyst/expressions/hash.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 6ee432db3ea1..03101b4bfc5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -643,7 +643,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { result: String): String = { s""" $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( - $input.toJavaBigDecimal(), true).hashCode();""" + $input.toJavaBigDecimal()).hashCode();""" } override protected def genHashCalendarInterval(input: String, result: String): String = {