diff --git a/akka-persistence/src/main/scala/akka/persistence/snapshot/local/LocalSnapshotStore.scala b/akka-persistence/src/main/scala/akka/persistence/snapshot/local/LocalSnapshotStore.scala index 369fc856ac9..5b899d90c88 100644 --- a/akka-persistence/src/main/scala/akka/persistence/snapshot/local/LocalSnapshotStore.scala +++ b/akka-persistence/src/main/scala/akka/persistence/snapshot/local/LocalSnapshotStore.scala @@ -73,13 +73,15 @@ private[persistence] class LocalSnapshotStore extends SnapshotStore with ActorLo Try(withInputStream(md)(deserialize)) match { case Success(s) ⇒ Some(SelectedSnapshot(md, s.data)) case Failure(e) ⇒ - log.error(e, s"error loading snapshot ${md}") + log.error(e, s"Error loading snapshot [${md}]") load(metadata.init) // try older snapshot } } - private def save(metadata: SnapshotMetadata, snapshot: Any): Unit = - withOutputStream(metadata)(serialize(_, Snapshot(snapshot))) + protected def save(metadata: SnapshotMetadata, snapshot: Any): Unit = { + val tmpFile = withOutputStream(metadata)(serialize(_, Snapshot(snapshot))) + tmpFile.renameTo(snapshotFile(metadata)) + } protected def deserialize(inputStream: InputStream): Snapshot = serializationExtension.deserialize(streamToBytes(inputStream), classOf[Snapshot]).get @@ -87,8 +89,11 @@ private[persistence] class LocalSnapshotStore extends SnapshotStore with ActorLo protected def serialize(outputStream: OutputStream, snapshot: Snapshot): Unit = outputStream.write(serializationExtension.findSerializerFor(snapshot).toBinary(snapshot)) - private def withOutputStream(metadata: SnapshotMetadata)(p: (OutputStream) ⇒ Unit): Unit = - withStream(new BufferedOutputStream(new FileOutputStream(snapshotFile(metadata))), p) + protected def withOutputStream(metadata: SnapshotMetadata)(p: (OutputStream) ⇒ Unit): File = { + val tmpFile = snapshotFile(metadata, extension = "tmp") + withStream(new BufferedOutputStream(new FileOutputStream(tmpFile)), p) + tmpFile + } private def withInputStream[T](metadata: SnapshotMetadata)(p: (InputStream) ⇒ T): T = withStream(new BufferedInputStream(new FileInputStream(snapshotFile(metadata))), p) @@ -96,8 +101,8 @@ private[persistence] class LocalSnapshotStore extends SnapshotStore with ActorLo private def withStream[A <: Closeable, B](stream: A, p: A ⇒ B): B = try { p(stream) } finally { stream.close() } - private def snapshotFile(metadata: SnapshotMetadata): File = - new File(snapshotDir, s"snapshot-${URLEncoder.encode(metadata.processorId, "UTF-8")}-${metadata.sequenceNr}-${metadata.timestamp}") + private def snapshotFile(metadata: SnapshotMetadata, extension: String = ""): File = + new File(snapshotDir, s"snapshot-${URLEncoder.encode(metadata.processorId, "UTF-8")}-${metadata.sequenceNr}-${metadata.timestamp}${extension}") private def snapshotMetadata(processorId: String, criteria: SnapshotSelectionCriteria): immutable.Seq[SnapshotMetadata] = snapshotDir.listFiles(new SnapshotFilenameFilter(processorId)).map(_.getName).collect { diff --git a/akka-persistence/src/test/scala/akka/persistence/SnapshotFailureRobustnessSpec.scala b/akka-persistence/src/test/scala/akka/persistence/SnapshotFailureRobustnessSpec.scala new file mode 100644 index 00000000000..5612e205b06 --- /dev/null +++ b/akka-persistence/src/test/scala/akka/persistence/SnapshotFailureRobustnessSpec.scala @@ -0,0 +1,82 @@ +/** + * Copyright (C) 2009-2014 Typesafe Inc. + */ + +package akka.persistence + +import akka.actor.{ Props, ActorRef } +import akka.testkit.{ TestEvent, EventFilter, ImplicitSender, AkkaSpec } +import scala.concurrent.duration._ +import akka.persistence.snapshot.local.LocalSnapshotStore +import akka.persistence.serialization.Snapshot +import akka.event.Logging + +import scala.language.postfixOps + +object SnapshotFailureRobustnessSpec { + + class SaveSnapshotTestProcessor(name: String, probe: ActorRef) extends NamedProcessor(name) { + def receive = { + case Persistent(payload, snr) ⇒ saveSnapshot(payload) + case SaveSnapshotSuccess(md) ⇒ probe ! md.sequenceNr + case SnapshotOffer(md, s) ⇒ probe ! ((md, s)) + case other ⇒ probe ! other + } + } + + class LoadSnapshotTestProcessor(name: String, probe: ActorRef) extends NamedProcessor(name) { + def receive = { + case Persistent(payload, snr) ⇒ probe ! s"${payload}-${snr}" + case SnapshotOffer(md, s) ⇒ probe ! ((md, s)) + case other ⇒ probe ! other + } + override def preStart() = () + } + + class FailingLocalSnapshotStore extends LocalSnapshotStore { + override def save(metadata: SnapshotMetadata, snapshot: Any): Unit = { + if (metadata.sequenceNr == 2) { + val bytes = "b0rk".getBytes("UTF-8") + withOutputStream(metadata)(_.write(bytes)) + } else super.save(metadata, snapshot) + } + } +} + +class SnapshotFailureRobustnessSpec extends AkkaSpec(PersistenceSpec.config("leveldb", "SnapshotFailureRobustnessSpec", serialization = "off", extraConfig = Some( + """ + |akka.persistence.snapshot-store.local.class = "akka.persistence.SnapshotFailureRobustnessSpec$FailingLocalSnapshotStore" + """.stripMargin))) with PersistenceSpec with ImplicitSender { + + import SnapshotFailureRobustnessSpec._ + + "A processor with a failing snapshot" must { + "recover state starting from the most recent complete snapshot" in { + val sProcessor = system.actorOf(Props(classOf[SaveSnapshotTestProcessor], name, testActor)) + val processorId = name + + sProcessor ! Persistent("blahonga") + expectMsg(1) + sProcessor ! Persistent("kablama") + expectMsg(2) + system.eventStream.publish(TestEvent.Mute( + EventFilter.error(start = "Error loading snapshot ["))) + system.eventStream.subscribe(testActor, classOf[Logging.Error]) + try { + val lProcessor = system.actorOf(Props(classOf[LoadSnapshotTestProcessor], name, testActor)) + lProcessor ! Recover() + expectMsgPF() { + case (SnapshotMetadata(`processorId`, 1, timestamp), state) ⇒ + state should be("blahonga") + timestamp should be > (0L) + } + expectMsg("kablama-2") + expectNoMsg(1 second) + } finally { + system.eventStream.unsubscribe(testActor, classOf[Logging.Error]) + system.eventStream.publish(TestEvent.UnMute( + EventFilter.error(start = "Error loading snapshot ["))) + } + } + } +} diff --git a/project/AkkaBuild.scala b/project/AkkaBuild.scala index 86e496c0707..5f4ea117cc9 100644 --- a/project/AkkaBuild.scala +++ b/project/AkkaBuild.scala @@ -1012,7 +1012,10 @@ object AkkaBuild extends Build { ProblemFilters.exclude[MissingMethodProblem]("akka.remote.ReliableDeliverySupervisor#GotUid.copy"), ProblemFilters.exclude[MissingMethodProblem]("akka.remote.ReliableDeliverySupervisor#GotUid.this"), ProblemFilters.exclude[MissingTypesProblem]("akka.remote.ReliableDeliverySupervisor$GotUid$"), - ProblemFilters.exclude[MissingMethodProblem]("akka.remote.ReliableDeliverySupervisor#GotUid.apply") + ProblemFilters.exclude[MissingMethodProblem]("akka.remote.ReliableDeliverySupervisor#GotUid.apply"), + + // Change of private method to protected by #15212 + ProblemFilters.exclude[MissingMethodProblem]("akka.persistence.snapshot.local.LocalSnapshotStore.akka$persistence$snapshot$local$LocalSnapshotStore$$save") ) }