diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 80f5b3532c5e..8cb99a162ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -434,22 +434,22 @@ case class StateStoreRestoreExec( numColsPrefixKey = 0, session.sessionState, Some(session.streams.stateStoreCoordinator)) { case (store, iter) => - val hasInput = iter.hasNext - if (!hasInput && keyExpressions.isEmpty) { - // If our `keyExpressions` are empty, we're getting a global aggregation. In that case - // the `HashAggregateExec` will output a 0 value for the partial merge. We need to - // restore the value, so that we don't overwrite our state with a 0 value, but rather - // merge the 0 with existing state. - store.iterator().map(_.value) - } else { - iter.flatMap { row => - val key = stateManager.getKey(row.asInstanceOf[UnsafeRow]) - val restoredRow = stateManager.get(store, key) - val outputRows = Option(restoredRow).toSeq :+ row - numOutputRows += outputRows.size - outputRows - } + val hasInput = iter.hasNext + if (!hasInput && keyExpressions.isEmpty) { + // If our `keyExpressions` are empty, we're getting a global aggregation. In that case + // the `HashAggregateExec` will output a 0 value for the partial merge. We need to + // restore the value, so that we don't overwrite our state with a 0 value, but rather + // merge the 0 with existing state. + store.iterator().map(_.value) + } else { + iter.flatMap { row => + val key = stateManager.getKey(row.asInstanceOf[UnsafeRow]) + val restoredRow = stateManager.get(store, key) + val outputRows = Option(restoredRow).toSeq :+ row + numOutputRows += outputRows.size + outputRows } + } } }