diff --git a/src/Akka.Persistence.Azure/Journal/AzureTableStorageJournal.cs b/src/Akka.Persistence.Azure/Journal/AzureTableStorageJournal.cs index 8b42f7b2..e80e8fbc 100644 --- a/src/Akka.Persistence.Azure/Journal/AzureTableStorageJournal.cs +++ b/src/Akka.Persistence.Azure/Journal/AzureTableStorageJournal.cs @@ -54,6 +54,7 @@ public class AzureTableStorageJournal : AsyncWriteJournal private readonly TableServiceClient _tableServiceClient; private TableClient _tableStorage_DoNotUseDirectly; private readonly Dictionary> _tagSubscribers = new Dictionary>(); + private readonly CancellationTokenSource _shutdownCts; public AzureTableStorageJournal(Config config = null) { @@ -81,6 +82,8 @@ public AzureTableStorageJournal(Config config = null) options: _settings.TableClientOptions) : new TableServiceClient(connectionString: _settings.ConnectionString); } + + _shutdownCts = new CancellationTokenSource(); } public TableClient Table @@ -107,9 +110,9 @@ public override async Task ReadHighestSequenceNrAsync( _log.Debug("Entering method ReadHighestSequenceNrAsync"); - var seqNo = await HighestSequenceNumberQuery(persistenceId) + var seqNo = await HighestSequenceNumberQuery(persistenceId, null, _shutdownCts.Token) .Select(entity => entity.GetInt64(HighestSequenceNrEntry.HighestSequenceNrKey).Value) - .AggregateAsync(0L, Math.Max); + .AggregateAsync(0L, Math.Max, cancellationToken: _shutdownCts.Token); _log.Debug("Leaving method ReadHighestSequenceNrAsync with SeqNo [{0}] for PersistentId [{1}]", seqNo, persistenceId); @@ -132,7 +135,8 @@ public override async Task ReplayMessagesAsync( if (max == 0) return; - var pages = PersistentJournalEntryReplayQuery(persistenceId, fromSequenceNr, toSequenceNr).AsPages().GetAsyncEnumerator(); + var pages = PersistentJournalEntryReplayQuery(persistenceId, fromSequenceNr, toSequenceNr, null, _shutdownCts.Token) + .AsPages().GetAsyncEnumerator(_shutdownCts.Token); ValueTask? nextTask = pages.MoveNextAsync(); var count = 0L; @@ -206,7 +210,8 @@ protected override async Task DeleteMessagesToAsync(string persistenceId, long t _log.Debug("Entering method DeleteMessagesToAsync for persistentId [{0}] and up to seqNo [{1}]", persistenceId, toSequenceNr); - var pages = PersistentJournalEntryDeleteQuery(persistenceId, toSequenceNr).AsPages().GetAsyncEnumerator(); + var pages = PersistentJournalEntryDeleteQuery(persistenceId, toSequenceNr, null, _shutdownCts.Token) + .AsPages().GetAsyncEnumerator(_shutdownCts.Token); ValueTask? nextTask = pages.MoveNextAsync(); while (nextTask.HasValue) @@ -227,7 +232,7 @@ protected override async Task DeleteMessagesToAsync(string persistenceId, long t if (currentPage.Values.Count > 0) { await Table.SubmitTransactionAsync(currentPage.Values - .Select(entity => new TableTransactionAction(TableTransactionActionType.Delete, entity))); + .Select(entity => new TableTransactionAction(TableTransactionActionType.Delete, entity)), _shutdownCts.Token); } } @@ -238,7 +243,7 @@ protected override void PreStart() { _log.Debug("Initializing Azure Table Storage..."); - InitCloudStorage(5) + InitCloudStorage(5, _shutdownCts.Token) .ConfigureAwait(false).GetAwaiter().GetResult(); _log.Debug("Successfully started Azure Table Storage!"); @@ -247,20 +252,27 @@ protected override void PreStart() base.PreStart(); } + protected override void PostStop() + { + _shutdownCts.Cancel(); + _shutdownCts.Dispose(); + base.PostStop(); + } + protected override bool ReceivePluginInternal(object message) { switch (message) { case ReplayTaggedMessages replay: - ReplayTaggedMessagesAsync(replay) + ReplayTaggedMessagesAsync(replay, _shutdownCts.Token) .PipeTo(replay.ReplyTo, success: h => new RecoverySuccess(h), failure: e => new ReplayMessagesFailure(e)); break; case SubscribePersistenceId subscribe: AddPersistenceIdSubscriber(Sender, subscribe.PersistenceId); Context.Watch(Sender); break; - case SubscribeAllPersistenceIds subscribe: - AddAllPersistenceIdSubscriber(Sender); + case SubscribeAllPersistenceIds _: + AddAllPersistenceIdSubscriber(Sender, _shutdownCts.Token); Context.Watch(Sender); break; case SubscribeTag subscribe: @@ -353,7 +365,7 @@ protected override async Task> WriteMessagesAsync(IEnu if (_log.IsDebugEnabled && _settings.VerboseLogging) _log.Debug("Attempting to write batch of {0} messages to Azure storage", batchItems.Count); - var response = await Table.SubmitTransactionAsync(batchItems); + var response = await Table.SubmitTransactionAsync(batchItems, _shutdownCts.Token); if (_log.IsDebugEnabled && _settings.VerboseLogging) { foreach (var r in response.Value) @@ -383,7 +395,7 @@ protected override async Task> WriteMessagesAsync(IEnu new AllPersistenceIdsEntry(PartitionKeyEscapeHelper.Escape(item.Key)).WriteEntity())); } - var allPersistenceResponse = await Table.SubmitTransactionAsync(allPersistenceIdsBatch); + var allPersistenceResponse = await Table.SubmitTransactionAsync(allPersistenceIdsBatch, _shutdownCts.Token); if (_log.IsDebugEnabled && _settings.VerboseLogging) foreach (var r in allPersistenceResponse.Value) @@ -405,7 +417,7 @@ protected override async Task> WriteMessagesAsync(IEnu eventTagsBatch.Add(new TableTransactionAction(TableTransactionActionType.UpsertReplace, item.WriteEntity())); } - var eventTagsResponse = await Table.SubmitTransactionAsync(eventTagsBatch); + var eventTagsResponse = await Table.SubmitTransactionAsync(eventTagsBatch, _shutdownCts.Token); if (_log.IsDebugEnabled && _settings.VerboseLogging) foreach (var r in eventTagsResponse.Value) @@ -439,8 +451,8 @@ protected override async Task> WriteMessagesAsync(IEnu } private AsyncPageable GenerateAllPersistenceIdsQuery( - int? maxPerPage = null, - CancellationToken cancellationToken = default) + int? maxPerPage, + CancellationToken cancellationToken) { return Table.QueryAsync( filter: $"PartitionKey eq '{AllPersistenceIdsEntry.PartitionKeyValue}'", @@ -452,8 +464,8 @@ private AsyncPageable GenerateAllPersistenceIdsQuery( private AsyncPageable HighestSequenceNumberQuery( string persistenceId, - int? maxPerPage = null, - CancellationToken cancellationToken = default) + int? maxPerPage, + CancellationToken cancellationToken) { return Table.QueryAsync( filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistenceId)}' and " + @@ -467,8 +479,8 @@ private AsyncPageable HighestSequenceNumberQuery( private AsyncPageable PersistentJournalEntryDeleteQuery( string persistenceId, long toSequenceNr, - int? maxPerPage = null, - CancellationToken cancellationToken = default) + int? maxPerPage, + CancellationToken cancellationToken) { return Table.QueryAsync( filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistenceId)}' and " + @@ -483,8 +495,8 @@ private AsyncPageable EventTagEntryDeleteQuery( string persistenceId, long fromSequenceNr, long toSequenceNr, - int? maxPerPage = null, - CancellationToken cancellationToken = default) + int? maxPerPage, + CancellationToken cancellationToken) { return Table.QueryAsync( filter: $"PartitionKey eq '{EventTagEntry.PartitionKeyValue}' and " + @@ -500,8 +512,8 @@ private AsyncPageable PersistentJournalEntryReplayQuery( string persistentId, long fromSequenceNumber, long toSequenceNumber, - int? maxPerPage = null, - CancellationToken cancellationToken = default) + int? maxPerPage, + CancellationToken cancellationToken) { var filter = $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistentId)}' and " + $"RowKey ne '{HighestSequenceNrEntry.RowKeyValue}'"; @@ -517,8 +529,8 @@ private AsyncPageable PersistentJournalEntryReplayQuery( private AsyncPageable TaggedMessageQuery( ReplayTaggedMessages replay, - int? maxPerPage = null, - CancellationToken cancellationToken = default) + int? maxPerPage, + CancellationToken cancellationToken) { return Table.QueryAsync( filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(EventTagEntry.GetPartitionKey(replay.Tag))}' and " + @@ -529,13 +541,13 @@ private AsyncPageable TaggedMessageQuery( cancellationToken: cancellationToken); } - private async Task AddAllPersistenceIdSubscriber(IActorRef subscriber) + private async Task AddAllPersistenceIdSubscriber(IActorRef subscriber, CancellationToken cancellationToken) { lock (_allPersistenceIdSubscribers) { _allPersistenceIdSubscribers.Add(subscriber); } - subscriber.Tell(new CurrentPersistenceIds(await GetAllPersistenceIds())); + subscriber.Tell(new CurrentPersistenceIds(await GetAllPersistenceIds(cancellationToken))); } private void AddPersistenceIdSubscriber(IActorRef subscriber, string persistenceId) @@ -560,18 +572,21 @@ private void AddTagSubscriber(IActorRef subscriber, string tag) subscriptions.Add(subscriber); } - private async Task> GetAllPersistenceIds() + private async Task> GetAllPersistenceIds(CancellationToken cancellationToken) { - return await GenerateAllPersistenceIdsQuery().Select(item => item.RowKey).ToListAsync(); + return await GenerateAllPersistenceIdsQuery(null, cancellationToken) + .Select(item => item.RowKey).ToListAsync(cancellationToken); } - private async Task InitCloudStorage(int remainingTries) + private async Task InitCloudStorage(int remainingTries, CancellationToken cancellationToken) { try { var tableClient = _tableServiceClient.GetTableClient(_settings.TableName); - using (var cts = new CancellationTokenSource(_settings.ConnectTimeout)) + var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(_settings.ConnectTimeout); + using (cts) { if (!_settings.AutoInitialize) { @@ -603,15 +618,19 @@ private async Task InitCloudStorage(int remainingTries) _log.Error(ex, "[{0}] more tries to initialize table storage remaining...", remainingTries); if (remainingTries == 0) throw; - await Task.Delay(RetryInterval[remainingTries]); - await InitCloudStorage(remainingTries - 1); + + await Task.Delay(RetryInterval[remainingTries], cancellationToken); + if (cancellationToken.IsCancellationRequested) + throw; + + await InitCloudStorage(remainingTries - 1, cancellationToken); } } - private async Task IsTableExist(string name, CancellationToken token) + private async Task IsTableExist(string name, CancellationToken cancellationToken) { - var tables = await _tableServiceClient.QueryAsync(t => t.Name == name, cancellationToken: token) - .ToListAsync(token) + var tables = await _tableServiceClient.QueryAsync(t => t.Name == name, cancellationToken: cancellationToken) + .ToListAsync(cancellationToken) .ConfigureAwait(false); return tables.Count > 0; } @@ -657,15 +676,16 @@ private void RemoveSubscriber( /// Replays all events with given tag within provided boundaries from current database. /// /// TBD + /// /// TBD - private async Task ReplayTaggedMessagesAsync(ReplayTaggedMessages replay) + private async Task ReplayTaggedMessagesAsync(ReplayTaggedMessages replay, CancellationToken cancellationToken) { // In order to actually break at the limit we ask for we have to // keep a separate counter and track it ourselves. var counter = 0; var maxOrderingId = 0L; - var pages = TaggedMessageQuery(replay).AsPages().GetAsyncEnumerator(); + var pages = TaggedMessageQuery(replay, null, cancellationToken).AsPages().GetAsyncEnumerator(cancellationToken); ValueTask? nextTask = pages.MoveNextAsync(); while (nextTask != null) diff --git a/src/Akka.Persistence.Azure/Snapshot/AzureBlobSnapshotStore.cs b/src/Akka.Persistence.Azure/Snapshot/AzureBlobSnapshotStore.cs index 94bf06b3..d2cdfe3f 100644 --- a/src/Akka.Persistence.Azure/Snapshot/AzureBlobSnapshotStore.cs +++ b/src/Akka.Persistence.Azure/Snapshot/AzureBlobSnapshotStore.cs @@ -47,6 +47,8 @@ public class AzureBlobSnapshotStore : SnapshotStore private readonly AzureBlobSnapshotStoreSettings _settings; private readonly BlobServiceClient _serviceClient; + private readonly CancellationTokenSource _shutdownCts; + public AzureBlobSnapshotStore(Config config = null) { _serialization = new SerializationHelper(Context.System); @@ -72,62 +74,71 @@ public AzureBlobSnapshotStore(Config config = null) : _serviceClient = new BlobServiceClient(connectionString: _settings.ConnectionString); } - _containerClient = new Lazy(() => InitCloudStorage(5).Result); + _shutdownCts = new CancellationTokenSource(); + _containerClient = new Lazy(() => + InitCloudStorage(5, _shutdownCts.Token).GetAwaiter().GetResult()); } public BlobContainerClient Container => _containerClient.Value; - private async Task InitCloudStorage(int remainingTries) + private async Task InitCloudStorage(int remainingTries, CancellationToken cancellationToken) { try { var blobClient = _serviceClient.GetBlobContainerClient(_settings.ContainerName); - using var cts = new CancellationTokenSource(_settings.ConnectTimeout); - if (!_settings.AutoInitialize) + var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken); + cts.CancelAfter(_settings.ConnectTimeout); + using (cts) { - var exists = await blobClient.ExistsAsync(cts.Token); - - if (!exists) + if (!_settings.AutoInitialize) { - remainingTries = 0; + var exists = await blobClient.ExistsAsync(cts.Token); - throw new Exception( - $"Container {_settings.ContainerName} doesn't exist. Either create it or turn auto-initialize on"); - } + if (!exists) + { + remainingTries = 0; + + throw new Exception( + $"Container {_settings.ContainerName} doesn't exist. Either create it or turn auto-initialize on"); + } - _log.Info("Successfully connected to existing container {0}", _settings.ContainerName); + _log.Info("Successfully connected to existing container {0}", _settings.ContainerName); - return blobClient; - } + return blobClient; + } - if (await blobClient.ExistsAsync(cts.Token)) - { - _log.Info("Successfully connected to existing container {0}", _settings.ContainerName); - } - else - { - try + if (await blobClient.ExistsAsync(cts.Token)) { - await blobClient.CreateAsync(_settings.ContainerPublicAccessType, - cancellationToken: cts.Token); - _log.Info("Created Azure Blob Container {0}", _settings.ContainerName); + _log.Info("Successfully connected to existing container {0}", _settings.ContainerName); } - catch (Exception e) + else { - throw new Exception($"Failed to create Azure Blob Container {_settings.ContainerName}", e); + try + { + await blobClient.CreateAsync(_settings.ContainerPublicAccessType, + cancellationToken: cts.Token); + _log.Info("Created Azure Blob Container {0}", _settings.ContainerName); + } + catch (Exception e) + { + throw new Exception($"Failed to create Azure Blob Container {_settings.ContainerName}", e); + } } - } - return blobClient; + return blobClient; + } } catch (Exception ex) { _log.Error(ex, "[{0}] more tries to initialize table storage remaining...", remainingTries); if (remainingTries == 0) throw; - await Task.Delay(RetryInterval[remainingTries]); - return await InitCloudStorage(remainingTries - 1); + await Task.Delay(RetryInterval[remainingTries], cancellationToken); + if (cancellationToken.IsCancellationRequested) + throw; + + return await InitCloudStorage(remainingTries - 1, cancellationToken); } } @@ -144,11 +155,20 @@ protected override void PreStart() base.PreStart(); } + protected override void PostStop() + { + _shutdownCts.Cancel(); + _shutdownCts.Dispose(); + base.PostStop(); + } - protected override async Task LoadAsync(string persistenceId, + protected override async Task LoadAsync( + string persistenceId, SnapshotSelectionCriteria criteria) { - using var cts = new CancellationTokenSource(_settings.RequestTimeout); + var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownCts.Token); + cts.CancelAfter(_settings.RequestTimeout); + using(cts) { var results = Container.GetBlobsAsync( prefix: SeqNoHelper.ToSnapshotSearchQuery(persistenceId), @@ -198,54 +218,66 @@ protected override async Task SaveAsync(SnapshotMetadata metadata, object snapsh var blobClient = Container.GetBlockBlobClient(metadata.ToSnapshotBlobId()); var snapshotData = _serialization.SnapshotToBytes(new Serialization.Snapshot(snapshot)); - using var cts = new CancellationTokenSource(_settings.RequestTimeout); - var blobMetadata = new Dictionary + var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownCts.Token); + cts.CancelAfter(_settings.RequestTimeout); + using (cts) { - [TimeStampMetaDataKey] = metadata.Timestamp.Ticks.ToString(), - /* - * N.B. No need to convert the key into the Journal format we use here. - * The blobs themselves don't have their sort order affected by - * the presence of this metadata, so we should just save the SeqNo - * in a format that can be easily deserialized later. - */ - [SeqNoMetaDataKey] = metadata.SequenceNr.ToString() - }; - - using var stream = new MemoryStream(snapshotData); - await blobClient.UploadAsync( - stream, - metadata: blobMetadata, - cancellationToken: cts.Token); + var blobMetadata = new Dictionary + { + [TimeStampMetaDataKey] = metadata.Timestamp.Ticks.ToString(), + /* + * N.B. No need to convert the key into the Journal format we use here. + * The blobs themselves don't have their sort order affected by + * the presence of this metadata, so we should just save the SeqNo + * in a format that can be easily deserialized later. + */ + [SeqNoMetaDataKey] = metadata.SequenceNr.ToString() + }; + + using var stream = new MemoryStream(snapshotData); + await blobClient.UploadAsync( + stream, + metadata: blobMetadata, + cancellationToken: cts.Token); + } } protected override async Task DeleteAsync(SnapshotMetadata metadata) { var blobClient = Container.GetBlobClient(metadata.ToSnapshotBlobId()); - using var cts = new CancellationTokenSource(_settings.RequestTimeout); - await blobClient.DeleteIfExistsAsync(cancellationToken: cts.Token); + var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownCts.Token); + cts.CancelAfter(_settings.RequestTimeout); + using (cts) + { + await blobClient.DeleteIfExistsAsync(cancellationToken: cts.Token); + } } protected override async Task DeleteAsync(string persistenceId, SnapshotSelectionCriteria criteria) { - using var cts = new CancellationTokenSource(_settings.RequestTimeout); - var items = Container.GetBlobsAsync( - prefix: SeqNoHelper.ToSnapshotSearchQuery(persistenceId), - traits: BlobTraits.Metadata, - cancellationToken: cts.Token); - - var filtered = items - .Where(x => FilterBlobSeqNo(criteria, x)) - .Where(x => FilterBlobTimestamp(criteria, x)); - - var deleteTasks = new List(); - await foreach (var blob in filtered.WithCancellation(cts.Token)) + var cts = CancellationTokenSource.CreateLinkedTokenSource(_shutdownCts.Token); + cts.CancelAfter(_settings.RequestTimeout); + using (cts) { - var blobClient = Container.GetBlobClient(blob.Name); - deleteTasks.Add(blobClient.DeleteIfExistsAsync(cancellationToken: cts.Token)); - } + var items = Container.GetBlobsAsync( + prefix: SeqNoHelper.ToSnapshotSearchQuery(persistenceId), + traits: BlobTraits.Metadata, + cancellationToken: cts.Token); + + var filtered = items + .Where(x => FilterBlobSeqNo(criteria, x)) + .Where(x => FilterBlobTimestamp(criteria, x)); - await Task.WhenAll(deleteTasks); + var deleteTasks = new List(); + await foreach (var blob in filtered.WithCancellation(cts.Token)) + { + var blobClient = Container.GetBlobClient(blob.Name); + deleteTasks.Add(blobClient.DeleteIfExistsAsync(cancellationToken: cts.Token)); + } + + await Task.WhenAll(deleteTasks); + } } private static bool FilterBlobSeqNo(SnapshotSelectionCriteria criteria, BlobItem x)