Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -125,23 +125,24 @@ internal bool ShouldHedge(RequestMessage request, CosmosClient client)
/// <param name="sender"></param>
/// <param name="client"></param>
/// <param name="request"></param>
/// <param name="cancellationToken"></param>
/// <param name="applicationProvidedCancellationToken"></param>
/// <returns>The response after executing cross region hedging</returns>
internal override async Task<ResponseMessage> ExecuteAvailabilityStrategyAsync(
Func<RequestMessage, CancellationToken, Task<ResponseMessage>> sender,
CosmosClient client,
RequestMessage request,
CancellationToken cancellationToken)
CancellationToken applicationProvidedCancellationToken)
{
if (!this.ShouldHedge(request, client)
|| client.DocumentClient.GlobalEndpointManager.ReadEndpoints.Count == 1)
{
return await sender(request, cancellationToken);
return await sender(request, applicationProvidedCancellationToken);
}

ITrace trace = request.Trace;

using (CancellationTokenSource cancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
using (CancellationTokenSource hedgeRequestsCancellationTokenSource =
CancellationTokenSource.CreateLinkedTokenSource(applicationProvidedCancellationToken))
{
using (CloneableStream clonedBody = (CloneableStream)(request.Content == null
? null
Expand All @@ -161,7 +162,7 @@ internal override async Task<ResponseMessage> ExecuteAvailabilityStrategyAsync(
{
TimeSpan awaitTime = requestNumber == 0 ? this.Threshold : this.ThresholdStep;

using (CancellationTokenSource timerTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken))
using (CancellationTokenSource timerTokenSource = CancellationTokenSource.CreateLinkedTokenSource(applicationProvidedCancellationToken))
{
CancellationToken timerToken = timerTokenSource.Token;
using (Task hedgeTimer = Task.Delay(awaitTime, timerToken))
Expand All @@ -173,32 +174,50 @@ internal override async Task<ResponseMessage> ExecuteAvailabilityStrategyAsync(
hedgeRegions: hedgeRegions,
requestNumber: requestNumber,
trace: trace,
cancellationToken: cancellationToken,
cancellationTokenSource: cancellationTokenSource);
hedgeRequestsCancellationTokenSource: hedgeRequestsCancellationTokenSource);

requestTasks.Add(requestTask);
requestTasks.Add(hedgeTimer);

Task completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);
Task completedTask;
do
{
completedTask = await Task.WhenAny(requestTasks);
requestTasks.Remove(completedTask);
}
while (
completedTask == hedgeTimer &&
// Ignore hedge timer signals if either the e2e timeout is hit
// or the hedgeTimer task failed (or more commonly since this is a linked CTS was cancelled)
// in both of these cases we do not want to spawn new hedge requests
// but just consolidate the outcome of previous requests
(completedTask.IsFaulted || completedTask.IsCanceled || applicationProvidedCancellationToken.IsCancellationRequested));

if (completedTask == hedgeTimer)
{
continue;
}

timerTokenSource.Cancel();
requestTasks.Remove(hedgeTimer);
timerTokenSource.Cancel();

if (completedTask.IsFaulted)
if (completedTask.IsFaulted || completedTask.IsCanceled)
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
requestTasks.Remove(hedgeTimer);
timerTokenSource.Cancel();

if (applicationProvidedCancellationToken.IsCancellationRequested)
{
await (Task<HedgingResponse>)completedTask;
}

continue;
}

hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient)
{
cancellationTokenSource.Cancel();
hedgeRequestsCancellationTokenSource.Cancel();

((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeConfig,
Expand Down Expand Up @@ -227,12 +246,19 @@ internal override async Task<ResponseMessage> ExecuteAvailabilityStrategyAsync(
{
AggregateException innerExceptions = completedTask.Exception.Flatten();
lastException = innerExceptions.InnerExceptions.FirstOrDefault();
continue;
}

if (completedTask.IsCanceled)
{
lastException = new OperationCanceledException();
continue;
}

hedgeResponse = await (Task<HedgingResponse>)completedTask;
if (hedgeResponse.IsNonTransient || requestTasks.Count == 0)
{
cancellationTokenSource.Cancel();
hedgeRequestsCancellationTokenSource.Cancel();
((CosmosTraceDiagnostics)hedgeResponse.ResponseMessage.Diagnostics).Value.AddOrUpdateDatum(
HedgeConfig,
this.HedgeConfigText);
Expand All @@ -251,7 +277,16 @@ internal override async Task<ResponseMessage> ExecuteAvailabilityStrategyAsync(
throw lastException;
}

Debug.Assert(hedgeResponse != null);
if (hedgeResponse == null)
{
if (applicationProvidedCancellationToken.IsCancellationRequested)
{
throw new CosmosOperationCanceledException(new OperationCanceledException(), trace);
}

throw new InvalidOperationException("Cross-region hedging completed without producing a response.");
}

return hedgeResponse.ResponseMessage;
}
}
Expand All @@ -264,8 +299,7 @@ private async Task<HedgingResponse> CloneAndSendAsync(
IReadOnlyCollection<string> hedgeRegions,
int requestNumber,
ITrace trace,
CancellationToken cancellationToken,
CancellationTokenSource cancellationTokenSource)
CancellationTokenSource hedgeRequestsCancellationTokenSource)
{
RequestMessage clonedRequest;

Expand All @@ -287,8 +321,7 @@ private async Task<HedgingResponse> CloneAndSendAsync(
sender,
clonedRequest,
hedgeRegions.ElementAt(requestNumber),
cancellationToken,
cancellationTokenSource,
hedgeRequestsCancellationTokenSource,
trace);
}
}
Expand All @@ -297,27 +330,30 @@ private async Task<HedgingResponse> RequestSenderAndResultCheckAsync(
Func<RequestMessage, CancellationToken, Task<ResponseMessage>> sender,
RequestMessage request,
string targetRegionName,
CancellationToken cancellationToken,
CancellationTokenSource cancellationTokenSource,
CancellationTokenSource hedgeRequestsCancellationTokenSource,
ITrace trace)
{
try
{
ResponseMessage response = await sender.Invoke(request, cancellationToken);
ResponseMessage response = await sender.Invoke(request, hedgeRequestsCancellationTokenSource.Token);
if (IsFinalResult((int)response.StatusCode, (int)response.Headers.SubStatusCode))
{
if (!cancellationToken.IsCancellationRequested)
if (!hedgeRequestsCancellationTokenSource.IsCancellationRequested)
{
cancellationTokenSource.Cancel();
// App has not reached e2e timeout - we can cancel any still remaining
// hedge requests since we have a final response now
hedgeRequestsCancellationTokenSource.Cancel();
}

return new HedgingResponse(true, response, targetRegionName);
}

return new HedgingResponse(false, response, targetRegionName);
}
catch (OperationCanceledException oce) when (cancellationTokenSource.IsCancellationRequested)
catch (OperationCanceledException oce) when (hedgeRequestsCancellationTokenSource.IsCancellationRequested)
{
// hedgeRequestsCancellationTokenSource is a linked cancellation token source - so, would also signal
// cancellation on e2e timeout via app provided CT
throw new CosmosOperationCanceledException(oce, trace);
}
catch (Exception ex)
Expand Down
Loading