diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 604c52713e972..4490a86a762f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -764,12 +764,42 @@ case class Round(child: Expression, scale: Expression) }""" } case DoubleType => // if child eval to NaN or Infinity, just return it. + // The logic for rounding half-integers to even values is exemplified by the following + // table: + // + // x | x rounded to half-even | x * 2 | (x rounded to half-even) * 2 | (x * 2) & 3 + // ---------------------------------------------------------------------------------------- + // -4.5 | -4 | -9 | -8 | 3 + // -3.5 | -4 | -7 | -8 | 1 + // -2.5 | -2 | -5 | -6 | 3 + // -1.5 | -2 | -3 | -6 | 1 + // -0.5 | 0 | -1 | 0 | 3 + // 0.5 | 0 | 1 | 0 | 1 + // 1.5 | 2 | 3 | 4 | 3 + // 2.5 | 2 | 5 | 4 | 1 + // 3.5 | 4 | 7 | 8 | 3 + // 4.5 | 4 | 9 | 8 | 1 + // + // Therefore, looking at the last three columns above, if x has the form of ".5", + // then + // (x rounded to half-even) * 2 = (x * 2) + ((x * 2) & 3) - 2 + if (_scale == 0) { s""" if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){ ${ev.primitive} = ${ce.primitive}; } else { - ${ev.primitive} = Math.round(${ce.primitive}); + double timesTwo = ${ce.primitive} * 2; + long timesTwoRounded = Math.round(timesTwo); + if (timesTwo == timesTwoRounded) { + if ((timesTwoRounded & 1) == 0) { + ${ev.primitive} = timesTwoRounded >> 1; + } else { + ${ev.primitive} = (timesTwoRounded + (timesTwoRounded & 3) - 2) >> 1; + } + } else { + ${ev.primitive} = Math.round(${ce.primitive}); + } }""" } else { s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 90c59f240b542..45e728e8483ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -512,7 +512,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { Seq.fill[Short](7)(31415) val intResults: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, - 314159270) ++ Seq.fill(7)(314159265) + 314159260) ++ Seq.fill(7)(314159265) val longResults: Seq[Long] = Seq(31415926536000000L, 31415926535900000L, 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++