Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 57 additions & 37 deletions src/Akka.Persistence.Azure/Journal/AzureTableStorageJournal.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public class AzureTableStorageJournal : AsyncWriteJournal
private readonly TableServiceClient _tableServiceClient;
private TableClient _tableStorage_DoNotUseDirectly;
private readonly Dictionary<string, ISet<IActorRef>> _tagSubscribers = new Dictionary<string, ISet<IActorRef>>();
private readonly CancellationTokenSource _shutdownCts;

public AzureTableStorageJournal(Config config = null)
{
Expand Down Expand Up @@ -81,6 +82,8 @@ public AzureTableStorageJournal(Config config = null)
options: _settings.TableClientOptions)
: new TableServiceClient(connectionString: _settings.ConnectionString);
}

_shutdownCts = new CancellationTokenSource();
}

public TableClient Table
Expand All @@ -107,9 +110,9 @@ public override async Task<long> ReadHighestSequenceNrAsync(

_log.Debug("Entering method ReadHighestSequenceNrAsync");

var seqNo = await HighestSequenceNumberQuery(persistenceId)
var seqNo = await HighestSequenceNumberQuery(persistenceId, null, _shutdownCts.Token)
.Select(entity => entity.GetInt64(HighestSequenceNrEntry.HighestSequenceNrKey).Value)
.AggregateAsync(0L, Math.Max);
.AggregateAsync(0L, Math.Max, cancellationToken: _shutdownCts.Token);

_log.Debug("Leaving method ReadHighestSequenceNrAsync with SeqNo [{0}] for PersistentId [{1}]", seqNo, persistenceId);

Expand All @@ -132,7 +135,8 @@ public override async Task ReplayMessagesAsync(
if (max == 0)
return;

var pages = PersistentJournalEntryReplayQuery(persistenceId, fromSequenceNr, toSequenceNr).AsPages().GetAsyncEnumerator();
var pages = PersistentJournalEntryReplayQuery(persistenceId, fromSequenceNr, toSequenceNr, null, _shutdownCts.Token)
.AsPages().GetAsyncEnumerator(_shutdownCts.Token);

ValueTask<bool>? nextTask = pages.MoveNextAsync();
var count = 0L;
Expand Down Expand Up @@ -206,7 +210,8 @@ protected override async Task DeleteMessagesToAsync(string persistenceId, long t

_log.Debug("Entering method DeleteMessagesToAsync for persistentId [{0}] and up to seqNo [{1}]", persistenceId, toSequenceNr);

var pages = PersistentJournalEntryDeleteQuery(persistenceId, toSequenceNr).AsPages().GetAsyncEnumerator();
var pages = PersistentJournalEntryDeleteQuery(persistenceId, toSequenceNr, null, _shutdownCts.Token)
.AsPages().GetAsyncEnumerator(_shutdownCts.Token);

ValueTask<bool>? nextTask = pages.MoveNextAsync();
while (nextTask.HasValue)
Expand All @@ -227,7 +232,7 @@ protected override async Task DeleteMessagesToAsync(string persistenceId, long t
if (currentPage.Values.Count > 0)
{
await Table.SubmitTransactionAsync(currentPage.Values
.Select(entity => new TableTransactionAction(TableTransactionActionType.Delete, entity)));
.Select(entity => new TableTransactionAction(TableTransactionActionType.Delete, entity)), _shutdownCts.Token);
}
}

Expand All @@ -238,7 +243,7 @@ protected override void PreStart()
{
_log.Debug("Initializing Azure Table Storage...");

InitCloudStorage(5)
InitCloudStorage(5, _shutdownCts.Token)
.ConfigureAwait(false).GetAwaiter().GetResult();

_log.Debug("Successfully started Azure Table Storage!");
Expand All @@ -247,20 +252,27 @@ protected override void PreStart()
base.PreStart();
}

protected override void PostStop()
{
_shutdownCts.Cancel();
_shutdownCts.Dispose();
base.PostStop();
}

protected override bool ReceivePluginInternal(object message)
{
switch (message)
{
case ReplayTaggedMessages replay:
ReplayTaggedMessagesAsync(replay)
ReplayTaggedMessagesAsync(replay, _shutdownCts.Token)
.PipeTo(replay.ReplyTo, success: h => new RecoverySuccess(h), failure: e => new ReplayMessagesFailure(e));
break;
case SubscribePersistenceId subscribe:
AddPersistenceIdSubscriber(Sender, subscribe.PersistenceId);
Context.Watch(Sender);
break;
case SubscribeAllPersistenceIds subscribe:
AddAllPersistenceIdSubscriber(Sender);
case SubscribeAllPersistenceIds _:
AddAllPersistenceIdSubscriber(Sender, _shutdownCts.Token);
Context.Watch(Sender);
break;
case SubscribeTag subscribe:
Expand Down Expand Up @@ -353,7 +365,7 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
if (_log.IsDebugEnabled && _settings.VerboseLogging)
_log.Debug("Attempting to write batch of {0} messages to Azure storage", batchItems.Count);

var response = await Table.SubmitTransactionAsync(batchItems);
var response = await Table.SubmitTransactionAsync(batchItems, _shutdownCts.Token);
if (_log.IsDebugEnabled && _settings.VerboseLogging)
{
foreach (var r in response.Value)
Expand Down Expand Up @@ -383,7 +395,7 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
new AllPersistenceIdsEntry(PartitionKeyEscapeHelper.Escape(item.Key)).WriteEntity()));
}

var allPersistenceResponse = await Table.SubmitTransactionAsync(allPersistenceIdsBatch);
var allPersistenceResponse = await Table.SubmitTransactionAsync(allPersistenceIdsBatch, _shutdownCts.Token);

if (_log.IsDebugEnabled && _settings.VerboseLogging)
foreach (var r in allPersistenceResponse.Value)
Expand All @@ -405,7 +417,7 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
eventTagsBatch.Add(new TableTransactionAction(TableTransactionActionType.UpsertReplace, item.WriteEntity()));
}

var eventTagsResponse = await Table.SubmitTransactionAsync(eventTagsBatch);
var eventTagsResponse = await Table.SubmitTransactionAsync(eventTagsBatch, _shutdownCts.Token);

if (_log.IsDebugEnabled && _settings.VerboseLogging)
foreach (var r in eventTagsResponse.Value)
Expand Down Expand Up @@ -439,8 +451,8 @@ protected override async Task<IImmutableList<Exception>> WriteMessagesAsync(IEnu
}

private AsyncPageable<TableEntity> GenerateAllPersistenceIdsQuery(
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{AllPersistenceIdsEntry.PartitionKeyValue}'",
Expand All @@ -452,8 +464,8 @@ private AsyncPageable<TableEntity> GenerateAllPersistenceIdsQuery(

private AsyncPageable<TableEntity> HighestSequenceNumberQuery(
string persistenceId,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistenceId)}' and " +
Expand All @@ -467,8 +479,8 @@ private AsyncPageable<TableEntity> HighestSequenceNumberQuery(
private AsyncPageable<TableEntity> PersistentJournalEntryDeleteQuery(
string persistenceId,
long toSequenceNr,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistenceId)}' and " +
Expand All @@ -483,8 +495,8 @@ private AsyncPageable<TableEntity> EventTagEntryDeleteQuery(
string persistenceId,
long fromSequenceNr,
long toSequenceNr,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{EventTagEntry.PartitionKeyValue}' and " +
Expand All @@ -500,8 +512,8 @@ private AsyncPageable<TableEntity> PersistentJournalEntryReplayQuery(
string persistentId,
long fromSequenceNumber,
long toSequenceNumber,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
var filter = $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(persistentId)}' and " +
$"RowKey ne '{HighestSequenceNrEntry.RowKeyValue}'";
Expand All @@ -517,8 +529,8 @@ private AsyncPageable<TableEntity> PersistentJournalEntryReplayQuery(

private AsyncPageable<TableEntity> TaggedMessageQuery(
ReplayTaggedMessages replay,
int? maxPerPage = null,
CancellationToken cancellationToken = default)
int? maxPerPage,
CancellationToken cancellationToken)
{
return Table.QueryAsync<TableEntity>(
filter: $"PartitionKey eq '{PartitionKeyEscapeHelper.Escape(EventTagEntry.GetPartitionKey(replay.Tag))}' and " +
Expand All @@ -529,13 +541,13 @@ private AsyncPageable<TableEntity> TaggedMessageQuery(
cancellationToken: cancellationToken);
}

private async Task AddAllPersistenceIdSubscriber(IActorRef subscriber)
private async Task AddAllPersistenceIdSubscriber(IActorRef subscriber, CancellationToken cancellationToken)
{
lock (_allPersistenceIdSubscribers)
{
_allPersistenceIdSubscribers.Add(subscriber);
}
subscriber.Tell(new CurrentPersistenceIds(await GetAllPersistenceIds()));
subscriber.Tell(new CurrentPersistenceIds(await GetAllPersistenceIds(cancellationToken)));
}

private void AddPersistenceIdSubscriber(IActorRef subscriber, string persistenceId)
Expand All @@ -560,18 +572,21 @@ private void AddTagSubscriber(IActorRef subscriber, string tag)
subscriptions.Add(subscriber);
}

private async Task<IEnumerable<string>> GetAllPersistenceIds()
private async Task<IEnumerable<string>> GetAllPersistenceIds(CancellationToken cancellationToken)
{
return await GenerateAllPersistenceIdsQuery().Select(item => item.RowKey).ToListAsync();
return await GenerateAllPersistenceIdsQuery(null, cancellationToken)
.Select(item => item.RowKey).ToListAsync(cancellationToken);
}

private async Task InitCloudStorage(int remainingTries)
private async Task InitCloudStorage(int remainingTries, CancellationToken cancellationToken)
{
try
{
var tableClient = _tableServiceClient.GetTableClient(_settings.TableName);

using (var cts = new CancellationTokenSource(_settings.ConnectTimeout))
var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(_settings.ConnectTimeout);
using (cts)
{
if (!_settings.AutoInitialize)
{
Expand Down Expand Up @@ -603,15 +618,19 @@ private async Task InitCloudStorage(int remainingTries)
_log.Error(ex, "[{0}] more tries to initialize table storage remaining...", remainingTries);
if (remainingTries == 0)
throw;
await Task.Delay(RetryInterval[remainingTries]);
await InitCloudStorage(remainingTries - 1);

await Task.Delay(RetryInterval[remainingTries], cancellationToken);
if (cancellationToken.IsCancellationRequested)
throw;

await InitCloudStorage(remainingTries - 1, cancellationToken);
}
}

private async Task<bool> IsTableExist(string name, CancellationToken token)
private async Task<bool> IsTableExist(string name, CancellationToken cancellationToken)
{
var tables = await _tableServiceClient.QueryAsync(t => t.Name == name, cancellationToken: token)
.ToListAsync(token)
var tables = await _tableServiceClient.QueryAsync(t => t.Name == name, cancellationToken: cancellationToken)
.ToListAsync(cancellationToken)
.ConfigureAwait(false);
return tables.Count > 0;
}
Expand Down Expand Up @@ -657,15 +676,16 @@ private void RemoveSubscriber(
/// Replays all events with given tag within provided boundaries from current database.
/// </summary>
/// <param name="replay">TBD</param>
/// <param name="cancellationToken"></param>
/// <returns>TBD</returns>
private async Task<long> ReplayTaggedMessagesAsync(ReplayTaggedMessages replay)
private async Task<long> ReplayTaggedMessagesAsync(ReplayTaggedMessages replay, CancellationToken cancellationToken)
{
// In order to actually break at the limit we ask for we have to
// keep a separate counter and track it ourselves.
var counter = 0;
var maxOrderingId = 0L;

var pages = TaggedMessageQuery(replay).AsPages().GetAsyncEnumerator();
var pages = TaggedMessageQuery(replay, null, cancellationToken).AsPages().GetAsyncEnumerator(cancellationToken);
ValueTask<bool>? nextTask = pages.MoveNextAsync();

while (nextTask != null)
Expand Down
Loading