diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index f653890f6c7ba..83020256ac462 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -180,16 +180,20 @@ class WindowSpec private[sql]( private def between(typ: FrameType, start: Long, end: Long): WindowSpec = { val boundaryStart = start match { case 0 => CurrentRow - case Long.MinValue => UnboundedPreceding - case x if x < 0 => ValuePreceding(-start.toInt) - case x if x > 0 => ValueFollowing(start.toInt) + case x if x < Int.MinValue => UnboundedPreceding + case x if x < 0 && x >= Int.MinValue => ValuePreceding(-start.toInt) + case x if x > 0 && x <= Int.MaxValue => ValueFollowing(start.toInt) + case _ => throw new IllegalArgumentException(s"Boundary start($start) should not be " + + s"larger than Int.MaxValue(${Int.MaxValue}).") } val boundaryEnd = end match { case 0 => CurrentRow - case Long.MaxValue => UnboundedFollowing - case x if x < 0 => ValuePreceding(-end.toInt) - case x if x > 0 => ValueFollowing(end.toInt) + case x if x > Int.MaxValue => UnboundedFollowing + case x if x < 0 && x >= Int.MinValue => ValuePreceding(-end.toInt) + case x if x > 0 && x <= Int.MaxValue => ValueFollowing(end.toInt) + case _ => throw new IllegalArgumentException(s"Boundary end($end) should not be " + + s"less than Int.MinValue(${Int.MinValue}).") } new WindowSpec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c49104718..44208ad4ce561 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -423,4 +423,118 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-19451: Underlying integer overflow in Window function") { + val df = Seq((1L, "a"), (1L, "a"), (2L, "a"), (1L, "b"), (2L, "b"), (3L, "b")) + .toDF("id", "category") + df.createOrReplaceTempView("window_table") + + // range frames + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, -1))), + Seq( + Row(1, "b", null), Row(2, "b", 1), Row(3, "b", 3), + Row(1, "a", null), Row(1, "a", null), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, 0))), + Seq( + Row(1, "b", 1), Row(2, "b", 3), Row(3, "b", 6), + Row(1, "a", 2), Row(1, "a", 2), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, 2))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-2160000000L, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(-1, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 5), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(0, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 5), Row(3, "b", 3), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rangeBetween(2, 2160000000L))), + Seq( + Row(1, "b", 3), Row(2, "b", null), Row(3, "b", null), + Row(1, "a", null), Row(1, "a", null), Row(2, "a", null))) + + // row frames + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, -1))), + Seq( + Row(1, "b", null), Row(2, "b", 1), Row(3, "b", 3), + Row(1, "a", null), Row(1, "a", 1), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, 0))), + Seq( + Row(1, "b", 1), Row(2, "b", 3), Row(3, "b", 6), + Row(1, "a", 1), Row(1, "a", 2), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, 2))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-2160000000L, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 6), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 4))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-1, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 6), Row(3, "b", 5), + Row(1, "a", 4), Row(1, "a", 4), Row(2, "a", 3))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(0, 2160000000L))), + Seq( + Row(1, "b", 6), Row(2, "b", 5), Row(3, "b", 3), + Row(1, "a", 4), Row(1, "a", 3), Row(2, "a", 2))) + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(2, 2160000000L))), + Seq( + Row(1, "b", 3), Row(2, "b", null), Row(3, "b", null), + Row(1, "a", 2), Row(1, "a", null), Row(2, "a", null))) + try { + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(-3160000000L, -2160000000L))), + Seq()) + assert(false, "Boundary end should not be smaller than Int.MinValue(-2147483648).") + } catch { + case e: IllegalArgumentException => + // expected + } + try { + checkAnswer( + df.select('id, 'category, sum("id").over(Window.partitionBy('category).orderBy('id) + .rowsBetween(2160000000L, 3160000000L))), + Seq()) + assert(false, "Boundary start should not be larger than Int.MaxValue(2147483647).") + } catch { + case e: IllegalArgumentException => + // expected + } + } }