Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ class RocksDB(
* Drop uncommitted changes, and roll back to previous version.
*/
def rollback(): Unit = {
acquire()
numKeysOnWritingVersion = numKeysOnLoadedVersion
loadedVersion = -1L
changelogWriter.foreach(_.abort())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,22 +434,26 @@ case class StateStoreRestoreExec(
numColsPrefixKey = 0,
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val hasInput = iter.hasNext
if (!hasInput && keyExpressions.isEmpty) {
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
// the `HashAggregateExec` will output a 0 value for the partial merge. We need to
// restore the value, so that we don't overwrite our state with a 0 value, but rather
// merge the 0 with existing state.
store.iterator().map(_.value)
} else {
iter.flatMap { row =>
val key = stateManager.getKey(row.asInstanceOf[UnsafeRow])
val restoredRow = stateManager.get(store, key)
val outputRows = Option(restoredRow).toSeq :+ row
numOutputRows += outputRows.size
outputRows
}
val hasInput = iter.hasNext
val result = if (!hasInput && keyExpressions.isEmpty) {
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
// the `HashAggregateExec` will output a 0 value for the partial merge. We need to
// restore the value, so that we don't overwrite our state with a 0 value, but rather
// merge the 0 with existing state.
store.iterator().map(_.value)
} else {
iter.flatMap { row =>
val key = stateManager.getKey(row.asInstanceOf[UnsafeRow])
val restoredRow = stateManager.get(store, key)
val outputRows = Option(restoredRow).toSeq :+ row
numOutputRows += outputRows.size
outputRows
}
}
// SPARK-46547 - Release any locks/resources if required, to prevent
// deadlocks with the maintenance thread.
store.abort()
result
}
}

Expand Down