Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.math.{BigDecimal, RoundingMode}
import java.security.{MessageDigest, NoSuchAlgorithmException}
import java.util.zip.CRC32

Expand Down Expand Up @@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction {
* We should use this hash function for both shuffle and bucket of Hive tables, so that
* we can guarantee shuffle and bucketing have same data distribution
*
* TODO: Support Decimal and date related types
* TODO: Support date related types
*/
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.")
Expand Down Expand Up @@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override protected def genHashBytes(b: String, result: String): String =
s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);"

override protected def genHashDecimal(
ctx: CodegenContext,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where do we use ctx and d?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They both aren't used but are a part of the method signature since the default impl in abstract class needs those :

d: DecimalType,
input: String,
result: String): String = {
s"""
$result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal(
$input.toJavaBigDecimal(), true).hashCode();"""
}

override protected def genHashCalendarInterval(input: String, result: String): String = {
s"""
$result = (31 * $hasherClassName.hashInt($input.months)) +
Expand Down Expand Up @@ -732,6 +743,48 @@ object HiveHashFunction extends InterpretedHashFunction {
HiveHasher.hashUnsafeBytes(base, offset, len)
}

private val HiveDecimalMaxPrecision = 38
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: HIVE_DECIMAL_MAX_PRECISION

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renamed

private val HiveDecimalMaxScale = 38

// Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize()
def normalizeDecimal(input: BigDecimal, allowRounding: Boolean): BigDecimal = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

allowRounding will never be false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed that param

if (input == null) return null

def trimDecimal(input: BigDecimal) = {
var result = input
if (result.compareTo(BigDecimal.ZERO) == 0) {
// Special case for 0, because java doesn't strip zeros correctly on that number.
result = BigDecimal.ZERO
} else {
result = result.stripTrailingZeros
if (result.scale < 0) {
// no negative scale decimals
result = result.setScale(0)
}
}
result
}

var result = trimDecimal(input)
val intDigits = result.precision - result.scale
if (intDigits > HiveDecimalMaxPrecision) {
return null
}

val maxScale =
Math.min(HiveDecimalMaxScale, Math.min(HiveDecimalMaxPrecision - intDigits, result.scale))
if (result.scale > maxScale) {
if (allowRounding) {
result = result.setScale(maxScale, RoundingMode.HALF_UP)
// Trimming is again necessary, because rounding may introduce new trailing 0's.
result = trimDecimal(result)
} else {
result = null
}
}
result
}

override def hash(value: Any, dataType: DataType, seed: Long): Long = {
value match {
case null => 0
Expand Down Expand Up @@ -785,7 +838,10 @@ object HiveHashFunction extends InterpretedHashFunction {
}
result

case _ => super.hash(value, dataType, 0)
case d: Decimal =>
normalizeDecimal(d.toJavaBigDecimal, allowRounding = true).hashCode()

case _ => super.hash(value, dataType, seed)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}


def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
// Note : All expected hashes need to be computed using Hive 1.2.1
val actual = HiveHashFunction.hash(input, dataType, seed = 0)
Expand Down Expand Up @@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))

test("hive-hash for decimal") {
def checkHiveHashForDecimal(
input: String,
precision: Int,
scale: Int,
expected: Long): Unit = {
val decimalType = DataTypes.createDecimalType(precision, scale)
val decimal = {
val value = Decimal.apply(new java.math.BigDecimal(input))
if (value.changePrecision(precision, scale)) value else null
}

checkHiveHash(decimal, decimalType, expected)
}

checkHiveHashForDecimal("18", 38, 0, 558)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm it's hard to guarantee that we can produce same hash value as hive, can we run hive in the test and compare the result with spark?

Copy link
Contributor Author

@tejasapatil tejasapatil Feb 28, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected values are generated using hive 1.2.1. My original approach was to depend on Hive for generating expected values but as per discussion in a related PR, I was suggested to hardcode expected values. The main point being reduce dependency on Hive

checkHiveHashForDecimal("-18", 38, 0, -558)
checkHiveHashForDecimal("-18", 38, 12, -558)
checkHiveHashForDecimal("18446744073709001000", 38, 19, 0)
checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0)
checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057)
checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057)
checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656)
checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656)
checkHiveHashForDecimal("00000.00000000000", 38, 34, 0)
checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0)
checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974)
checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252)
checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234)
checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136)
checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136)
checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666)
checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608)
checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666)
}

test("SPARK-18207: Compute hash for a lot of expressions") {
val N = 1000
val wideRow = new GenericInternalRow(
Expand Down