diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala index c9f5d3b3d92d..dfa1100c37a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/AggregateProcessor.scala @@ -145,10 +145,13 @@ private[window] final class AggregateProcessor( /** Update the buffer. */ def update(input: InternalRow): Unit = { - updateProjection(join(buffer, input)) + // TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that + // MutableProjection makes copies of the complex input objects it buffer. + val copy = input.copy() + updateProjection(join(buffer, copy)) var i = 0 while (i < numImperatives) { - imperatives(i).update(buffer, input) + imperatives(i).update(buffer, copy) i += 1 } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 1255c4910471..204858fa2978 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.types._ /** * Window function testing for DataFrame API. @@ -423,4 +424,48 @@ class DataFrameWindowFunctionsSuite extends QueryTest with SharedSQLContext { df.select(selectList: _*).where($"value" < 2), Seq(Row(3, "1", null, 3.0, 4.0, 3.0), Row(5, "1", false, 4.0, 5.0, 5.0))) } + + test("SPARK-21258: complex object in combination with spilling") { + // Make sure we trigger the spilling path. + withSQLConf(SQLConf.WINDOW_EXEC_BUFFER_SPILL_THRESHOLD.key -> "17") { + val sampleSchema = new StructType(). + add("f0", StringType). + add("f1", LongType). + add("f2", ArrayType(new StructType(). + add("f20", StringType))). + add("f3", ArrayType(new StructType(). + add("f30", StringType))) + + val w0 = Window.partitionBy("f0").orderBy("f1") + val w1 = w0.rowsBetween(Long.MinValue, Long.MaxValue) + + val c0 = first(struct($"f2", $"f3")).over(w0) as "c0" + val c1 = last(struct($"f2", $"f3")).over(w1) as "c1" + + val input = + """{"f1":1497820153720,"f2":[{"f20":"x","f21":0}],"f3":[{"f30":"x","f31":0}]} + |{"f1":1497802179638} + |{"f1":1497802189347} + |{"f1":1497802189593} + |{"f1":1497802189597} + |{"f1":1497802189599} + |{"f1":1497802192103} + |{"f1":1497802193414} + |{"f1":1497802193577} + |{"f1":1497802193709} + |{"f1":1497802202883} + |{"f1":1497802203006} + |{"f1":1497802203743} + |{"f1":1497802203834} + |{"f1":1497802203887} + |{"f1":1497802203893} + |{"f1":1497802203976} + |{"f1":1497820168098} + |""".stripMargin.split("\n").toSeq + + import testImplicits._ + + spark.read.schema(sampleSchema).json(input.toDS()).select(c0, c1).foreach { _ => () } + } + } }